com.tencent.angel.client.AngelClient Maven / Gradle / Ivy
The newest version!
/*
* Tencent is pleased to support the open source community by making Angel available.
*
* Copyright (C) 2017-2018 THL A29 Limited, a Tencent company. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in
* compliance with the License. You may obtain a copy of the License at
*
* https://opensource.org/licenses/Apache-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License
* is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
* or implied. See the License for the specific language governing permissions and limitations under
* the License.
*
*/
package com.tencent.angel.client;
import com.google.protobuf.ServiceException;
import com.tencent.angel.RunningMode;
import com.tencent.angel.common.location.Location;
import com.tencent.angel.conf.AngelConf;
import com.tencent.angel.conf.MatrixConf;
import com.tencent.angel.exception.AngelException;
import com.tencent.angel.exception.InvalidParameterException;
import com.tencent.angel.ipc.TConnectionManager;
import com.tencent.angel.master.MasterProtocol;
import com.tencent.angel.ml.matrix.MatrixContext;
import com.tencent.angel.ml.model.MLModel;
import com.tencent.angel.ml.model.PSModel;
import com.tencent.angel.model.LoadState;
import com.tencent.angel.model.MatrixLoadContext;
import com.tencent.angel.model.MatrixSaveContext;
import com.tencent.angel.model.ModelLoadContext;
import com.tencent.angel.model.ModelSaveContext;
import com.tencent.angel.model.SaveState;
import com.tencent.angel.model.output.format.ModelFilesConstent;
import com.tencent.angel.model.output.format.RowIdColIdValueTextRowFormat;
import com.tencent.angel.protobuf.ProtobufUtil;
import com.tencent.angel.protobuf.generated.ClientMasterServiceProtos;
import com.tencent.angel.protobuf.generated.ClientMasterServiceProtos.CheckModelLoadedRequest;
import com.tencent.angel.protobuf.generated.ClientMasterServiceProtos.CheckModelLoadedResponse;
import com.tencent.angel.protobuf.generated.ClientMasterServiceProtos.CheckModelSavedRequest;
import com.tencent.angel.protobuf.generated.ClientMasterServiceProtos.CheckModelSavedResponse;
import com.tencent.angel.protobuf.generated.ClientMasterServiceProtos.ClientRegisterRequest;
import com.tencent.angel.protobuf.generated.ClientMasterServiceProtos.GetClientIdRequest;
import com.tencent.angel.protobuf.generated.ClientMasterServiceProtos.GetJobReportRequest
;
import com.tencent.angel.protobuf.generated.ClientMasterServiceProtos.GetJobReportResponse;
import com.tencent.angel.protobuf.generated.ClientMasterServiceProtos.JobReportProto;
import com.tencent.angel.protobuf.generated.ClientMasterServiceProtos.JobStateProto;
import com.tencent.angel.protobuf.generated.ClientMasterServiceProtos.KeepAliveRequest;
import com.tencent.angel.protobuf.generated.ClientMasterServiceProtos.LoadRequest;
import com.tencent.angel.protobuf.generated.ClientMasterServiceProtos.SaveRequest;
import com.tencent.angel.protobuf.generated.ClientMasterServiceProtos.SetParamsRequest;
import com.tencent.angel.protobuf.generated.ClientMasterServiceProtos.StartRequest;
import com.tencent.angel.protobuf.generated.MLProtos.CheckMatricesCreatedRequest;
import com.tencent.angel.protobuf.generated.MLProtos.CheckMatricesCreatedResponse;
import com.tencent.angel.protobuf.generated.MLProtos.GetAllPSLocationRequest;
import com.tencent.angel.protobuf.generated.MLProtos.GetAllPSLocationResponse;
import com.tencent.angel.protobuf.generated.MLProtos.PSLocationProto;
import com.tencent.angel.protobuf.generated.MLProtos.PSStatus;
import com.tencent.angel.protobuf.generated.MLProtos.Pair;
import com.tencent.angel.utils.HdfsUtil;
import com.tencent.angel.utils.UGITools;
import com.tencent.angel.worker.WorkerGroupId;
import com.tencent.angel.worker.WorkerId;
import com.tencent.angel.worker.task.BaseTask;
import java.io.IOException;
import java.net.InetAddress;
import java.net.UnknownHostException;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.atomic.AtomicBoolean;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FSDataInputStream;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
/**
* Angel application client. It provides the control interfaces for the application.
*/
public abstract class AngelClient implements AngelClientInterface {
private static final Log LOG = LogFactory.getLog(AngelClient.class);
/**
* application configuration
*/
protected final Configuration conf;
/**
* matrices used in the application
*/
private final Map nameToMatrixMap;
/**
* rpc client to master
*/
protected volatile MasterProtocol master;
private GetJobReportRequest getJobReportReq;
private GetJobReportRequest.Builder getJobReportReqBuilder;
private GetJobReportResponse lastReport;
private boolean isExecuteFinished;
private boolean isFinished;
private String appFailedMessage;
/**
* temporary file use to store application state
*/
protected Path internalStateFile;
/**
* the application submitting user
*/
protected String userName;
/**
* master location
*/
protected Location masterLocation;
private static final DecimalFormat df = new DecimalFormat("#0.000000");
private volatile int clientId = -1;
private volatile Thread hbThread;
private final int hbIntervalMS;
private final int hbTimeoutMS;
private final AtomicBoolean stopped = new AtomicBoolean(false);
/**
* Create a new AngelClient.
*
* @param conf application configuration
*/
public AngelClient(Configuration conf) {
this.conf = conf;
nameToMatrixMap = new LinkedHashMap<>();
isExecuteFinished = false;
isFinished = false;
hbIntervalMS = conf.getInt(AngelConf.ANGEL_CLIENT_HEARTBEAT_INTERVAL_MS,
AngelConf.DEFAULT_ANGEL_CLIENT_HEARTBEAT_INTERVAL_MS);
hbTimeoutMS = conf.getInt(AngelConf.ANGEL_CLIENT_HEARTBEAT_INTERVAL_TIMEOUT_MS,
AngelConf.DEFAULT_ANGEL_CLIENT_HEARTBEAT_INTERVAL_TIMEOUT_MS);
}
@SuppressWarnings("rawtypes")
@Override
public void runTask(Class extends BaseTask> taskClass)
throws AngelException {
if (master == null) {
throw new AngelException(
"parameter servers are not started, you must execute startPSServer first!!");
}
try {
master.setParams(null, SetParamsRequest.newBuilder().addKvs(
Pair.newBuilder().setKey(AngelConf.ANGEL_TASK_USER_TASKCLASS)
.setValue(taskClass.getName())
.build()).build());
master.start(null, StartRequest.newBuilder().build());
} catch (ServiceException e) {
LOG.error("start application failed.", e);
throw new AngelException(e);
}
}
public void runTask(String taskClassName) throws AngelException {
if (master == null) {
throw new AngelException(
"parameter servers are not started, you must execute startPSServer first!!");
}
try {
master.setParams(null, SetParamsRequest.newBuilder().addKvs(
Pair.newBuilder().setKey(AngelConf.ANGEL_TASK_USER_TASKCLASS).setValue(taskClassName)
.build()).build());
master.start(null, StartRequest.newBuilder().build());
} catch (ServiceException e) {
LOG.error("start application failed.", e);
throw new AngelException(e);
}
}
protected void startHeartbeat() throws ServiceException {
if (master == null) {
LOG.error("Master has not been connected");
return;
}
clientId = master.getClientId(null, GetClientIdRequest.getDefaultInstance()).getClientId();
master.clientRegister(null, ClientRegisterRequest.newBuilder().setClientId(clientId).build());
LOG.info("clientId=" + clientId);
hbThread = new Thread(() -> {
long lastHbTs = System.currentTimeMillis();
while (!stopped.get() && !Thread.interrupted()) {
try {
if (System.currentTimeMillis() - lastHbTs > hbTimeoutMS) {
LOG.fatal("can not connect to master in " + hbTimeoutMS
+ " ms. the client will be killed by itself");
System.exit(-1);
}
Thread.sleep(hbIntervalMS);
master.keepAlive(null, KeepAliveRequest.newBuilder().setClientId(clientId).build());
lastHbTs = System.currentTimeMillis();
} catch (Throwable e) {
if (!stopped.get()) {
LOG.error("AngelClient " + clientId + " send heartbeat to Master failed ", e);
}
}
}
});
hbThread.setName("client-heartbeat");
hbThread.setDaemon(true);
hbThread.start();
}
public void keepAlive() throws ServiceException {
master.keepAlive(null, KeepAliveRequest.newBuilder().setClientId(clientId).build());
}
@Override
public void run() throws AngelException {
if (master == null) {
throw new AngelException(
"parameter servers are not started, you must execute startPSServer first!!");
}
createMatrices();
try {
master.start(null, StartRequest.newBuilder().build());
} catch (ServiceException e) {
LOG.error("start application failed.", e);
throw new AngelException(e);
}
}
@Override
public void addMatrix(MatrixContext mContext) throws AngelException {
if (nameToMatrixMap.containsKey(mContext.getName())) {
throw new AngelException(
"Matrix \"" + mContext.getName() + "\" already exist, please check it");
}
try {
nameToMatrixMap.put(mContext.getName(), mContext);
} catch (Throwable x) {
throw new AngelException(x);
}
}
@SuppressWarnings("rawtypes")
@Override
public void loadModel(MLModel model)
throws AngelException {
if (master == null) {
throw new AngelException(
"parameter servers are not started, you must execute startPSServer first!!");
}
Map psModels = model.getPSModels();
for (Map.Entry entry : psModels.entrySet()) {
addMatrix(entry.getValue().getContext());
}
createMatrices();
load(psModels.keySet());
}
@SuppressWarnings("rawtypes")
@Override
public void saveModel(MLModel model)
throws AngelException {
if (master == null) {
throw new AngelException(
"parameter servers are not started, you must execute startPSServer first!!");
}
Map psModels = model.getPSModels();
ModelSaveContext saveContext = new ModelSaveContext();
for (Map.Entry entry : psModels.entrySet()) {
MatrixContext context = entry.getValue().getContext();
String savePath = context.getAttributes().get(MatrixConf.MATRIX_SAVE_PATH);
if (savePath != null) {
saveContext.addMatrix(new MatrixSaveContext(context.getName(),conf.get("OUT_FORMAT_CLASS", RowIdColIdValueTextRowFormat.class.getName())));
}
}
saveContext.setSavePath(conf.get(AngelConf.ANGEL_JOB_OUTPUT_PATH));
save(saveContext);
LOG.info("save is finish");
}
/**
* Save matrices to files.
*
* @param matrixNames need save matrix name list
*/
public void saveMatrices(List matrixNames) {
ModelSaveContext saveContext = new ModelSaveContext();
saveContext.setSavePath(conf.get(AngelConf.ANGEL_JOB_OUTPUT_PATH));
for (String name : matrixNames) {
saveContext.addMatrix(new MatrixSaveContext(name));
}
save(saveContext);
}
@Override
public void save(ModelSaveContext saveContext) throws AngelException {
if (saveContext.getMatricesContext().size() == 0) {
throw new AngelException("Need save matrix name is empty, you should check it");
}
if (saveContext.getSavePath() == null
|| saveContext.getSavePath().isEmpty()) {
throw new AngelException("Save path is null, you should check it");
}
try {
/*UserGroupInformation ugi = UGITools.getCurrentUser(conf);
ugi.doAs(new PrivilegedExceptionAction() {
@Override public String run() throws Exception {
Path savePath = new Path(saveContext.getSavePath());
FileSystem fs = savePath.getFileSystem(conf);
if(fs.exists(savePath)) {
if(conf.getBoolean(AngelConf.ANGEL_JOB_OUTPUT_PATH_DELETEONEXIST,
AngelConf.DEFAULT_ANGEL_JOB_OUTPUT_PATH_DELETEONEXIST)) {
fs.delete(savePath, true);
} else {
throw new AngelException("Save path " + savePath + " already exist, you can set another save path or set angel.job.output.path.deleteonexist be true");
}
}
return "OK";
}
});
*/
Path savePath = new Path(saveContext.getSavePath());
FileSystem fs = savePath.getFileSystem(conf);
if (fs.exists(savePath)) {
if (conf.getBoolean(AngelConf.ANGEL_JOB_OUTPUT_PATH_DELETEONEXIST,
AngelConf.DEFAULT_ANGEL_JOB_OUTPUT_PATH_DELETEONEXIST)) {
fs.delete(savePath, true);
if (fs.exists(savePath)) {
throw new AngelException(
"Save path " + savePath + " already exist, remove it failed!!!");
}
} else {
throw new AngelException("Save path " + savePath
+ " already exist, you can set another save path or set angel.job.output.path.deleteonexist be true");
}
}
} catch (Throwable x) {
throw new AngelException(x);
}
SaveRequest.Builder builder = SaveRequest.newBuilder();
int requestId;
try {
requestId =
master.save(null, builder.setSaveContext(ProtobufUtil.convert(saveContext)).build())
.getRequestId();
} catch (ServiceException e) {
LOG.error("save model failed.", e);
throw new AngelException(e);
}
while (!isSaveCompleted(requestId)) {
try {
Thread.sleep(1000);
} catch (InterruptedException e) {
throw new AngelException("Interrupted as waiting app complete");
}
}
if (appFailedMessage != null) {
throw new AngelException("app run failed, " + appFailedMessage);
}
}
@Override
public void checkpoint(int checkpointId, ModelSaveContext saveContext) throws AngelException {
String tmpPath = conf.get(AngelConf.ANGEL_JOB_TMP_OUTPUT_PATH);
String snapshotDir = new Path(tmpPath, ModelFilesConstent.snapshotDirName).toString();
String checkpointItemPath = new Path(snapshotDir, "" + checkpointId).toString();
LOG.info("===Checkpoint path = " + checkpointItemPath);
saveContext.setSavePath(checkpointItemPath);
saveContext.setCheckpoint(true);
save(saveContext);
}
@Override
public void load(ModelLoadContext loadContext) throws AngelException {
if (loadContext.getMatricesContext().size() == 0) {
throw new AngelException("Need load matrix names is empty, you should check it");
}
if (loadContext.getLoadPath() == null || loadContext.getLoadPath().isEmpty()) {
throw new AngelException("Load path is null, you should check it");
}
try {
/* UserGroupInformation ugi = UGITools.getCurrentUser(conf);
ugi.doAs(new PrivilegedExceptionAction() {
@Override public String run() throws Exception {
Path loadPath = new Path(loadContext.getLoadPath());
FileSystem fs = loadPath.getFileSystem(conf);
if(!fs.exists(loadPath)) {
throw new AngelException("Load path " + loadPath + " does not exist, please check");
}
return "OK";
}
});
*/
Path loadPath = new Path(loadContext.getLoadPath());
FileSystem fs = loadPath.getFileSystem(conf);
if (!fs.exists(loadPath)) {
throw new AngelException("Load path " + loadPath + " does not exist, please check");
}
} catch (Throwable e) {
throw new AngelException(e);
}
LoadRequest.Builder builder = LoadRequest.newBuilder();
int requestId;
try {
requestId =
master.load(null, builder.setLoadContext(ProtobufUtil.convert(loadContext)).build())
.getRequestId();
} catch (ServiceException e) {
LOG.error("save model failed.", e);
throw new AngelException(e);
}
while (!isLoadCompleted(requestId)) {
try {
Thread.sleep(1000);
} catch (InterruptedException e) {
throw new AngelException("Interrupted as waiting app complete");
}
}
LOG.info("load model complete from path " + loadContext.getLoadPath());
if (appFailedMessage != null) {
throw new AngelException("app run failed, " + appFailedMessage);
}
}
@Override
public void recover(int checkpointId, ModelLoadContext loadContext) throws AngelException {
String tmpPath = conf.get(AngelConf.ANGEL_JOB_TMP_OUTPUT_PATH);
String snapshotDir = new Path(tmpPath, ModelFilesConstent.snapshotDirName).toString();
String checkpointItemPath = new Path(snapshotDir, "" + checkpointId).toString();
loadContext.setLoadPath(checkpointItemPath);
load(loadContext);
}
@Override
public void stop(int stateCode) throws AngelException {
LOG.info("stop the application");
stopService();
if (master != null) {
try {
LOG.info("master is not null, send stop command to Master, stateCode=" + stateCode);
master.stop(null,
ClientMasterServiceProtos.StopRequest.newBuilder().setExitStatus(stateCode).build());
} catch (Throwable e) {
LOG.error("send stop command to Master failed ", e);
kill();
//throw new AngelException(e);
}
} else {
LOG.info("master is null, just kill the application");
kill();
}
close();
}
@Override
public void stop() throws AngelException {
stop(0);
}
private void stopService() {
nameToMatrixMap.clear();
isExecuteFinished = false;
isFinished = false;
if (!stopped.getAndSet(true)) {
if (hbThread != null) {
hbThread.interrupt();
try {
hbThread.join();
} catch (Throwable e) {
}
}
}
stopped.set(false);
}
@Override
public void waitForCompletion() throws AngelException {
if (master == null) {
throw new AngelException(
"parameter servers are not started, you must execute startPSServer first!!");
}
RunningMode mode = RunningMode
.valueOf(conf.get(AngelConf.ANGEL_RUNNING_MODE, AngelConf.DEFAULT_ANGEL_RUNNING_MODE));
switch (mode) {
case ANGEL_PS: {
while (!isCompleted()) {
try {
Thread.sleep(1000);
} catch (InterruptedException e) {
throw new AngelException("Interrupted as waiting app complete");
}
}
if (appFailedMessage != null) {
throw new AngelException("app run failed, " + appFailedMessage);
}
break;
}
case ANGEL_PS_WORKER: {
printWorkerLogUrl(new WorkerId(new WorkerGroupId(0), 0));
while (!isExecuteCompleted()) {
try {
Thread.sleep(1000);
} catch (InterruptedException e) {
throw new AngelException("Interrupted as waiting app complete");
}
}
if (appFailedMessage != null) {
throw new AngelException("app run failed, " + appFailedMessage);
}
String actionType =
conf.get(AngelConf.ANGEL_ACTION_TYPE, AngelConf.DEFAULT_ANGEL_ACTION_TYPE);
LOG.info("action type " + actionType);
if (actionType.matches("predict")) {
try {
movePredictResult();
} catch (IOException e) {
throw new AngelException("move predict result failed." + e.getMessage());
}
}
break;
}
default:
break;
}
}
private void movePredictResult() throws IOException {
String outPathStr = conf.get(AngelConf.ANGEL_JOB_OUTPUT_PATH);
String tmpPathStr = conf.get(AngelConf.ANGEL_JOB_TMP_OUTPUT_PATH);
Path outPath = new Path(outPathStr);
Path tmpOutPath = new Path(tmpPathStr, "predict");
FileSystem fs = outPath.getFileSystem(conf);
Path tmpCombinePath = HdfsUtil.toTmpPath(outPath);
HdfsUtil.copyFilesInSameHdfs(tmpOutPath, tmpCombinePath, fs);
LOG.info("copy files from " + tmpOutPath + " to " + tmpCombinePath);
HdfsUtil.rename(tmpCombinePath, outPath, fs);
LOG.info("rename " + tmpCombinePath + " to " + outPath);
}
public Path getInternalStateFile() {
return internalStateFile;
}
/**
* Get applicaiton configuration.
*
* @return Configuration the applicaiton configuration
*/
public Configuration getConf() {
return conf;
}
/**
* Get the location(ip:port) of the application master.
*
* @return Location the location(ip:port) of the application master
*/
public Location getMasterLocation() {
return masterLocation;
}
/**
* Check the application is over or not. If the application state is J_SUCCESSED, J_FAILED or
* J_KILLED, means it is over.
*
* @return boolean true means the application is over
*/
private boolean isCompleted() {
if (isFinished) {
return true;
}
updateJobReport();
if (lastReport == null) {
try {
lastReport = tryGetResponseFromFile(true);
LOG.info("app from file is " + lastReport);
} catch (IOException e) {
LOG.error("get app from file failed ", e);
}
}
if (lastReport == null || lastReport.getJobReport().getJobState() == JobStateProto.J_NEW) {
appFailedMessage = " detail is killed";
return true;
}
JobStateProto jobState = lastReport.getJobReport().getJobState();
if (LOG.isDebugEnabled()) {
LOG.debug("job stat = " + jobState.name());
}
if (jobState == JobStateProto.J_SUCCEEDED || jobState == JobStateProto.J_FAILED
|| jobState == JobStateProto.J_KILLED) {
isFinished = true;
LOG.info("job is finished! status: " + jobState);
if (jobState == JobStateProto.J_FAILED || jobState == JobStateProto.J_KILLED) {
appFailedMessage = " detail is " + lastReport.getJobReport().getDiagnostics();
LOG.error(appFailedMessage);
}
return true;
} else {
return false;
}
}
/**
* Check a save request complete or not
*
* @param requestId save request id
* @return true means complete
*/
private boolean isSaveCompleted(int requestId) throws AngelException {
CheckModelSavedResponse response;
try {
response = master
.checkModelSaved(null,
CheckModelSavedRequest.newBuilder().setRequestId(requestId).build());
} catch (Throwable x) {
throw new AngelException("Check model save request failed ", x);
}
SaveState state = SaveState.valueOf(response.getStatus());
if (state == SaveState.FAILED) {
throw new AngelException("Model save falied " + response.getLog());
} else if (state == SaveState.SUCCESS) {
return true;
} else {
return false;
}
}
/**
* Check a load request complete or not
*
* @param requestId load request id
* @return true means complete
*/
private boolean isLoadCompleted(int requestId) throws AngelException {
CheckModelLoadedResponse response;
try {
response = master.checkModelLoaded(null,
CheckModelLoadedRequest.newBuilder().setRequestId(requestId).build());
} catch (Throwable x) {
throw new AngelException("Check model save request failed ", x);
}
LoadState state = LoadState.valueOf(response.getStatus());
if (state == LoadState.FAILED) {
throw new AngelException("Model load falied " + response.getLog());
} else if (state == LoadState.SUCCESS) {
return true;
} else {
return false;
}
}
/**
* Check the application calculation phase is over or not. If the application state is
* J_SUCCESSED, J_FAILED J_KILLED, J_EXECUTE_SUCCESSED or J_COMMITTING means it is over.
*
* @return boolean true means the application is over
*/
private boolean isExecuteCompleted() {
if (isExecuteFinished) {
return true;
}
updateJobReport();
if (lastReport == null) {
try {
lastReport = tryGetResponseFromFile(true);
LOG.info("app from file is " + lastReport);
} catch (IOException e) {
LOG.error("get app from file failed ", e);
}
}
if (lastReport == null || lastReport.getJobReport().getJobState() == JobStateProto.J_NEW) {
appFailedMessage = " detail is killed";
return true;
}
JobStateProto jobState = lastReport.getJobReport().getJobState();
if (LOG.isDebugEnabled()) {
LOG.debug("job stat = " + jobState.name());
}
if (jobState != JobStateProto.J_INITED && jobState != JobStateProto.J_NEW
&& jobState != JobStateProto.J_PREPARE_WORKERS && jobState != JobStateProto.J_RUNNING) {
isExecuteFinished = true;
LOG.info("job is finished! status: " + jobState);
if (jobState == JobStateProto.J_FAILED || jobState == JobStateProto.J_KILLED) {
appFailedMessage = " detail is " + lastReport.getJobReport().getDiagnostics();
LOG.error(appFailedMessage);
}
return true;
} else {
return false;
}
}
protected void printWorkerLogUrl(WorkerId workerId) {
}
private void updateJobReport() {
GetJobReportRequest getJobRequest = getGetJobReportRequest();
GetJobReportResponse response = null;
try {
response = master.getJobReport(null, getJobRequest);
} catch (Exception e) {
LOG.error("getJobReport from master failed. " + e.getMessage());
try {
updateMaster(60);
if (master != null) {
response = master.getJobReport(null, getJobRequest);
}
} catch (Exception e1) {
LOG.error("update master failed.", e1);
}
}
if (response == null) {
isFinished = true;
lastReport = null;
return;
}
JobReportProto report = response.getJobReport();
// JobStateProto jobState = report.getJobState();
if (lastReport == null || (report.hasCurIteration() && report.getCurIteration() != lastReport
.getJobReport().getCurIteration())) {
LOG.info(
"Epoch: " + report.getCurIteration() + ". Metrics=" + toString(report.getMetricsList()));
if (report.hasLoss()) {
LOG.info("loss/success: " + report.getLoss() + "/" + report.getSuccess());
}
}
lastReport = response;
}
private String toString(List metrics) {
StringBuilder sb = new StringBuilder("{");
int size = metrics.size();
for (int i = 0; i < size; i++) {
sb.append("\"" + metrics.get(i).getKey() + "\":" + df
.format(Double.valueOf(metrics.get(i).getValue())));
if (i < size - 1) {
sb.append(",");
}
}
sb.append("}");
return sb.toString();
}
private GetJobReportRequest getGetJobReportRequest() {
if (getJobReportReq != null) {
return getJobReportReq;
}
if (getJobReportReqBuilder == null) {
getJobReportReqBuilder = GetJobReportRequest.newBuilder();
}
getJobReportReqBuilder.setAppId(getAppId());
return getJobReportReqBuilder.build();
}
private GetJobReportResponse tryGetResponseFromFile(boolean deleteOnExist) throws IOException {
GetJobReportResponse response = null;
FileSystem fs = internalStateFile.getFileSystem(conf);
if (fs.exists(internalStateFile)) {
LOG.info(internalStateFile + " exist, parse app report from it");
FSDataInputStream in = fs.open(internalStateFile);
response = GetJobReportResponse.parseFrom(in);
if (deleteOnExist) {
fs.delete(internalStateFile.getParent(), true);
}
}
return response;
}
@Override
public void createMatrices() throws AngelException {
try {
for (MatrixContext context : nameToMatrixMap.values()) {
context.init(conf);
}
master.createMatrices(null,
ProtobufUtil.buildCreateMatricesRequest(new ArrayList<>(nameToMatrixMap.values())));
List matrixNames = new ArrayList<>(nameToMatrixMap.keySet());
waitForMatricesCreated(matrixNames);
} catch (Throwable x) {
throw new AngelException(x);
}
}
@Override
public void createMatrices(List matrixContexts) throws AngelException {
try {
for (MatrixContext context : matrixContexts) {
context.init(conf);
}
master.createMatrices(null, ProtobufUtil.buildCreateMatricesRequest(matrixContexts));
List matrixNames = new ArrayList<>(matrixContexts.size());
for (MatrixContext context : matrixContexts) {
matrixNames.add(context.getName());
}
waitForMatricesCreated(matrixNames);
} catch (Throwable x) {
throw new AngelException(x);
}
}
public void load() {
load(nameToMatrixMap.keySet());
}
public void load(Set matrixNames) {
// Check need load matrices
String loadPath = conf.get(AngelConf.ANGEL_LOAD_MODEL_PATH);
if (loadPath != null && !loadPath.isEmpty()) {
ModelLoadContext loadContext = new ModelLoadContext(loadPath);
int needLoadMatrixCount = 0;
for (String name : matrixNames) {
MatrixContext matrix = nameToMatrixMap.get(name);
if (matrix.getAttributes().get(MatrixConf.MATRIX_LOAD_PATH) != null) {
loadContext.addMatrix(new MatrixLoadContext(name));
needLoadMatrixCount++;
}
}
if (needLoadMatrixCount > 0) {
load(loadContext);
}
}
}
private void waitForMatricesCreated(List matrixNames) throws ServiceException {
CheckMatricesCreatedRequest request =
CheckMatricesCreatedRequest.newBuilder().addAllMatrixNames(matrixNames).build();
int size = matrixNames.size();
while (true) {
CheckMatricesCreatedResponse response = master.checkMatricesCreated(null, request);
if (response.getStatus() == 0) {
return;
}
try {
Thread.sleep(1000);
} catch (InterruptedException e) {
LOG.warn("waitForMatricesCreated is interrupted.");
}
}
}
protected void setInputDirectory() throws IOException {
boolean isUseDummy = conf.getBoolean(AngelConf.ANGEL_AM_USE_DUMMY_DATASPLITER,
AngelConf.DEFAULT_ANGEL_AM_USE_DUMMY_DATASPLITER);
if (isUseDummy) {
return;
}
String actionType = conf.get(AngelConf.ANGEL_ACTION_TYPE, AngelConf.DEFAULT_ANGEL_ACTION_TYPE);
RunningMode runningMode = RunningMode
.valueOf(conf.get(AngelConf.ANGEL_RUNNING_MODE, AngelConf.DEFAULT_ANGEL_RUNNING_MODE));
String path;
if (actionType.matches("predict")) {
path = conf.get(AngelConf.ANGEL_PREDICT_DATA_PATH);
} else {
path = conf.get(AngelConf.ANGEL_TRAIN_DATA_PATH);
}
if (runningMode == RunningMode.ANGEL_PS_WORKER) {
if (path == null) {
throw new IOException("input data directory is empty, you should set it");
} else {
conf.set(AngelConf.ANGEL_JOB_INPUT_PATH, path);
}
}
}
protected void setOutputDirectory() throws IOException {
String actionType = conf.get(AngelConf.ANGEL_ACTION_TYPE, AngelConf.DEFAULT_ANGEL_ACTION_TYPE);
RunningMode runningMode = RunningMode
.valueOf(conf.get(AngelConf.ANGEL_RUNNING_MODE, AngelConf.DEFAULT_ANGEL_RUNNING_MODE));
LOG.info("running mode = " + runningMode);
boolean deleteOnExist = conf.getBoolean(AngelConf.ANGEL_JOB_OUTPUT_PATH_DELETEONEXIST,
AngelConf.DEFAULT_ANGEL_JOB_OUTPUT_PATH_DELETEONEXIST);
String path;
if (actionType.matches("train") || actionType.matches("inctrain")) {
path = conf.get(AngelConf.ANGEL_SAVE_MODEL_PATH);
} else if (actionType.matches("predict")) {
path = conf.get(AngelConf.ANGEL_PREDICT_PATH);
} else if (actionType.matches("serving")) {
path = conf.get(AngelConf.ANGEL_SERVING_TEMP_PATH);
} else {
path = null;
}
if (path == null) {
throw new IOException(
"output directory is null. you must set " + AngelConf.ANGEL_SAVE_MODEL_PATH
+ " at training mode or set " + AngelConf.ANGEL_PREDICT_PATH + " at predict mode"
+ AngelConf.ANGEL_SERVING_TEMP_PATH + "at serving mode");
}
conf.set(AngelConf.ANGEL_JOB_OUTPUT_PATH, path);
Path outputPath = new Path(path);
FileSystem outFs = outputPath.getFileSystem(conf);
if (outFs.exists(outputPath)) {
if (deleteOnExist) {
outFs.delete(outputPath, true);
if (outFs.exists(outputPath)) {
throw new IOException(
"output path " + outputPath + " already exist, remove it failed!!!");
}
} else {
throw new IOException("output path " + outputPath + " already exist, please check");
}
}
Path outputParentPath = outputPath.getParent();
if (!outFs.exists(outputParentPath)) {
LOG.info("Make dir for model output parent path: " + outputParentPath);
if (!outFs.mkdirs(outputParentPath)) {
throw new IOException(
"Failed to make dir for model output parent path: " + outputParentPath);
}
}
if (runningMode == RunningMode.ANGEL_PS_WORKER) {
String logPathStr = conf.get(AngelConf.ANGEL_LOG_PATH);
if (logPathStr != null) {
Path logPath = new Path(logPathStr);
FileSystem logFs = logPath.getFileSystem(conf);
if (logFs.exists(logPath)) {
if (deleteOnExist) {
logFs.delete(logPath, true);
} else {
throw new IOException("log path " + logPath + " already exist, please check");
}
}
}
}
Path tmpOutputPath = HdfsUtil.generateTmpDirectory(conf, getAppId(), outputPath);
internalStateFile =
new Path(HdfsUtil.generateTmpDirectory(conf, getAppId(), outputPath), "state");
conf.set(AngelConf.ANGEL_JOB_TMP_OUTPUT_PATH, tmpOutputPath.toString());
LOG.info(AngelConf.ANGEL_JOB_TMP_OUTPUT_PATH + "=" + tmpOutputPath.toString());
LOG.info("internal state file is " + internalStateFile);
conf.set(AngelConf.ANGEL_APP_SERILIZE_STATE_FILE, internalStateFile.toString());
}
protected void setUser()
throws ClassNotFoundException, NoSuchFieldException, SecurityException, InstantiationException,
IllegalAccessException, IOException {
userName = UGITools.getCurrentUser(conf).getShortUserName();
conf.set(AngelConf.USER_NAME, userName);
}
abstract public void startPSServer() throws AngelException;
protected void setLocalAddr() throws UnknownHostException {
InetAddress ip = InetAddress.getLocalHost();
if (ip != null) {
String submitHostAddress = ip.getHostAddress();
String submitHostName = ip.getHostName();
conf.set(AngelConf.JOB_SUBMITHOST, submitHostName);
conf.set(AngelConf.JOB_SUBMITHOSTADDR, submitHostAddress);
}
}
protected void handleDeprecatedParameters(Configuration conf) {
String memoryMBStr = conf.get(AngelConf.ANGEL_AM_MEMORY_MB);
String memoryGBStr = conf.get(AngelConf.ANGEL_AM_MEMORY_GB);
if (memoryGBStr == null && memoryMBStr != null) {
LOG.warn("use deprecated parameter " + AngelConf.ANGEL_AM_MEMORY_MB + ", you can use "
+ AngelConf.ANGEL_AM_MEMORY_GB + " instead.");
try {
int memoryMB = Integer.valueOf(memoryMBStr);
conf.setInt(AngelConf.ANGEL_AM_MEMORY_GB, (int) Math.ceil((float) memoryMB / 1024));
} catch (Exception x) {
LOG.error("invalid value for " + AngelConf.ANGEL_AM_MEMORY_MB + " " + memoryMBStr);
}
}
memoryMBStr = conf.get(AngelConf.ANGEL_WORKER_MEMORY_MB);
memoryGBStr = conf.get(AngelConf.ANGEL_WORKER_MEMORY_GB);
if (memoryGBStr == null && memoryMBStr != null) {
LOG.warn("use deprecated parameter " + AngelConf.ANGEL_WORKER_MEMORY_MB + ", you can use "
+ AngelConf.ANGEL_WORKER_MEMORY_GB + " instead.");
try {
int memoryMB = Integer.valueOf(memoryMBStr);
conf.setInt(AngelConf.ANGEL_WORKER_MEMORY_GB, (int) Math.ceil((float) memoryMB / 1024));
} catch (Exception x) {
LOG.error("invalid value for " + AngelConf.ANGEL_WORKER_MEMORY_MB + " " + memoryGBStr);
}
}
memoryMBStr = conf.get(AngelConf.ANGEL_PS_MEMORY_MB);
memoryGBStr = conf.get(AngelConf.ANGEL_PS_MEMORY_GB);
if (memoryGBStr == null && memoryMBStr != null) {
LOG.warn("use deprecated parameter " + AngelConf.ANGEL_PS_MEMORY_MB + ", you can use "
+ AngelConf.ANGEL_PS_MEMORY_GB + " instead.");
try {
int memoryMB = Integer.valueOf(memoryMBStr);
conf.setInt(AngelConf.ANGEL_PS_MEMORY_GB, (int) Math.ceil((float) memoryMB / 1024));
} catch (Exception x) {
LOG.error("invalid value for " + AngelConf.ANGEL_PS_MEMORY_MB + " " + memoryGBStr);
}
}
}
protected void checkParameters(Configuration conf) throws InvalidParameterException {
StringBuilder sb = new StringBuilder();
int coreNum = conf.getInt(AngelConf.ANGEL_AM_CPU_VCORES, AngelConf.DEFAULT_ANGEL_AM_CPU_VCORES);
if (coreNum <= 0) {
sb.append(AngelConf.ANGEL_AM_CPU_VCORES + " should > 0");
sb.append("\n");
}
int memNum = conf.getInt(AngelConf.ANGEL_AM_MEMORY_GB, AngelConf.DEFAULT_ANGEL_AM_MEMORY_GB);
if (memNum <= 0) {
sb.append(AngelConf.ANGEL_AM_MEMORY_GB + " should > 0");
sb.append("\n");
}
coreNum =
conf.getInt(AngelConf.ANGEL_WORKER_CPU_VCORES, AngelConf.DEFAULT_ANGEL_WORKER_CPU_VCORES);
if (coreNum <= 0) {
sb.append(AngelConf.ANGEL_WORKER_CPU_VCORES + " should > 0");
sb.append("\n");
}
memNum =
conf.getInt(AngelConf.ANGEL_WORKER_MEMORY_GB, AngelConf.DEFAULT_ANGEL_WORKER_MEMORY_GB);
if (memNum <= 0) {
sb.append(AngelConf.ANGEL_WORKER_MEMORY_GB + " should > 0");
sb.append("\n");
}
coreNum = conf.getInt(AngelConf.ANGEL_PS_CPU_VCORES, AngelConf.DEFAULT_ANGEL_PS_CPU_VCORES);
if (coreNum <= 0) {
sb.append(AngelConf.ANGEL_PS_CPU_VCORES + " should > 0");
sb.append("\n");
}
memNum = conf.getInt(AngelConf.ANGEL_PS_MEMORY_GB, AngelConf.DEFAULT_ANGEL_PS_MEMORY_GB);
if (memNum <= 0) {
sb.append(AngelConf.ANGEL_PS_MEMORY_GB + " should > 0");
sb.append("\n");
}
if (sb.length() > 0) {
throw new InvalidParameterException(sb.toString());
}
}
protected void waitForAllPS(int psNumber) throws ServiceException, InterruptedException {
boolean isAllPSReady = true;
while (true) {
GetAllPSLocationResponse response =
master.getAllPSLocation(null, GetAllPSLocationRequest.newBuilder().build());
List psLocs = response.getPsLocationsList();
int size = psLocs.size();
if (size == psNumber) {
isAllPSReady = true;
for (int i = 0; i < size; i++) {
if (psLocs.get(i).getPsStatus() == PSStatus.PS_NOTREADY) {
isAllPSReady = false;
break;
}
}
if (isAllPSReady) {
return;
}
}
Thread.sleep(100);
}
}
@Override
public void close() {
TConnectionManager.deleteAllConnections(true);
TConnectionManager.shutDown();
}
protected abstract void updateMaster(int maxWaitTimeInSec) throws Exception;
protected abstract String getAppId();
}