ml.shifu.guagua.yarn.GuaguaYarnTask Maven / Gradle / Ivy
/*
* Copyright [2013-2014] PayPal Software Foundation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.shifu.guagua.yarn;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.security.PrivilegedAction;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.List;
import java.util.Map.Entry;
import java.util.Properties;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import ml.shifu.guagua.GuaguaConstants;
import ml.shifu.guagua.GuaguaRuntimeException;
import ml.shifu.guagua.GuaguaService;
import ml.shifu.guagua.hadoop.io.GuaguaInputSplit;
import ml.shifu.guagua.io.Bytable;
import ml.shifu.guagua.io.GuaguaFileSplit;
import ml.shifu.guagua.master.GuaguaMasterService;
import ml.shifu.guagua.util.Progressable;
import ml.shifu.guagua.worker.GuaguaWorkerService;
import ml.shifu.guagua.yarn.util.GsonUtils;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FSDataInputStream;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.io.serializer.Deserializer;
import org.apache.hadoop.io.serializer.SerializationFactory;
import org.apache.hadoop.mapreduce.MRJobConfig;
import org.apache.hadoop.mapreduce.lib.input.FileSplit;
import org.apache.hadoop.security.Credentials;
import org.apache.hadoop.security.UserGroupInformation;
import org.apache.hadoop.security.token.Token;
import org.apache.hadoop.yarn.api.ApplicationConstants;
import org.apache.hadoop.yarn.api.ApplicationConstants.Environment;
import org.apache.hadoop.yarn.api.records.ApplicationAttemptId;
import org.apache.hadoop.yarn.api.records.ApplicationId;
import org.apache.hadoop.yarn.api.records.ContainerId;
import org.apache.hadoop.yarn.conf.YarnConfiguration;
import org.apache.hadoop.yarn.util.ConverterUtils;
import org.apache.zookeeper.common.IOUtils;
import org.jboss.netty.bootstrap.ClientBootstrap;
import org.jboss.netty.channel.Channel;
import org.jboss.netty.channel.ChannelEvent;
import org.jboss.netty.channel.ChannelFuture;
import org.jboss.netty.channel.ChannelHandlerContext;
import org.jboss.netty.channel.ChannelPipeline;
import org.jboss.netty.channel.ChannelPipelineFactory;
import org.jboss.netty.channel.ChannelStateEvent;
import org.jboss.netty.channel.Channels;
import org.jboss.netty.channel.ExceptionEvent;
import org.jboss.netty.channel.MessageEvent;
import org.jboss.netty.channel.SimpleChannelUpstreamHandler;
import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory;
import org.jboss.netty.handler.codec.serialization.ClassResolvers;
import org.jboss.netty.handler.codec.serialization.ObjectDecoder;
import org.jboss.netty.handler.codec.serialization.ObjectEncoder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* {@link GuaguaYarnTask} is a entry point to run both master and workers.
*
*
* {@link #partition} should be passed as the last parameter in main. And it should be not be changed if we try another
* task.
*
*
* Input split are now storing in guagua-conf.xml. We read data from there and check whether this task is master or
* worker.
*/
public class GuaguaYarnTask {
private static final Logger LOG = LoggerFactory.getLogger(GuaguaYarnTask.class);
static {
// pick up new conf XML file and populate it with stuff exported from client
Configuration.addDefaultResource(GuaguaYarnConstants.GUAGUA_CONF_FILE);
}
/**
* Partition is never changed, it is the index of all spits. If a fail-over task, it should keep the partition
* unchanged.
*/
private int partition;
/**
* Application attempt id
*/
private ApplicationAttemptId appAttemptId;
/**
* Container id
*/
private ContainerId containerId;
/**
* Application id
*/
private ApplicationId appId;
/**
* Yarn conf
*/
private Configuration yarnConf;
/**
* Whether this task is master
*/
private boolean isMaster;
/**
* Service instance to run master or worker service.
*/
private GuaguaService guaguaService;
/**
* Input split for worker tasks
*/
private GuaguaInputSplit inputSplit;
/**
* RPC port used to connect to RPC server.
*/
private int rpcPort = GuaguaYarnConstants.DEFAULT_STATUS_RPC_PORT;
/**
* RPC server host name
*/
private String rpcHostName;
/**
* Client channel used to connect to RPC server.
*/
private Channel rpcClientChannel;
/**
* Netty client instance.
*/
private ClientBootstrap rpcClient;
/**
* Constructor with yarn task related parameters.
*/
public GuaguaYarnTask(ApplicationAttemptId appAttemptId, ContainerId containerId, int partition,
String rpcHostName, String rpcPort, Configuration conf) {
this.appAttemptId = appAttemptId;
this.containerId = containerId;
this.partition = partition;
this.rpcHostName = rpcHostName;
this.rpcPort = Integer.parseInt(rpcPort);
LOG.info("current partition:{}", this.getPartition());
this.appId = this.getAppAttemptId().getApplicationId();
this.yarnConf = conf;
this.inputSplit = GsonUtils.fromJson(
this.getYarnConf().get(GuaguaYarnConstants.GUAGUA_YARN_INPUT_SPLIT_PREFIX + partition),
GuaguaInputSplit.class);
LOG.info("current input split:{}", this.getInputSplit());
}
/**
* Set up guagua service
*/
protected void setup() {
this.setMaster(this.getInputSplit().isMaster());
if(this.isMaster()) {
this.setGuaguaService(new GuaguaMasterService());
} else {
this.setGuaguaService(new GuaguaWorkerService());
List splits = new LinkedList();
for(FileSplit fileSplit: getInputSplit().getFileSplits()) {
splits.add(new GuaguaFileSplit(fileSplit.getPath().toString(), fileSplit.getStart(), fileSplit
.getLength()));
}
this.getGuaguaService().setSplits(splits);
}
Properties props = replaceConfToProps();
this.getGuaguaService().setAppId(this.getAppId().toString());
this.getGuaguaService().setContainerId(this.getPartition() + "");
this.getGuaguaService().init(props);
this.getGuaguaService().start();
initRPCClient();
}
/**
* Connect to app master for status RPC report.
*/
private void initRPCClient() {
this.rpcClient = new ClientBootstrap(new NioClientSocketChannelFactory(Executors.newSingleThreadExecutor(),
Executors.newSingleThreadExecutor()));
// Set up the pipeline factory.
this.rpcClient.setPipelineFactory(new ChannelPipelineFactory() {
public ChannelPipeline getPipeline() throws Exception {
return Channels.pipeline(new ObjectEncoder(),
new ObjectDecoder(ClassResolvers.cacheDisabled(getClass().getClassLoader())),
new ClientHandler());
}
});
// Start the connection attempt.
ChannelFuture future = this.rpcClient.connect(new InetSocketAddress(this.rpcHostName, this.rpcPort));
LOG.info("Connect to {}:{}", this.rpcHostName, this.rpcPort);
this.rpcClientChannel = future.awaitUninterruptibly().getChannel();
}
/**
* ClientHandeler used to update progress to RPC server (AppMaster).
*/
public static class ClientHandler extends SimpleChannelUpstreamHandler {
@Override
public void handleUpstream(ChannelHandlerContext ctx, ChannelEvent e) throws Exception {
super.handleUpstream(ctx, e);
}
@Override
public void channelConnected(ChannelHandlerContext ctx, ChannelStateEvent e) {
// Send the first message if this handler is a client-side handler.
LOG.info("Channel connected:{}", e.getValue());
}
@Override
public void messageReceived(ChannelHandlerContext ctx, MessageEvent e) {
LOG.info("Receive status:{}", e.getMessage());
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, ExceptionEvent e) {
e.getChannel().close();
}
}
/**
* We have to replace {@link Configuration} to {@link Properties} because of no dependency on hadoop in guagua-core.
*/
private Properties replaceConfToProps() {
Properties properties = new Properties();
for(Entry entry: getYarnConf()) {
properties.put(entry.getKey(), entry.getValue());
if(LOG.isInfoEnabled()) {
if(entry.getKey().toString().startsWith(GuaguaConstants.GUAGUA)) {
LOG.debug("{}:{}", entry.getKey(), entry.getValue());
}
}
}
return properties;
}
@SuppressWarnings({ "unchecked", "unused" })
private T getSplitDetails(Path file, long offset) throws IOException {
FileSystem fs = file.getFileSystem(getYarnConf());
FSDataInputStream inFile = null;
T split = null;
try {
inFile = fs.open(file);
inFile.seek(offset);
String className = Text.readString(inFile);
Class cls;
try {
cls = (Class) getYarnConf().getClassByName(className);
} catch (ClassNotFoundException ce) {
IOException wrap = new IOException(String.format("Split class %s not found", className));
wrap.initCause(ce);
throw wrap;
}
SerializationFactory factory = new SerializationFactory(getYarnConf());
Deserializer deserializer = (Deserializer) factory.getDeserializer(cls);
deserializer.open(inFile);
split = deserializer.deserialize(null);
} finally {
IOUtils.closeStream(inFile);
}
return split;
}
/**
* Run master or worker service.
*/
public void run() {
try {
this.setup();
this.getGuaguaService().run(new Progressable() {
@Override
public void progress(int currentIteration, int totalIteration, String status, boolean isLastUpdate,
boolean isKill) {
// if is last update in current iteration, progress and status should be updated
if(isLastUpdate) {
LOG.info("Application progress: {}%.", (currentIteration * 100 / totalIteration));
GuaguaIterationStatus gi = new GuaguaIterationStatus(GuaguaYarnTask.this.partition,
currentIteration, totalIteration);
gi.setKillContainer(isKill);
LOG.info("Send GuaguaIterationStatus: {}.", gi);
ChannelFuture channelFuture = rpcClientChannel.write(GsonUtils.toJson(gi));
try {
channelFuture.await(10, TimeUnit.SECONDS);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
}
}
});
} catch (Exception e) {
LOG.error("Error in guagua main run method.", e);
throw new GuaguaRuntimeException(e);
} finally {
// cleanup should be called in finally segment to make sure resources are cleaned up at last.
this.cleanup();
}
}
/**
* Clean up resources used
*/
protected void cleanup() {
if(this.rpcClient != null) {
this.rpcClient.shutdown();
this.rpcClient.releaseExternalResources();
}
if(this.rpcClientChannel != null) {
this.rpcClientChannel.close();
}
this.getGuaguaService().stop();
}
public GuaguaService getGuaguaService() {
return guaguaService;
}
public void setGuaguaService(GuaguaService guaguaService) {
this.guaguaService = guaguaService;
}
public int getPartition() {
return partition;
}
public void setPartition(int partition) {
this.partition = partition;
}
public ApplicationAttemptId getAppAttemptId() {
return appAttemptId;
}
public void setAppAttemptId(ApplicationAttemptId appAttemptId) {
this.appAttemptId = appAttemptId;
}
public ContainerId getContainerId() {
return containerId;
}
public void setContainerId(ContainerId containerId) {
this.containerId = containerId;
}
public boolean isMaster() {
return isMaster;
}
public void setMaster(boolean isMaster) {
this.isMaster = isMaster;
}
public Configuration getYarnConf() {
return yarnConf;
}
public void setYarnConf(YarnConfiguration yarnConf) {
this.yarnConf = yarnConf;
}
public ApplicationId getAppId() {
return appId;
}
public void setAppId(ApplicationId appId) {
this.appId = appId;
}
public GuaguaInputSplit getInputSplit() {
return inputSplit;
}
public void setInputSplit(GuaguaInputSplit inputSplit) {
this.inputSplit = inputSplit;
}
public static void main(String[] args) {
LOG.info("args:{}", Arrays.toString(args));
if(args.length != 7) {
throw new IllegalStateException(String.format(
"GuaguaYarnTask could not construct a TaskAttemptID for the Guagua job from args: %s",
Arrays.toString(args)));
}
String containerIdString = System.getenv().get(Environment.CONTAINER_ID.name());
if(containerIdString == null) {
// container id should always be set in the env by the framework
throw new IllegalArgumentException("ContainerId not found in env vars.");
}
ContainerId containerId = ConverterUtils.toContainerId(containerIdString);
ApplicationAttemptId appAttemptId = containerId.getApplicationAttemptId();
try {
Configuration conf = new YarnConfiguration();
String jobUserName = System.getenv(ApplicationConstants.Environment.USER.name());
conf.set(MRJobConfig.USER_NAME, jobUserName);
UserGroupInformation.setConfiguration(conf);
// Security framework already loaded the tokens into current UGI, just use them
Credentials credentials = UserGroupInformation.getCurrentUser().getCredentials();
LOG.info("Executing with tokens:");
for(Token> token: credentials.getAllTokens()) {
LOG.info(token.toString());
}
UserGroupInformation appTaskUGI = UserGroupInformation.createRemoteUser(jobUserName);
appTaskUGI.addCredentials(credentials);
@SuppressWarnings("rawtypes")
final GuaguaYarnTask, ?> guaguaYarnTask = new GuaguaYarnTask(appAttemptId, containerId,
Integer.parseInt(args[args.length - 3]), args[args.length - 2], args[args.length - 1], conf);
appTaskUGI.doAs(new PrivilegedAction() {
@Override
public Void run() {
guaguaYarnTask.run();
return null;
}
});
} catch (Throwable t) {
LOG.error("GuaguaYarnTask threw a top-level exception, failing task", t);
System.exit(2);
}
System.exit(0);
}
}