
org.deeplearning4j.aws.ec2.provision.HostProvisioner Maven / Gradle / Ivy
/*
*
* * Copyright 2015 Skymind,Inc.
* *
* * Licensed under the Apache License, Version 2.0 (the "License");
* * you may not use this file except in compliance with the License.
* * You may obtain a copy of the License at
* *
* * http://www.apache.org/licenses/LICENSE-2.0
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS,
* * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* * See the License for the specific language governing permissions and
* * limitations under the License.
*
*/
package org.deeplearning4j.aws.ec2.provision;
import java.io.BufferedInputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.InputStream;
import java.util.Arrays;
import java.util.Collection;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.jcraft.jsch.Channel;
import com.jcraft.jsch.ChannelExec;
import com.jcraft.jsch.ChannelSftp;
import com.jcraft.jsch.JSch;
import com.jcraft.jsch.Session;
import com.jcraft.jsch.UserInfo;
/**
* Meant for uploading files to remote servers
* @author Adam Gibson
*
*/
public class HostProvisioner implements UserInfo {
private String host;
private JSch jsch;
private String user;
private int port = 22;
private String password;
private static final Logger log = LoggerFactory.getLogger(HostProvisioner.class);
/**
*
* @param host host to connect to (public facing dns)
* @param user the user to connect with (default root otherwise)
* @param password the password to use if any
* @param port the port to connect to(default 22)
*/
public HostProvisioner(String host,String user,String password,int port) {
super();
this.host = host;
this.user = user;
this.port = port;
this.password = password;
jsch = new JSch();
}
/**
* Connects to port 22
* @param host host to connect to (public facing dns)
* @param user the user to connect with (default root otherwise)
* @param password the password to use if any
*/
public HostProvisioner(String host,String user,String password) {
this(host,user,password,22);
}
/**
* Connects to port 22
* @param host host to connect to (public facing dns)
* @param user the user to connect with (default root otherwise)
*/
public HostProvisioner(String host,String user) {
this(host,user,"",22);
}
/**
* Connects to port 22, user root, with no password
* @param host host to connect to (public facing dns)
*/
public HostProvisioner(String host) {
this(host,"root","",22);
}
public void uploadAndRun(String script,String rootDir) throws Exception {
String remoteName = rootDir.isEmpty() ? new File(script).getName() : rootDir + "/" + new File(script).getName();
upload(new File(script),remoteName);
String remoteCommand = remoteName.charAt(0) != '/' ? "./" + remoteName : remoteName;
remoteCommand = "chmod +x " + remoteCommand + " && " + remoteCommand;
runRemoteCommand(remoteCommand);
}
public void runRemoteCommand(String remoteCommand) throws Exception {
Session session = getSession();
session.connect();
ChannelExec channel = (ChannelExec) session.openChannel("exec");
channel.setCommand(remoteCommand);
channel.setErrStream(System.err);
channel.setPty(true);
channel.setOutputStream(System.out);
channel.connect();
channel.start();
InputStream input = channel.getInputStream();
//start reading the input from the executed commands on the shell
byte[] tmp = new byte[60000];
while (true) {
while (input.available() > 0) {
int i = input.read(tmp, 0, tmp.length);
if (i < 0)
break;
log.info(new String(tmp, 0, i));
}
if (channel.isClosed()){
log.info("exit-status: " + channel.getExitStatus());
break;
}
}
channel.disconnect();
session.disconnect();
}
private Session getSession() throws Exception {
Session session = jsch.getSession(user, host, port);
session.setUserInfo(this);
return session;
}
/**
* Creates the directory for the file if necessary
* and uploads the file
* @param from the directory to upload from
* @param to the destination directory on the remote server
* @throws Exception
*/
public void uploadForDeployment(String from,String to) throws Exception {
File fromFile = new File(from);
if(!to.isEmpty() && fromFile.isDirectory())
mkDir(to);
else
upload(from,to);
}
public void addKeyFile(String keyFile) throws Exception {
jsch.addIdentity(keyFile);
}
//creates the directory to upload to
private void mkDir(String dir) throws Exception {
Session session = getSession();
session.connect();
Channel channel = session.openChannel("sftp");
channel.connect();
ChannelSftp c = (ChannelSftp) channel;
if(!fileExists(dir,c))
c.mkdir(dir);
c.exit();
session.disconnect();
}
private boolean fileExists(String dir,ChannelSftp channel) {
try {
channel.stat(dir);
return true;
}catch(Exception e) {
return false;
}
}
//uploads the file or listed files in a directory
private void upload(String fileOrDir,String uploadRootDir) throws Exception {
if(uploadRootDir.isEmpty())
uploadRootDir = ".";
File origin = new File(fileOrDir);
if(fileOrDir.endsWith(".tar") || fileOrDir.endsWith(".tar.gz")) {
upload(new File(fileOrDir),uploadRootDir);
untar(uploadRootDir);
}
else if(origin.isFile()) {
upload(new File(fileOrDir),uploadRootDir);
}
else {
File[] childFiles = origin.listFiles();
if(childFiles != null)
upload(Arrays.asList(childFiles),uploadRootDir);
}
}
private void untar(String targetRemoteFile) throws Exception {
this.runRemoteCommand("tar xvf " + targetRemoteFile);
}
private void upload(Collection files,String rootDir) throws Exception {
Session session = getSession();
session.connect();
Channel channel = session.openChannel("sftp");
channel.connect();
ChannelSftp c = (ChannelSftp) channel;
for(File f : files) {
if(f.isDirectory()) {
log.warn("Skipping " + f.getName());
continue;
}
log.info("Uploading " + f.getName());
BufferedInputStream bis = new BufferedInputStream(new FileInputStream(f));
c.put(bis, rootDir + "/" + f.getName());
bis.close();
}
channel.disconnect();
session.disconnect();
}
private void upload(File f,String remoteFile) throws Exception {
Session session = getSession();
int numRetries = 0;
while(numRetries < 3 && !session.isConnected()) {
try {
session.connect();
}catch(Exception e) {
numRetries++;
}
}
try {
Channel channel = session.openChannel("sftp");
channel.connect();
ChannelSftp c = (ChannelSftp) channel;
BufferedInputStream bis = new BufferedInputStream(new FileInputStream(f));
if (this.fileExists(remoteFile, c))
if (f.isDirectory())
c.rmdir(remoteFile);
else
c.rm(remoteFile);
c.put(bis, remoteFile);
bis.close();
c.exit();
session.disconnect();
}catch(Exception e) {
log.info("Session was down...trying again",e);
upload(f,remoteFile);
}
}
@Override
public String getPassphrase() {
return this.password;
}
@Override
public String getPassword() {
return this.password;
}
@Override
public boolean promptPassphrase(String arg0) {
return true;
}
@Override
public boolean promptPassword(String arg0) {
return true;
}
@Override
public boolean promptYesNo(String arg0) {
return true;
}
@Override
public void showMessage(String arg0) {
log.info(arg0);
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy