com.tencent.angel.psagent.PSAgent Maven / Gradle / Ivy
/*
* Tencent is pleased to support the open source community by making Angel available.
*
* Copyright (C) 2017-2018 THL A29 Limited, a Tencent company. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in
* compliance with the License. You may obtain a copy of the License at
*
* https://opensource.org/licenses/Apache-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License
* is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
* or implied. See the License for the specific language governing permissions and limitations under
* the License.
*
*/
package com.tencent.angel.psagent;
import com.google.protobuf.ServiceException;
import com.tencent.angel.RunningMode;
import com.tencent.angel.common.location.Location;
import com.tencent.angel.common.location.LocationManager;
import com.tencent.angel.conf.AngelConf;
import com.tencent.angel.exception.AngelException;
import com.tencent.angel.exception.InvalidParameterException;
import com.tencent.angel.ipc.TConnection;
import com.tencent.angel.ipc.TConnectionManager;
import com.tencent.angel.ml.matrix.MatrixContext;
import com.tencent.angel.ml.matrix.MatrixMeta;
import com.tencent.angel.ml.matrix.MatrixMetaManager;
import com.tencent.angel.protobuf.generated.PSAgentMasterServiceProtos.PSAgentCommandProto;
import com.tencent.angel.protobuf.generated.PSAgentMasterServiceProtos.PSAgentRegisterResponse;
import com.tencent.angel.protobuf.generated.PSAgentMasterServiceProtos.PSAgentReportResponse;
import com.tencent.angel.ps.ParameterServerId;
import com.tencent.angel.psagent.client.MasterClient;
import com.tencent.angel.psagent.client.PSControlClientManager;
import com.tencent.angel.psagent.clock.ClockCache;
import com.tencent.angel.psagent.consistency.ConsistencyController;
import com.tencent.angel.psagent.executor.Executor;
import com.tencent.angel.psagent.matrix.MatrixClient;
import com.tencent.angel.psagent.matrix.MatrixClientFactory;
import com.tencent.angel.psagent.matrix.PSAgentLocationManager;
import com.tencent.angel.psagent.matrix.PSAgentMatrixMetaManager;
import com.tencent.angel.psagent.matrix.cache.MatricesCache;
import com.tencent.angel.psagent.matrix.oplog.cache.MatrixOpLogCache;
import com.tencent.angel.psagent.matrix.storage.MatrixStorageManager;
import com.tencent.angel.psagent.matrix.transport.MatrixTransportClient;
import com.tencent.angel.psagent.matrix.transport.adapter.UserRequestAdapter;
import com.tencent.angel.utils.NetUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.yarn.api.records.ApplicationId;
import java.util.*;
import java.util.concurrent.atomic.AtomicBoolean;
/**
* Agent for master and parameter servers, It fetches matrix meta from master, fetches matrices from
* parameter servers and pushes updates to parameter servers.
*
* It use {@link MasterClient} to communicate with master. The meta data contains:
* matrix meta {@link MatrixMetaManager},
* server locations {@link LocationManager}.
*
*
* It use {@link MatrixTransportClient} to communicate with parameter servers. Because an
* application layer request generally needs to access multiple parameter servers, so we use
* {@link UserRequestAdapter} first to divide the application layer request to many underlying
* network requests in matrix partition.
*/
public class PSAgent {
private static final Log LOG = LogFactory.getLog(PSAgent.class);
/**
* application configuration
*/
private final Configuration conf;
/**
* application id
*/
private final ApplicationId appId;
/**
* the user that submit the application
*/
private final String user;
/**
* ps agent attempt id
*/
private volatile int id;
/**
* the connection manager for rpc to master
*/
private final TConnection connection;
/**
* the rpc client to master
*/
private volatile MasterClient masterClient;
private volatile PSControlClientManager psControlClientManager;
/**
* psagent location(ip and listening port)
*/
private volatile Location location;
/**
* psagent location(ip and listening port)
*/
private final Location masterLocation;
/**
* master location(ip and listening port)
*/
private volatile PSAgentLocationManager locationManager;
/**
* matrix meta manager
*/
private volatile PSAgentMatrixMetaManager matrixMetaManager;
/**
* matrix updates cache
*/
private volatile MatrixOpLogCache opLogCache;
/**
* the rpc client to parameter servers
*/
private volatile MatrixTransportClient matrixTransClient;
/**
* psagent initialization completion flag
*/
private final AtomicBoolean psAgentInitFinishedFlag;
/**
* psagent heartbeat thread
*/
private volatile Thread heartbeatThread;
/**
* psagent stop flag
*/
private final AtomicBoolean stopped;
/**
* psagent heartbeat interval in milliseconds
*/
private final int heartbeatIntervalMs;
/**
* psagent metrics
*/
private final Map metrics;
/**
* SSP consistency controller
*/
private volatile ConsistencyController consistencyController;
/**
* matrix partitions clock cache
*/
private volatile ClockCache clockCache;
/**
* matrices cache
*/
private volatile MatricesCache matricesCache;
/**
* matrix storage manager
*/
private volatile MatrixStorageManager matrixStorageManager;
/**
* application layer request adapter
*/
private volatile UserRequestAdapter userRequestAdapter;
/**
* if we need startup heartbeat thread
*/
private final boolean needHeartBeat;
/**
* application running mode(LOCAL,YARN)
*/
private final RunningMode runningMode;
/**
* the machine learning executor reference
*/
private final Executor executor;
/**
* psagent exited flag
*/
private final AtomicBoolean exitedFlag;
/**
* Control connection manager
*/
private volatile TConnection controlConnectManager;
/**
* Create a new PSAgent instance.
*
* @param masterIp master ip
* @param masterPort master listening port
* @param clientIndex psagent index
*/
public PSAgent(String masterIp, int masterPort, int clientIndex) {
this(new Configuration(), masterIp, masterPort, clientIndex, true, null);
}
/**
* Create a new PSAgent instance.
*
* @param conf application configuration
* @param masterIp master listening port
* @param masterPort master listening port
* @param clientIndex psagent index
*/
public PSAgent(Configuration conf, String masterIp, int masterPort, int clientIndex) {
this(conf, masterIp, masterPort, clientIndex, true, null);
}
/**
* Create a new PSAgent instance.
*
* @param conf application configuration
* @param appId application id
* @param user the user that submit this application
* @param masterIp master ip
* @param masterPort master port
* @param needHeartBeat true means need startup heartbeat thread
* @param executor the machine learning executor reference
*/
public PSAgent(Configuration conf, ApplicationId appId, String user, String masterIp,
int masterPort, boolean needHeartBeat, Executor executor) {
this.needHeartBeat = needHeartBeat;
this.conf = conf;
this.executor = executor;
this.heartbeatIntervalMs = conf.getInt(AngelConf.ANGEL_WORKER_HEARTBEAT_INTERVAL_MS,
AngelConf.DEFAULT_ANGEL_WORKER_HEARTBEAT_INTERVAL);
this.runningMode = initRunningMode(conf);
this.appId = appId;
this.user = user;
this.masterLocation = new Location(masterIp, masterPort);
this.connection = TConnectionManager.getConnection(conf);
this.psAgentInitFinishedFlag = new AtomicBoolean(false);
this.stopped = new AtomicBoolean(false);
this.exitedFlag = new AtomicBoolean(false);
this.metrics = new HashMap<>();
PSAgentContext.get().setPsAgent(this);
}
/**
* Create a new PSAgent instance.
*
* @param conf application configuration
* @param masterIp master ip
* @param masterPort master port
* @param clientIndex ps agent index
* @param needHeartBeat true means need startup heartbeat thread
* @param executor the machine learning executor reference
*/
public PSAgent(Configuration conf, String masterIp, int masterPort, int clientIndex,
boolean needHeartBeat, Executor executor) {
this.needHeartBeat = needHeartBeat;
this.conf = conf;
this.executor = executor;
this.heartbeatIntervalMs = conf.getInt(AngelConf.ANGEL_WORKER_HEARTBEAT_INTERVAL_MS,
AngelConf.DEFAULT_ANGEL_WORKER_HEARTBEAT_INTERVAL);
this.runningMode = initRunningMode(conf);
this.masterLocation = new Location(masterIp, masterPort);
this.appId = null;
this.user = null;
this.connection = TConnectionManager.getConnection(conf);
this.psAgentInitFinishedFlag = new AtomicBoolean(false);
this.stopped = new AtomicBoolean(false);
this.exitedFlag = new AtomicBoolean(false);
this.metrics = new HashMap<>();
PSAgentContext.get().setPsAgent(this);
}
private RunningMode initRunningMode(Configuration conf) {
String mode = conf.get(AngelConf.ANGEL_RUNNING_MODE, AngelConf.DEFAULT_ANGEL_RUNNING_MODE);
if (mode.equals(RunningMode.ANGEL_PS.toString())) {
return RunningMode.ANGEL_PS;
} else {
return RunningMode.ANGEL_PS_WORKER;
}
}
public void initAndStart() throws Exception {
// Init control connection manager
controlConnectManager = TConnectionManager.getConnection(conf);
// Get ps locations from master and put them to the location cache.
locationManager = new PSAgentLocationManager(PSAgentContext.get());
locationManager.setMasterLocation(masterLocation);
// Build and initialize rpc client to master
masterClient = new MasterClient();
masterClient.init();
// Get psagent id
id = masterClient.getPSAgentId();
// Build PS control rpc client manager
psControlClientManager = new PSControlClientManager();
// Build local location
String localIp = NetUtils.getRealLocalIP();
int port = NetUtils.chooseAListenPort(conf);
location = new Location(localIp, port);
register();
// Initialize matrix meta information
clockCache = new ClockCache();
List matrixMetas = masterClient.getMatrices();
LOG.info("PSAgent get matrices from master," + matrixMetas.size());
this.matrixMetaManager = new PSAgentMatrixMetaManager(clockCache);
matrixMetaManager.addMatrices(matrixMetas);
Map psIdToLocMap = masterClient.getPSLocations();
List psIds = new ArrayList<>(psIdToLocMap.keySet());
Collections.sort(psIds, new Comparator() {
@Override public int compare(ParameterServerId s1, ParameterServerId s2) {
return s1.getIndex() - s2.getIndex();
}
});
int size = psIds.size();
locationManager.setPsIds(psIds.toArray(new ParameterServerId[0]));
for (int i = 0; i < size; i++) {
if (psIdToLocMap.containsKey(psIds.get(i))) {
locationManager.setPsLocation(psIds.get(i), psIdToLocMap.get(psIds.get(i)));
}
}
matrixTransClient = new MatrixTransportClient();
userRequestAdapter = new UserRequestAdapter();
matricesCache = new MatricesCache();
if (runningMode == RunningMode.ANGEL_PS_WORKER) {
opLogCache = new MatrixOpLogCache();
matrixStorageManager = new MatrixStorageManager();
int staleness = conf.getInt(AngelConf.ANGEL_STALENESS, AngelConf.DEFAULT_ANGEL_STALENESS);
consistencyController = new ConsistencyController(staleness);
consistencyController.init();
}
psAgentInitFinishedFlag.set(true);
// Start all services
matrixTransClient.start();
userRequestAdapter.start();
matricesCache.start();
if (runningMode == RunningMode.ANGEL_PS_WORKER) {
clockCache.start();
opLogCache.start();
}
}
/**
* Get matrix meta from master and refresh the local cache.
*
* @throws ServiceException rpc failed
* @throws InterruptedException interrupted while wait for rpc results
*/
public void refreshMatrixInfo()
throws InterruptedException, ServiceException, ClassNotFoundException {
matrixMetaManager.addMatrices(masterClient.getMatrices());
}
/**
* Stop ps agent
*/
public void stop() {
if (!stopped.getAndSet(true)) {
LOG.info("stop heartbeat thread!");
if (heartbeatThread != null) {
heartbeatThread.interrupt();
try {
heartbeatThread.join();
} catch (InterruptedException ie) {
LOG.warn("InterruptedException while stopping heartbeatThread:", ie);
}
heartbeatThread = null;
}
LOG.info("stop op log merger");
if (opLogCache != null) {
opLogCache.stop();
opLogCache = null;
}
LOG.info("stop clock cache");
if (clockCache != null) {
clockCache.stop();
clockCache = null;
}
LOG.info("stop matrix cache");
if (matricesCache != null) {
matricesCache.stop();
matricesCache = null;
}
LOG.info("stop user request adapater");
if (userRequestAdapter != null) {
userRequestAdapter.stop();
userRequestAdapter = null;
}
LOG.info("stop rpc dispacher");
if (matrixTransClient != null) {
matrixTransClient.stop();
matrixTransClient = null;
}
if (PSAgentContext.get() != null) {
PSAgentContext.get().clear();
}
}
}
protected void heartbeat() throws ServiceException {
PSAgentReportResponse response = masterClient.psAgentReport();
switch (response.getCommand()) {
case PSAGENT_NEED_REGISTER:
try {
register();
} catch (Exception x) {
LOG.error("register failed: ", x);
stop();
}
break;
case PSAGENT_SHUTDOWN:
LOG.error("shutdown command come from appmaster, exit now!!");
stop();
break;
default:
break;
}
}
private void register() throws ServiceException {
PSAgentRegisterResponse response = masterClient.psAgentRegister();
if (response.getCommand() == PSAgentCommandProto.PSAGENT_SHUTDOWN) {
LOG.fatal("register to master, receive shutdown command");
stop();
}
}
/**
* Get application configuration
*
* @return Configuration application configuration
*/
public Configuration getConf() {
return conf;
}
/**
* Get master location
*
* @return Location master location
*/
public Location getMasterLocation() {
return locationManager.getMasterLocation();
}
/**
* Get application id
*
* @return ApplicationId application id
*/
public ApplicationId getAppId() {
return appId;
}
/**
* Get the user that submits the application
*
* @return String the user that submits the application
*/
public String getUser() {
return user;
}
/**
* Get the connection manager for rpc to master
*
* @return TConnection the connection manager for rpc to master
*/
public TConnection getConnection() {
return connection;
}
/**
* get the rpc client to master
*
* @return MasterClient the rpc client to master
*/
public MasterClient getMasterClient() {
return masterClient;
}
/**
* Get ps agent location ip
*
* @return String ps agent location ip
*/
public String getIp() {
return location.getIp();
}
/**
* Notify run success message to master
*/
public void done() {
if (!exitedFlag.getAndSet(true)) {
LOG.info("psagent success done");
RunningMode mode = PSAgentContext.get().getRunningMode();
// Stop all modules
if (executor != null) {
executor.done();
} else {
stop();
}
}
}
/**
* Notify run failed message to master
*
* @param errorMsg detail failed message
*/
public void error(String errorMsg) {
if (!exitedFlag.getAndSet(true)) {
LOG.info("psagent falied");
// Stop all modules
if (executor != null) {
executor.error(errorMsg);
} else {
stop();
}
}
}
/**
* Get ps agent location
*
* @return Location ps agent location
*/
public Location getLocation() {
return location;
}
/**
* Get ps location cache
*
* @return LocationCache ps location cache
*/
public PSAgentLocationManager getLocationManager() {
return locationManager;
}
/**
* Get matrix meta manager
*
* @return MatrixMetaManager matrix meta manager
*/
public PSAgentMatrixMetaManager getMatrixMetaManager() {
return matrixMetaManager;
}
/**
* Get matrix update cache
*
* @return MatrixOpLogCache matrix update cache
*/
public MatrixOpLogCache getOpLogCache() {
return opLogCache;
}
/**
* Get rpc client to ps
*
* @return MatrixTransportClient rpc client to ps
*/
public MatrixTransportClient getMatrixTransportClient() {
return matrixTransClient;
}
/**
* Get matrix client for rpc
*
* @param matrixId matrix id
* @return matrix client
*/
public MatrixClient getMatrixClient(int matrixId) throws AngelException {
return getMatrixClient(matrixId, -1);
}
/**
* Get matrix client for rpc
*
* @param matrixId matrix id
* @param taskIndex task id
* @return matrix client
*/
public MatrixClient getMatrixClient(int matrixId, int taskIndex) throws AngelException {
try {
return MatrixClientFactory.get(matrixId, taskIndex);
} catch (Throwable e) {
throw new AngelException(e);
}
}
/**
* Get matrix client for rpc
*
* @param matrixName matrix name
* @return matrix client
*/
public MatrixClient getMatrixClient(String matrixName) throws AngelException {
return getMatrixClient(matrixName, -1);
}
/**
* Get matrix client for rpc
*
* @param matrixName matrix name
* @param taskIndex task id
* @return matrix client
*/
public MatrixClient getMatrixClient(String matrixName, int taskIndex) throws AngelException {
try {
return MatrixClientFactory.get(matrixName, taskIndex);
} catch (Throwable e) {
throw new AngelException(e);
}
}
/**
* Get heartbeat interval in milliseconds
*
* @return int heartbeat interval in milliseconds
*/
public int getHeartbeatIntervalMs() {
return heartbeatIntervalMs;
}
/**
* Get ps agent metrics
*
* @return Map ps agent metrics
*/
public Map getMetrics() {
return metrics;
}
/**
* Get SSP consistency controller
*
* @return ConsistencyController SSP consistency controller
*/
public ConsistencyController getConsistencyController() {
return consistencyController;
}
/**
* Get matrix clock cache
*
* @return ClockCache matrix clock cache
*/
public ClockCache getClockCache() {
return clockCache;
}
/**
* Get matrix cache
*
* @return ClockCache matrix cache
*/
public MatricesCache getMatricesCache() {
return matricesCache;
}
/**
* Create a new matrix
*
* @param matrixContext matrix configuration
* @param timeOutMs maximun wait time in milliseconds
* @return MatrixMeta matrix meta
* @throws AngelException exception come from master
*/
public void createMatrix(MatrixContext matrixContext, long timeOutMs) throws AngelException {
try {
PSAgentContext.get().getMasterClient().createMatrix(matrixContext, timeOutMs);
} catch (Throwable x) {
throw new AngelException(x);
}
}
/**
* Get Matrix meta
*
* @param matrixName matrix name
* @return
*/
public MatrixMeta getMatrix(String matrixName) {
try {
return masterClient.getMatrix(matrixName);
} catch (Throwable e) {
throw new AngelException(e);
}
}
private void removeLocalMatrix(String matrixName) {
int matrixId = matrixMetaManager.getMatrixId(matrixName);
matrixMetaManager.removeMatrix(matrixId);
removeCacheData(matrixId);
}
private void removeCacheData(int matrixId) {
matricesCache.remove(matrixId);
if (runningMode == RunningMode.ANGEL_PS_WORKER) {
opLogCache.remove(matrixId);
matrixStorageManager.removeMatrix(matrixId);
}
}
/**
* Release a batch of matrices
*
* @param matrixNames matrix names
* @throws AngelException exception come from master
*/
public void releaseMatricesUseName(List matrixNames) throws AngelException {
int size = matrixNames.size();
for (int i = 0; i < size; i++) {
removeLocalMatrix(matrixNames.get(i));
}
try {
masterClient.releaseMatrices(matrixNames);
} catch (Throwable x) {
throw new AngelException(x);
}
}
/**
* Release a matrix
*
* @param matrixName matrix name
* @throws AngelException exception come from master
*/
public void releaseMatrix(String matrixName) throws AngelException {
try {
removeLocalMatrix(matrixName);
masterClient.releaseMatrix(matrixName);
} catch (Throwable x) {
throw new AngelException(x);
}
}
/**
* Release a batch of matrices
*
* @param matrixIds matrix ids
* @throws AngelException exception come from master
*/
public void releaseMatrices(List matrixIds) throws AngelException {
int size = matrixIds.size();
List matrixNames = new ArrayList<>(size);
for (int i = 0; i < size; i++) {
MatrixMeta meta = matrixMetaManager.getMatrixMeta(matrixIds.get(i));
if (meta == null) {
continue;
}
matrixNames.add(meta.getName());
}
releaseMatricesUseName(matrixNames);
}
/**
* Release a matrix
*
* @param matrixId matrix id
* @throws AngelException exception come from master
*/
public void releaseMatrix(int matrixId) throws AngelException {
MatrixMeta meta = matrixMetaManager.getMatrixMeta(matrixId);
if (meta == null) {
return;
}
releaseMatrix(meta.getName());
}
/**
* Create a batch of matrices
*
* @param matrixContexts matrices configuration
* @param timeOutMs maximun wait time in milliseconds
* @throws AngelException exception come from master
*/
public void createMatrices(List matrixContexts, long timeOutMs)
throws AngelException {
try {
masterClient.createMatrices(matrixContexts, timeOutMs);
} catch (Throwable x) {
throw new AngelException(x);
}
}
public List getMatrices(List matrixNames) {
try {
return masterClient.getMatrices(matrixNames);
} catch (Throwable x) {
throw new AngelException(x);
}
}
/**
* Get matrix storage manager
*
* @return MatrixStorageManager matrix storage manager
*/
public MatrixStorageManager getMatrixStorageManager() {
return matrixStorageManager;
}
/**
* Get application layer request adapter
*
* @return UserRequestAdapter application layer request adapter
*/
public UserRequestAdapter getUserRequestAdapter() {
return userRequestAdapter;
}
/**
* Get application running mode
*
* @return RunningMode application running mode
*/
public RunningMode getRunningMode() {
return runningMode;
}
/**
* Get the machine learning executor reference
*
* @return Executor the machine learning executor reference
*/
public Executor getExecutor() {
return executor;
}
/**
* Get PSAgent ID
*
* @return PSAgent ID
*/
public int getId() {
return id;
}
/**
* Get control connection manager
*
* @return control connection manager
*/
public TConnection getControlConnectManager() {
return controlConnectManager;
}
/**
* Get PS control rpc client manager
*
* @return PS control rpc client manager
*/
public PSControlClientManager getPsControlClientManager() {
return psControlClientManager;
}
}