org.deeplearning4j.aws.emr.SparkEMRClient Maven / Gradle / Ivy
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.aws.emr;
import com.amazonaws.services.elasticmapreduce.AmazonElasticMapReduce;
import com.amazonaws.services.elasticmapreduce.AmazonElasticMapReduceClientBuilder;
import com.amazonaws.services.elasticmapreduce.model.*;
import com.amazonaws.services.s3.AmazonS3ClientBuilder;
import com.amazonaws.services.s3.AmazonS3URI;
import com.amazonaws.services.s3.model.PutObjectRequest;
import lombok.AccessLevel;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang.RandomStringUtils;
import org.apache.spark.api.java.function.Function;
import java.io.File;
import java.util.*;
/**
* Configuration for a Spark EMR cluster
*/
@Data
@AllArgsConstructor(access = AccessLevel.PRIVATE)
@NoArgsConstructor
@Slf4j
public class SparkEMRClient {
protected String sparkClusterName = RandomStringUtils.randomAlphanumeric(12);
protected String sparkAwsRegion = "us-east-1";
protected String sparkEmrRelease = "emr-5.9.0";
protected String sparkEmrServiceRole = "EMR_DefaultRole";
protected List sparkEmrConfigs = Collections.emptyList();
protected String sparkSubnetId = null;
protected List sparkSecurityGroupIds = Collections.emptyList();
protected int sparkInstanceCount = 1;
protected String sparkInstanceType = "m3.xlarge";
protected Optional sparkInstanceBidPrice = Optional.empty();
protected String sparkInstanceRole = "EMR_EC2_DefaultRole";
protected String sparkS3JarFolder = "changeme";
protected int sparkTimeoutDurationMinutes = 90;
//underlying configs
protected AmazonElasticMapReduceClientBuilder sparkEmrClientBuilder;
protected AmazonS3ClientBuilder sparkS3ClientBuilder;
protected JobFlowInstancesConfig sparkJobFlowInstancesConfig;
protected RunJobFlowRequest sparkRunJobFlowRequest;
protected Function sparkS3PutObjectDecorator;
protected Map sparkSubmitConfs;
private static ClusterState[] activeClusterStates = new ClusterState[]{
ClusterState.RUNNING,
ClusterState.STARTING,
ClusterState.WAITING,
ClusterState.BOOTSTRAPPING};
private Optional findClusterWithName(AmazonElasticMapReduce emr, String name) {
List csrl = emr.listClusters((new ListClustersRequest()).withClusterStates(activeClusterStates)).getClusters();
for (ClusterSummary csr : csrl) {
if (csr.getName().equals(name)) return Optional.of(csr);
}
return Optional.empty();
}
/**
* Creates the current cluster
*/
public void createCluster() {
AmazonElasticMapReduce emr = sparkEmrClientBuilder.build();
Optional csr = findClusterWithName(emr, sparkClusterName);
if (csr.isPresent()) {
String msg = String.format("A cluster with name %s and id %s is already deployed", sparkClusterName, csr.get().getId());
log.error(msg);
throw new IllegalStateException(msg);
} else {
RunJobFlowResult res = emr.runJobFlow(sparkRunJobFlowRequest);
String msg = String.format("Your cluster is launched with name %s and id %s.", sparkClusterName, res.getJobFlowId());
log.info(msg);
}
}
private void logClusters(List csrl) {
if (csrl.isEmpty()) log.info("No cluster found.");
else {
log.info(String.format("%d clusters found.", csrl.size()));
for (ClusterSummary csr : csrl) {
log.info(String.format("Name: %s | Id: %s", csr.getName(), csr.getId()));
}
}
}
/**
* Lists existing active clusters Names
*
* @return cluster names
*/
public List listActiveClusterNames() {
AmazonElasticMapReduce emr = sparkEmrClientBuilder.build();
List csrl =
emr.listClusters(new ListClustersRequest().withClusterStates(activeClusterStates)).getClusters();
logClusters(csrl);
List res = new ArrayList<>(csrl.size());
for (ClusterSummary csr : csrl) res.add(csr.getName());
return res;
}
/**
* List existing active cluster IDs
*
* @return cluster IDs
*/
public List listActiveClusterIds() {
AmazonElasticMapReduce emr = sparkEmrClientBuilder.build();
List csrl =
emr.listClusters(new ListClustersRequest().withClusterStates(activeClusterStates)).getClusters();
logClusters(csrl);
List res = new ArrayList<>(csrl.size());
for (ClusterSummary csr : csrl) res.add(csr.getId());
return res;
}
/**
* Terminates a cluster
*/
public void terminateCluster() {
AmazonElasticMapReduce emr = sparkEmrClientBuilder.build();
Optional optClusterSum = findClusterWithName(emr, sparkClusterName);
if (!optClusterSum.isPresent()) {
log.error(String.format("The cluster with name %s , requested for deletion, does not exist.", sparkClusterName));
} else {
String id = optClusterSum.get().getId();
emr.terminateJobFlows((new TerminateJobFlowsRequest()).withJobFlowIds(id));
log.info(String.format("The cluster with id %s is terminating.", id));
}
}
// The actual job-sumission logic
private void submitJob(AmazonElasticMapReduce emr, String mainClass, List args, Map sparkConfs, File uberJar) throws Exception {
AmazonS3URI s3Jar = new AmazonS3URI(sparkS3JarFolder + "/" + uberJar.getName());
log.info(String.format("Placing uberJar %s to %s", uberJar.getPath(), s3Jar.toString()));
PutObjectRequest putRequest = sparkS3PutObjectDecorator.call(
new PutObjectRequest(s3Jar.getBucket(), s3Jar.getKey(), uberJar)
);
sparkS3ClientBuilder.build().putObject(putRequest);
// The order of these matters
List sparkSubmitArgs = Arrays.asList(
"spark-submit",
"--deploy-mode",
"cluster",
"--class",
mainClass
);
for (Map.Entry e : sparkConfs.entrySet()) {
sparkSubmitArgs.add(String.format("--conf %s = %s ", e.getKey(), e.getValue()));
}
sparkSubmitArgs.add(s3Jar.toString());
sparkSubmitArgs.addAll(args);
StepConfig step = new StepConfig()
.withActionOnFailure(ActionOnFailure.CONTINUE)
.withName("Spark step")
.withHadoopJarStep(
new HadoopJarStepConfig()
.withJar("command-runner.jar")
.withArgs(sparkSubmitArgs)
);
Optional optCsr = findClusterWithName(emr, sparkClusterName);
if (optCsr.isPresent()) {
ClusterSummary csr = optCsr.get();
emr.addJobFlowSteps(
new AddJobFlowStepsRequest()
.withJobFlowId(csr.getId())
.withSteps(step));
log.info(
String.format("Your job is added to the cluster with id %s.", csr.getId())
);
} else {
// If the cluster wasn't started, it's assumed ot be throwaway
List steps = sparkRunJobFlowRequest.getSteps();
steps.add(step);
RunJobFlowRequest jobFlowRequest = sparkRunJobFlowRequest
.withSteps(steps)
.withInstances(sparkJobFlowInstancesConfig.withKeepJobFlowAliveWhenNoSteps(false));
RunJobFlowResult res = emr.runJobFlow(jobFlowRequest);
log.info("Your new cluster's id is %s.", res.getJobFlowId());
}
}
/**
* Submit a Spark Job with a specified main class
*/
public void sparkSubmitJobWithMain(String[] args, String mainClass, File uberJar) throws Exception {
AmazonElasticMapReduce emr = sparkEmrClientBuilder.build();
submitJob(emr, mainClass, Arrays.asList(args), sparkSubmitConfs, uberJar);
}
private void checkStatus(AmazonElasticMapReduce emr, String clusterId) throws InterruptedException {
log.info(".");
com.amazonaws.services.elasticmapreduce.model.Cluster dcr =
emr.describeCluster((new DescribeClusterRequest()).withClusterId(clusterId)).getCluster();
String state = dcr.getStatus().getState();
long timeOutTime = System.currentTimeMillis() + ((long) sparkTimeoutDurationMinutes * 60 * 1000);
Boolean activated = Arrays.asList(activeClusterStates).contains(ClusterState.fromValue(state));
Boolean timedOut = System.currentTimeMillis() > timeOutTime;
if (activated && timedOut) {
emr.terminateJobFlows(
new TerminateJobFlowsRequest().withJobFlowIds(clusterId)
);
log.error("Timeout. Cluster terminated.");
} else if (!activated) {
Boolean hasAbnormalStep = false;
StepSummary stepS = null;
List steps = emr.listSteps(new ListStepsRequest().withClusterId(clusterId)).getSteps();
for (StepSummary step : steps) {
if (step.getStatus().getState() != StepState.COMPLETED.toString()) {
hasAbnormalStep = true;
stepS = step;
}
}
if (hasAbnormalStep && stepS != null)
log.error(String.format("Cluster %s terminated with an abnormal step, name %s, id %s", clusterId, stepS.getName(), stepS.getId()));
else
log.info("Cluster %s terminated without error.", clusterId);
} else {
Thread.sleep(5000);
checkStatus(emr, clusterId);
}
}
/**
* Monitor the cluster and terminates when it times out
*/
public void sparkMonitor() throws InterruptedException {
AmazonElasticMapReduce emr = sparkEmrClientBuilder.build();
Optional optCsr = findClusterWithName(emr, sparkClusterName);
if (!optCsr.isPresent()) {
log.error(String.format("The cluster with name %s does not exist.", sparkClusterName));
} else {
ClusterSummary csr = optCsr.get();
log.info(String.format("found cluster with id %s, starting monitoring", csr.getId()));
checkStatus(emr, csr.getId());
}
}
@Data
public static class Builder {
protected String sparkClusterName = RandomStringUtils.randomAlphanumeric(12);
protected String sparkAwsRegion = "us-east-1";
protected String sparkEmrRelease = "emr-5.9.0";
protected String sparkEmrServiceRole = "EMR_DefaultRole";
protected List sparkEmrConfigs = Collections.emptyList();
protected String sparkSubNetid = null;
protected List sparkSecurityGroupIds = Collections.emptyList();
protected int sparkInstanceCount = 1;
protected String sparkInstanceType = "m3.xlarge";
protected Optional sparkInstanceBidPrice = Optional.empty();
protected String sparkInstanceRole = "EMR_EC2_DefaultRole";
protected String sparkS3JarFolder = "changeme";
protected int sparkTimeoutDurationMinutes = 90;
protected AmazonElasticMapReduceClientBuilder sparkEmrClientBuilder;
protected AmazonS3ClientBuilder sparkS3ClientBuilder;
protected JobFlowInstancesConfig sparkJobFlowInstancesConfig;
protected RunJobFlowRequest sparkRunJobFlowRequest;
// This should allow the user to decorate the put call to add metadata to the jar put command, such as security groups,
protected Function sparkS3PutObjectDecorator = new Function() {
@Override
public PutObjectRequest call(PutObjectRequest putObjectRequest) throws Exception {
return putObjectRequest;
}
};
protected Map sparkSubmitConfs;
/**
* Defines the EMR cluster's name
*
* @param clusterName the EMR cluster's name
* @return an EMR cluster builder
*/
public Builder clusterName(String clusterName) {
this.sparkClusterName = clusterName;
return this;
}
/**
* Defines the EMR cluster's region
* See https://docs.aws.amazon.com/general/latest/gr/rande.html
*
* @param region the EMR cluster's region
* @return an EMR cluster builder
*/
public Builder awsRegion(String region) {
this.sparkAwsRegion = region;
return this;
}
/**
* Defines the EMR release version to be used in this cluster
* uses a release label
* See https://docs.aws.amazon.com/emr/latest/ReleaseGuide/emr-4.2.0/emr-release-differences.html#emr-release-label
*
* @param releaseLabel the EMR release label
* @return an EM cluster Builder
*/
public Builder emrRelease(String releaseLabel) {
this.sparkEmrRelease = releaseLabel;
return this;
}
/**
* Defines the IAM role to be assumed by the EMR service
*
* https://docs.aws.amazon.com/IAM/latest/UserGuide/id_roles_create_for-service.html
*
* @param serviceRole the service role
* @return an EM cluster Builder
*/
public Builder emrServiceRole(String serviceRole) {
this.sparkEmrServiceRole = serviceRole;
return this;
}
/**
* A list of configuration parameters to apply to EMR instances.
*
* @param configs the EMR configurations to apply to this cluster
* @return an EMR cluster builder
*/
public Builder emrConfigs(List configs) {
this.sparkEmrConfigs = configs;
return this;
}
/**
* The id of the EC2 subnet to be used for this Spark EMR service
* see https://docs.aws.amazon.com/AmazonVPC/latest/UserGuide/VPC_Subnets.html
*
* @param id the subnet ID
* @return an EMR cluster builder
*/
public Builder subnetId(String id) {
this.sparkSubNetid = id;
return this;
}
/**
* The id of additional security groups this deployment should adopt for both master and slaves
*
* @param securityGroups
* @return an EMR cluster builder
*/
public Builder securityGroupIDs(List securityGroups) {
this.sparkSecurityGroupIds = securityGroups;
return this;
}
/**
* The number of instances this deployment should comprise of
*
* @param count the number of instances for this cluster
* @rturn an EMR cluster buidler
*/
public Builder instanceCount(int count) {
this.sparkInstanceCount = count;
return this;
}
/**
* The type of instance this cluster should comprise of
* See https://aws.amazon.com/ec2/instance-types/
*
* @param instanceType the type of instance for this cluster
* @return an EMR cluster builder
*/
public Builder instanceType(String instanceType) {
this.sparkInstanceType = instanceType;
return this;
}
/**
* The optional bid value for this cluster's spot instances
* see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/how-spot-instances-work.html
* Uses the on-demand market if empty.
*
* @param optBid the Optional bid price for this cluster's instnces
* @return an EMR cluster Builder
*/
public Builder instanceBidPrice(Optional optBid) {
this.sparkInstanceBidPrice = optBid;
return this;
}
/**
* The EC2 instance role that this cluster's instances should assume
* see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/iam-roles-for-amazon-ec2.html
*
* @param role the intended instance role
* @return an EMR cluster builder
*/
public Builder instanceRole(String role) {
this.sparkInstanceRole = role;
return this;
}
/**
* the S3 folder in which to find the application jar
*
* @param jarfolder the S3 folder in which to find a jar
* @return an EMR cluster builder
*/
public Builder s3JarFolder(String jarfolder) {
this.sparkS3JarFolder = jarfolder;
return this;
}
/**
* The timeout duration for this Spark EMR cluster, in minutes
*
* @param timeoutMinutes
* @return an EMR cluster builder
*/
public Builder sparkTimeOutDurationMinutes(int timeoutMinutes) {
this.sparkTimeoutDurationMinutes = timeoutMinutes;
return this;
}
/**
* Creates an EMR Spark cluster deployment
*
* @return a SparkEMRClient
*/
public SparkEMRClient build() {
this.sparkEmrClientBuilder = AmazonElasticMapReduceClientBuilder.standard().withRegion(sparkAwsRegion);
this.sparkS3ClientBuilder = AmazonS3ClientBuilder.standard().withRegion(sparkAwsRegion);
// note this will be kept alive without steps, an arbitrary choice to avoid rapid test-teardown-restart cycles
this.sparkJobFlowInstancesConfig = (new JobFlowInstancesConfig()).withKeepJobFlowAliveWhenNoSteps(true);
if (this.sparkSubNetid != null)
this.sparkJobFlowInstancesConfig = this.sparkJobFlowInstancesConfig.withEc2SubnetId(this.sparkSubNetid);
if (!this.sparkSecurityGroupIds.isEmpty()) {
this.sparkJobFlowInstancesConfig = this.sparkJobFlowInstancesConfig.withAdditionalMasterSecurityGroups(this.sparkSecurityGroupIds);
this.sparkJobFlowInstancesConfig = this.sparkJobFlowInstancesConfig.withAdditionalSlaveSecurityGroups(this.sparkSecurityGroupIds);
}
InstanceGroupConfig masterConfig =
(new InstanceGroupConfig()).withInstanceCount(1).withInstanceRole(InstanceRoleType.MASTER).withInstanceType(sparkInstanceType);
if (sparkInstanceBidPrice.isPresent()) {
masterConfig = masterConfig.withMarket(MarketType.SPOT).withBidPrice(sparkInstanceBidPrice.get().toString());
} else {
masterConfig = masterConfig.withMarket(MarketType.ON_DEMAND);
}
int slaveCount = sparkInstanceCount - 1;
InstanceGroupConfig slaveConfig =
(new InstanceGroupConfig()).withInstanceCount(slaveCount).withInstanceRole(InstanceRoleType.CORE).withInstanceRole(sparkInstanceType);
if (sparkInstanceBidPrice.isPresent()) {
slaveConfig = slaveConfig.withMarket(MarketType.SPOT).withBidPrice(sparkInstanceBidPrice.get().toString());
} else {
slaveConfig = slaveConfig.withMarket(MarketType.ON_DEMAND);
}
if (slaveCount > 0) {
this.sparkJobFlowInstancesConfig = this.sparkJobFlowInstancesConfig.withInstanceGroups(Arrays.asList(masterConfig, slaveConfig));
} else {
this.sparkJobFlowInstancesConfig = this.sparkJobFlowInstancesConfig.withInstanceGroups(slaveConfig);
}
this.sparkRunJobFlowRequest = new RunJobFlowRequest();
if (!sparkEmrConfigs.isEmpty()) {
List emrConfigs = new ArrayList<>();
for (EmrConfig config : sparkEmrConfigs) {
emrConfigs.add(config.toAwsConfig());
}
this.sparkRunJobFlowRequest = this.sparkRunJobFlowRequest.withConfigurations(emrConfigs);
}
this.sparkRunJobFlowRequest =
this.sparkRunJobFlowRequest.withName(sparkClusterName).withApplications((new Application()).withName("Spark"))
.withReleaseLabel(sparkEmrRelease)
.withServiceRole(sparkEmrServiceRole)
.withJobFlowRole(sparkInstanceRole)
.withInstances(this.sparkJobFlowInstancesConfig);
return new SparkEMRClient(
this.sparkClusterName,
this.sparkAwsRegion,
this.sparkEmrRelease,
this.sparkEmrServiceRole,
this.sparkEmrConfigs,
this.sparkSubNetid,
this.sparkSecurityGroupIds,
this.sparkInstanceCount,
this.sparkInstanceType,
this.sparkInstanceBidPrice,
this.sparkInstanceRole,
this.sparkS3JarFolder,
this.sparkTimeoutDurationMinutes,
this.sparkEmrClientBuilder,
this.sparkS3ClientBuilder,
this.sparkJobFlowInstancesConfig,
this.sparkRunJobFlowRequest,
this.sparkS3PutObjectDecorator,
this.sparkSubmitConfs
);
}
}
}