All Downloads are FREE. Search and download functionalities are using the official Maven repository.
Please wait. This can take some minutes ...
Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance.
Project price only 1 $
You can buy this project and download/modify it how often you want.
org.deeplearning4j.aws.ec2.provision.HostProvisioner Maven / Gradle / Ivy
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.aws.ec2.provision;
import com.jcraft.jsch.*;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.BufferedInputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.InputStream;
import java.util.Arrays;
import java.util.Collection;
/**
* Meant for uploading files to remote servers
* @author Adam Gibson
*
*/
public class HostProvisioner implements UserInfo {
private String host;
private JSch jsch;
private String user;
private int port = 22;
private String password;
private static final Logger log = LoggerFactory.getLogger(HostProvisioner.class);
/**
*
* @param host host to connect to (public facing dns)
* @param user the user to connect with (default root otherwise)
* @param password the password to use if any
* @param port the port to connect to(default 22)
*/
public HostProvisioner(String host, String user, String password, int port) {
super();
this.host = host;
this.user = user;
this.port = port;
this.password = password;
jsch = new JSch();
}
/**
* Connects to port 22
* @param host host to connect to (public facing dns)
* @param user the user to connect with (default root otherwise)
* @param password the password to use if any
*/
public HostProvisioner(String host, String user, String password) {
this(host, user, password, 22);
}
/**
* Connects to port 22
* @param host host to connect to (public facing dns)
* @param user the user to connect with (default root otherwise)
*/
public HostProvisioner(String host, String user) {
this(host, user, "", 22);
}
/**
* Connects to port 22, user root, with no password
* @param host host to connect to (public facing dns)
*/
public HostProvisioner(String host) {
this(host, "root", "", 22);
}
public void uploadAndRun(String script, String rootDir) throws Exception {
String remoteName = rootDir.isEmpty() ? new File(script).getName() : rootDir + "/" + new File(script).getName();
upload(new File(script), remoteName);
String remoteCommand = remoteName.charAt(0) != '/' ? "./" + remoteName : remoteName;
remoteCommand = "chmod +x " + remoteCommand + " && " + remoteCommand;
runRemoteCommand(remoteCommand);
}
public void runRemoteCommand(String remoteCommand) throws Exception {
Session session = getSession();
session.connect();
ChannelExec channel = (ChannelExec) session.openChannel("exec");
channel.setCommand(remoteCommand);
channel.setErrStream(System.err);
channel.setPty(true);
channel.setOutputStream(System.out);
channel.connect();
channel.start();
InputStream input = channel.getInputStream();
//start reading the input from the executed commands on the shell
byte[] tmp = new byte[60000];
while (true) {
while (input.available() > 0) {
int i = input.read(tmp, 0, tmp.length);
if (i < 0)
break;
log.info(new String(tmp, 0, i));
}
if (channel.isClosed()) {
log.info("exit-status: " + channel.getExitStatus());
break;
}
}
channel.disconnect();
session.disconnect();
}
private Session getSession() throws Exception {
Session session = jsch.getSession(user, host, port);
session.setUserInfo(this);
return session;
}
/**
* Creates the directory for the file if necessary
* and uploads the file
* @param from the directory to upload from
* @param to the destination directory on the remote server
* @throws Exception
*/
public void uploadForDeployment(String from, String to) throws Exception {
File fromFile = new File(from);
if (!to.isEmpty() && fromFile.isDirectory())
mkDir(to);
else
upload(from, to);
}
public void addKeyFile(String keyFile) throws Exception {
jsch.addIdentity(keyFile);
}
//creates the directory to upload to
private void mkDir(String dir) throws Exception {
Session session = getSession();
session.connect();
Channel channel = session.openChannel("sftp");
channel.connect();
ChannelSftp c = (ChannelSftp) channel;
if (!fileExists(dir, c))
c.mkdir(dir);
c.exit();
session.disconnect();
}
private boolean fileExists(String dir, ChannelSftp channel) {
try {
channel.stat(dir);
return true;
} catch (Exception e) {
return false;
}
}
//uploads the file or listed files in a directory
private void upload(String fileOrDir, String uploadRootDir) throws Exception {
if (uploadRootDir.isEmpty())
uploadRootDir = ".";
File origin = new File(fileOrDir);
if (fileOrDir.endsWith(".tar") || fileOrDir.endsWith(".tar.gz")) {
upload(new File(fileOrDir), uploadRootDir);
untar(uploadRootDir);
} else if (origin.isFile()) {
upload(new File(fileOrDir), uploadRootDir);
} else {
File[] childFiles = origin.listFiles();
if (childFiles != null)
upload(Arrays.asList(childFiles), uploadRootDir);
}
}
private void untar(String targetRemoteFile) throws Exception {
this.runRemoteCommand("tar xvf " + targetRemoteFile);
}
private void upload(Collection files, String rootDir) throws Exception {
Session session = getSession();
session.connect();
Channel channel = session.openChannel("sftp");
channel.connect();
ChannelSftp c = (ChannelSftp) channel;
for (File f : files) {
if (f.isDirectory()) {
log.warn("Skipping " + f.getName());
continue;
}
log.info("Uploading " + f.getName());
BufferedInputStream bis = new BufferedInputStream(new FileInputStream(f));
c.put(bis, rootDir + "/" + f.getName());
bis.close();
}
channel.disconnect();
session.disconnect();
}
private void upload(File f, String remoteFile) throws Exception {
Session session = getSession();
int numRetries = 0;
while (numRetries < 3 && !session.isConnected()) {
try {
session.connect();
} catch (Exception e) {
numRetries++;
}
}
try {
Channel channel = session.openChannel("sftp");
channel.connect();
ChannelSftp c = (ChannelSftp) channel;
BufferedInputStream bis = new BufferedInputStream(new FileInputStream(f));
if (this.fileExists(remoteFile, c))
if (f.isDirectory())
c.rmdir(remoteFile);
else
c.rm(remoteFile);
c.put(bis, remoteFile);
bis.close();
c.exit();
session.disconnect();
} catch (Exception e) {
log.info("Session was down...trying again", e);
upload(f, remoteFile);
}
}
@Override
public String getPassphrase() {
return this.password;
}
@Override
public String getPassword() {
return this.password;
}
@Override
public boolean promptPassphrase(String arg0) {
return true;
}
@Override
public boolean promptPassword(String arg0) {
return true;
}
@Override
public boolean promptYesNo(String arg0) {
return true;
}
@Override
public void showMessage(String arg0) {
log.info(arg0);
}
}