com.tencent.angel.psagent.client.MasterClient Maven / Gradle / Ivy
/*
* Tencent is pleased to support the open source community by making Angel available.
*
* Copyright (C) 2017-2018 THL A29 Limited, a Tencent company. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in
* compliance with the License. You may obtain a copy of the License at
*
* https://opensource.org/licenses/Apache-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License
* is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
* or implied. See the License for the specific language governing permissions and limitations under
* the License.
*
*/
package com.tencent.angel.psagent.client;
import com.google.protobuf.ByteString;
import com.google.protobuf.ServiceException;
import com.tencent.angel.common.location.Location;
import com.tencent.angel.exception.TimeOutException;
import com.tencent.angel.master.MasterProtocol;
import com.tencent.angel.ml.matrix.MatrixContext;
import com.tencent.angel.ml.matrix.MatrixMeta;
import com.tencent.angel.ml.matrix.PartitionLocation;
import com.tencent.angel.ml.metric.Metric;
import com.tencent.angel.protobuf.ProtobufUtil;
import com.tencent.angel.protobuf.RequestConverter;
import com.tencent.angel.protobuf.generated.MLProtos.*;
import com.tencent.angel.protobuf.generated.PSAgentMasterServiceProtos.*;
import com.tencent.angel.protobuf.generated.WorkerMasterServiceProtos.*;
import com.tencent.angel.ps.ParameterServerId;
import com.tencent.angel.ps.server.data.PSLocation;
import com.tencent.angel.psagent.PSAgentContext;
import com.tencent.angel.split.SplitClassification;
import com.tencent.angel.utils.KryoUtils;
import com.tencent.angel.utils.Time;
import com.tencent.angel.worker.WorkerContext;
import com.tencent.angel.worker.WorkerGroup;
import com.tencent.angel.worker.WorkerRef;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import java.io.IOException;
import java.util.*;
/**
* The RPC client to master use protobuf codec protocol
*/
public class MasterClient {
private static final Log LOG = LogFactory.getLog(MasterClient.class);
/**
* protobuf RPC client
*/
private volatile MasterProtocol master;
public MasterClient() {
}
/**
* Init protobuf rpc client to master
*
* @throws IOException connect to master failed
*/
public void init() throws IOException {
this.master = getOrCreateMasterClient(PSAgentContext.get().getPsAgent().getMasterLocation());
}
private MasterProtocol getOrCreateMasterClient(Location loc) throws IOException {
return PSAgentContext.get().getPsAgent().getControlConnectManager()
.getMasterService(loc.getIp(), loc.getPort());
}
/**
* Get the ps location from master
*
* @param psId ps id
* @return Location ps location
* @throws ServiceException rpc failed
*/
public Location getPSLocation(ParameterServerId psId) throws ServiceException {
GetPSLocationRequest request =
GetPSLocationRequest.newBuilder().setPsId(ProtobufUtil.convertToIdProto(psId)).build();
return ProtobufUtil.convertToLocation(master.getPSLocation(null, request).getPsLocation());
}
/**
* Get the locations of all parameter servers
*
* @return Map ps id to location map
* @throws ServiceException rpc failed
*/
public Map getPSLocations() throws ServiceException {
GetAllPSLocationRequest request = GetAllPSLocationRequest.newBuilder().build();
HashMap routingMap = new HashMap<>();
try {
GetAllPSLocationResponse response = master.getAllPSLocation(null, request);
List psLocs = response.getPsLocationsList();
int size = psLocs.size();
for (int i = 0; i < size; i++) {
routingMap.put(ProtobufUtil.convertToId(psLocs.get(i).getPsId()),
ProtobufUtil.convertToLocation(psLocs.get(i)));
}
} catch (com.google.protobuf.ServiceException e) {
LOG.error("get all ps locations from master failed.", e);
}
return routingMap;
}
/**
* Get the meta data and partitions for all matrices, it will wait until the matrices are ready
*
* @return GetAllMatrixInfoResponse the meta data and partitions for all matrices
* @throws InterruptedException interrupted when sleep for next try
* @throws ServiceException rpc failed
*/
public List getMatrices()
throws InterruptedException, ServiceException, ClassNotFoundException {
GetAllMatrixMetaResponse response =
master.getAllMatrixMeta(null, GetAllMatrixMetaRequest.newBuilder().build());
List matrixMetaProtos = response.getMatrixMetasList();
int size = matrixMetaProtos.size();
List matrixMetas = new ArrayList<>(size);
for (int i = 0; i < size; i++) {
matrixMetas.add(ProtobufUtil.convertToMatrixMeta(matrixMetaProtos.get(i)));
}
return matrixMetas;
}
/**
* Get a matrix meta
*
* @param matrixName matrix name
* @return matrix meta
* @throws ServiceException
* @throws ClassNotFoundException
*/
public MatrixMeta getMatrix(String matrixName) throws ServiceException, ClassNotFoundException {
GetMatricesResponse response =
master.getMatrices(null, GetMatricesRequest.newBuilder().addMatrixNames(matrixName).build());
return ProtobufUtil.convertToMatrixMeta(response.getMatrixMetas(0));
}
/**
* Get matrix metas
*
* @param matrixNames matrix names
* @return matrix metas
* @throws ServiceException
* @throws ClassNotFoundException
*/
public List getMatrices(List matrixNames)
throws ServiceException, ClassNotFoundException {
GetMatricesResponse response = master
.getMatrices(null, GetMatricesRequest.newBuilder().addAllMatrixNames(matrixNames).build());
List matrixMetaProtos = response.getMatrixMetasList();
int size = matrixMetaProtos.size();
List matrixMetas = new ArrayList<>(size);
for (int i = 0; i < size; i++) {
matrixMetas.add(ProtobufUtil.convertToMatrixMeta(matrixMetaProtos.get(i)));
}
return matrixMetas;
}
/**
* PSAgent register to master
*
* @return PSAgentRegisterResponse register response
* @throws ServiceException rpc failed
*/
public PSAgentRegisterResponse psAgentRegister() throws ServiceException {
PSAgentRegisterRequest request =
PSAgentRegisterRequest.newBuilder().setPsAgentId(PSAgentContext.get().getPsAgent().getId())
.setLocation(ProtobufUtil.convertToLocationProto(PSAgentContext.get().getLocation()))
.build();
return master.psAgentRegister(null, request);
}
/**
* Report ps agent state to master
*
* @return PSAgentReportResponse report response
* @throws ServiceException rpc failed
*/
public PSAgentReportResponse psAgentReport() throws ServiceException {
PSAgentReportRequest request =
PSAgentReportRequest.newBuilder().setPsAgentId(PSAgentContext.get().getPsAgent().getId())
.build();
return master.psAgentReport(null, request);
}
/**
* Create a new matrix
*
* @param matrixContext matrix configuration
* @param timeOutMS maximun wait time in milliseconds
* @throws Exception rpc failed
*/
public void createMatrix(MatrixContext matrixContext, long timeOutMS) throws Exception {
matrixContext.init(PSAgentContext.get().getConf());
List matrixContexts = new ArrayList<>(1);
matrixContexts.add(matrixContext);
createMatrices(matrixContexts, timeOutMS);
}
/**
* Create a new matrix
*
* @param matrixContexts matrices configuration
* @param timeOutMS maximun wait time in milliseconds
* @throws Exception rpc failed
*/
public void createMatrices(List matrixContexts, long timeOutMS) throws Exception {
CreateMatricesRequest.Builder createBuilder = CreateMatricesRequest.newBuilder();
CheckMatricesCreatedRequest.Builder checkBuilder = CheckMatricesCreatedRequest.newBuilder();
List matrixNames = new ArrayList<>(matrixContexts.size());
int size = matrixContexts.size();
for (int i = 0; i < size; i++) {
matrixContexts.get(i).init(PSAgentContext.get().getConf());
matrixNames.add(matrixContexts.get(i).getName());
checkBuilder.addMatrixNames(matrixContexts.get(i).getName());
createBuilder.addMatrices(ProtobufUtil.convertToMatrixContextProto(matrixContexts.get(i)));
}
LOG.info("start to create matrices " + String.join(",", matrixNames));
master.createMatrices(null, createBuilder.build());
CheckMatricesCreatedRequest checkRequest = checkBuilder.build();
CheckMatricesCreatedResponse checkResponse = null;
while (true) {
long startTs = Time.now();
checkResponse = master.checkMatricesCreated(null, checkRequest);
if (checkResponse.getStatus() == 0) {
LOG.info("create matrices " + String.join(",", matrixNames) + " success");
List metaProtos = master
.getMatrices(null, GetMatricesRequest.newBuilder().addAllMatrixNames(matrixNames).build())
.getMatrixMetasList();
for (int i = 0; i < size; i++) {
PSAgentContext.get().getMatrixMetaManager()
.addMatrix(ProtobufUtil.convertToMatrixMeta(metaProtos.get(i)));
}
return;
} else {
if (Time.now() - startTs > timeOutMS) {
throw new TimeOutException("create matrix time out ", (Time.now() - startTs), timeOutMS);
}
Thread.sleep(1000);
}
}
}
/**
* Release a matrix
*
* @param matrixName matrix name
* @throws ServiceException exception come from master
* @throws InterruptedException interrupted when wait
*/
public void releaseMatrix(String matrixName) throws ServiceException {
List matrixNames = new ArrayList<>(1);
matrixNames.add(matrixName);
releaseMatrices(matrixNames);
}
public void releaseMatrices(List matrixNames) throws ServiceException {
master.releaseMatrices(null,
ReleaseMatricesRequest.newBuilder().addAllMatrixNames(matrixNames).build());
}
/**
* Get worker group information:workers and data splits, it will wait until the worker group is ready
*
* @return WorkerGroup worker group information
* @throws ClassNotFoundException split class not found
* @throws IOException deserialize data splits meta failed
* @throws ServiceException rpc failed
* @throws InterruptedException interrupted when wait for next try
*/
public WorkerGroup getWorkerGroupMetaInfo()
throws ClassNotFoundException, IOException, ServiceException, InterruptedException {
GetWorkerGroupMetaInfoRequest request = GetWorkerGroupMetaInfoRequest.newBuilder()
.setWorkerAttemptId(WorkerContext.get().getWorkerAttemptIdProto()).build();
while (true) {
GetWorkerGroupMetaInfoResponse response = master.getWorkerGroupMetaInfo(null, request);
assert (response.getWorkerGroupStatus()
!= GetWorkerGroupMetaInfoResponse.WorkerGroupStatus.WORKERGROUP_EXITED);
LOG.debug("GetWorkerGroupMetaInfoResponse response=" + response);
if (response.getWorkerGroupStatus()
== GetWorkerGroupMetaInfoResponse.WorkerGroupStatus.WORKERGROUP_OK) {
// Deserialize data splits meta
SplitClassification splits = null;
if (response.getWorkerGroupMeta().getSplitsCount() > 0) {
splits = ProtobufUtil
.getSplitClassification(response.getWorkerGroupMeta().getSplitsList(),
WorkerContext.get().getConf());
}
// Get workers
WorkerGroup group = new WorkerGroup(WorkerContext.get().getWorkerGroupId(), splits);
for (WorkerMetaInfoProto worker : response.getWorkerGroupMeta().getWorkersList()) {
WorkerRef workerRef = new WorkerRef(worker.getWorkerLocation().getWorkerAttemptId(),
worker.getWorkerLocation().getLocation(), worker.getTasksList());
group.addWorkerRef(workerRef);
}
return group;
} else {
Thread.sleep(WorkerContext.get().getRequestSleepTimeMS());
}
}
}
/**
* Register to master, report the listening port
*
* @return WorkerRegisterResponse worker register response
* @throws ServiceException rpc falied
*/
public WorkerRegisterResponse workerRegister() throws ServiceException {
Location location = WorkerContext.get().getLocation();
WorkerRegisterRequest request = WorkerRegisterRequest.newBuilder()
.setWorkerAttemptId(WorkerContext.get().getWorkerAttemptIdProto()).setLocation(
LocationProto.newBuilder().setIp(location.getIp()).setPort(location.getPort()).build())
.setPsAgentId(WorkerContext.get().getPSAgent().getId()).build();
return master.workerRegister(null, request);
}
/**
* Report worker state to master
*
* @return WorkerReportResponse report response
* @throws ServiceException rpc failed
*/
public WorkerReportResponse workerReport() throws ServiceException {
WorkerReportRequest request =
RequestConverter.buildWorkerReportRequest(WorkerContext.get().getWorker());
return master.workerReport(null, request);
}
/**
* Notify ps agent failed message to master
*
* @param msg ps agent detail failed message
* @throws ServiceException rpc failed
*/
public void psAgentError(String msg) throws ServiceException {
PSAgentErrorRequest request =
PSAgentErrorRequest.newBuilder().setPsAgentId(PSAgentContext.get().getPsAgent().getId())
.setMsg(msg).build();
master.psAgentError(null, request);
}
/**
* Notify ps agent success message to master
*
* @throws ServiceException rpc failed
*/
public void psAgentDone() throws ServiceException {
PSAgentDoneRequest request =
PSAgentDoneRequest.newBuilder().setPsAgentId(PSAgentContext.get().getPsAgent().getId())
.build();
master.psAgentDone(null, request);
}
/**
* Notify worker failed message to master
*
* @param msg worker detail failed message
* @throws ServiceException rpc failed
*/
public void workerError(String msg) throws ServiceException {
WorkerErrorRequest request = WorkerErrorRequest.newBuilder()
.setWorkerAttemptId(WorkerContext.get().getWorkerAttemptIdProto()).setMsg(msg).build();
master.workerError(null, request);
}
/**
* Notify worker success message to master
*
* @throws ServiceException rpc failed
*/
public void workerDone() throws ServiceException {
WorkerDoneRequest request = WorkerDoneRequest.newBuilder()
.setWorkerAttemptId(WorkerContext.get().getWorkerAttemptIdProto()).build();
master.workerDone(null, request);
}
/**
* Task update clock value of a matrix
*
* @param taskIndex task index
* @param matrixId matrix id
* @param clock clock value
* @throws ServiceException
*/
public void updateClock(int taskIndex, int matrixId, int clock) throws ServiceException {
TaskClockRequest request = TaskClockRequest.newBuilder()
.setTaskId(TaskIdProto.newBuilder().setTaskIndex(taskIndex).build())
.setMatrixClock(MatrixClock.newBuilder().setMatrixId(matrixId).setClock(clock).build())
.build();
master.taskClock(null, request);
}
/**
* Task update iteration number
*
* @param taskIndex task index
* @param iteration iteration number
* @throws ServiceException rpc failed
*/
public void taskIteration(int taskIndex, int iteration) throws ServiceException {
TaskIterationRequest request = TaskIterationRequest.newBuilder()
.setTaskId(TaskIdProto.newBuilder().setTaskIndex(taskIndex).build()).setIteration(iteration)
.build();
master.taskIteration(null, request);
}
/**
* Task update iteration number
*
* @param taskIndex task index
* @param counters task counters
* @throws ServiceException rpc failed
*/
public void taskCountersUpdate(Map counters, int taskIndex)
throws ServiceException {
TaskCounterUpdateRequest.Builder builder = TaskCounterUpdateRequest.newBuilder();
builder.setTaskId(TaskIdProto.newBuilder().setTaskIndex(taskIndex).build());
Pair.Builder kvBuilder = Pair.newBuilder();
for (Map.Entry kv : counters.entrySet()) {
builder.addCounters(kvBuilder.setKey(kv.getKey()).setValue(kv.getValue()).build());
}
master.taskCountersUpdate(null, builder.build());
}
/**
* Set Task algorithm metrics
*
* @param taskIndex task index
* @param algoMetrics algorithm metrics
*/
public void setAlgoMetrics(int taskIndex, Map algoMetrics)
throws ServiceException {
SetAlgoMetricsRequest.Builder builder = SetAlgoMetricsRequest.newBuilder();
AlgoMetric.Builder metricBuilder = AlgoMetric.newBuilder();
builder.setTaskId(TaskIdProto.newBuilder().setTaskIndex(taskIndex).build());
for (Map.Entry metricEntry : algoMetrics.entrySet()) {
builder.addAlgoMetrics(metricBuilder.setName(metricEntry.getKey()).setSerializedMetric(
ByteString.copyFrom(KryoUtils.serializeAlgoMetric(metricEntry.getValue()))).build());
}
master.setAlgoMetrics(null, builder.build());
}
/**
* Get the pss that stored the partition
*
* @param matrixId matrix id
* @param partitionId partition id
* @return the pss that stored the partition
* @throws ServiceException
*/
public List getStoredPss(int matrixId, int partitionId)
throws ServiceException {
List psIdProtos = master.getStoredPss(null,
GetStoredPssRequest.newBuilder().setMatrixId(matrixId).setPartId(partitionId).build())
.getPsIdsList();
int size = psIdProtos.size();
List psIds = new ArrayList<>(psIdProtos.size());
for (int i = 0; i < size; i++) {
psIds.add(ProtobufUtil.convertToId(psIdProtos.get(i)));
}
return psIds;
}
/**
* Get the pss and their locations that stored the partition
*
* @param matrixId matrix id
* @param partId partition id
* @return the pss and their locations that stored the partition
* @throws ServiceException
*/
public PartitionLocation getPartLocation(int matrixId, int partId) throws ServiceException {
GetPartLocationResponse response = master.getPartLocation(null,
GetPartLocationRequest.newBuilder().setMatrixId(matrixId).setPartId(partId).build());
List psLocsProto = response.getLocationsList();
int size = psLocsProto.size();
List psLocs = new ArrayList<>(size);
for (int i = 0; i < size; i++) {
psLocs.add(new PSLocation(ProtobufUtil.convertToId(psLocsProto.get(i).getPsId()),
ProtobufUtil.convertToLocation(psLocsProto.get(i))));
}
return new PartitionLocation(psLocs);
}
/**
* Report a ps failed to Master
*
* @param psLoc ps id and location
* @throws ServiceException
*/
public void psFailed(PSLocation psLoc) throws ServiceException {
master.psFailedReport(null,
PSFailedReportRequest.newBuilder().setClientId(PSAgentContext.get().getPSAgentId())
.setPsLoc(ProtobufUtil.convert(psLoc)).build());
}
/**
* Get the number of success worker group
*
* @return the number of success worker group
* @throws ServiceException
*/
public int getSuccessWorkerGroupNum() throws ServiceException {
return master.getWorkerGroupSuccessNum(null,
GetWorkerGroupSuccessNumRequest.getDefaultInstance().newBuilder().build()).getSuccessNum();
}
/**
* Get a psagent id
*
* @return psagent id
* @throws ServiceException
*/
public int getPSAgentId() throws ServiceException {
return master.getPSAgentId(null, GetPSAgentIdRequest.getDefaultInstance()).getPsAgentId();
}
/**
* Check PS exist or not
*
* @param psLoc ps id and location
* @return true means ps exited
*/
public boolean isPSExited(PSLocation psLoc) throws ServiceException {
return master.checkPSExited(null,
CheckPSExitRequest.newBuilder().setClientId(PSAgentContext.get().getPSAgentId())
.setPsLoc(ProtobufUtil.convert(psLoc)).build()).getExited() == 1;
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy