com.tencent.angel.master.ps.ParameterServerManager Maven / Gradle / Ivy
/*
* Tencent is pleased to support the open source community by making Angel available.
*
* Copyright (C) 2017-2018 THL A29 Limited, a Tencent company. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in
* compliance with the License. You may obtain a copy of the License at
*
* https://opensource.org/licenses/Apache-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License
* is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
* or implied. See the License for the specific language governing permissions and limitations under
* the License.
*
*/
package com.tencent.angel.master.ps;
import com.tencent.angel.common.location.Location;
import com.tencent.angel.conf.AngelConf;
import com.tencent.angel.master.app.AMContext;
import com.tencent.angel.master.app.AppEvent;
import com.tencent.angel.master.app.AppEventType;
import com.tencent.angel.master.app.InternalErrorEvent;
import com.tencent.angel.master.ps.attempt.PSAttemptDiagnosticsUpdateEvent;
import com.tencent.angel.master.ps.attempt.PSAttemptEvent;
import com.tencent.angel.master.ps.attempt.PSAttemptEventType;
import com.tencent.angel.master.ps.ps.AMParameterServer;
import com.tencent.angel.master.ps.ps.AMParameterServerEvent;
import com.tencent.angel.master.ps.ps.AMParameterServerEventType;
import com.tencent.angel.ps.PSAttemptId;
import com.tencent.angel.ps.ParameterServerId;
import com.tencent.angel.ps.server.data.PSLocation;
import com.tencent.angel.utils.StringUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.service.AbstractService;
import org.apache.hadoop.yarn.api.records.Priority;
import org.apache.hadoop.yarn.api.records.Resource;
import org.apache.hadoop.yarn.event.EventHandler;
import org.apache.hadoop.yarn.factory.providers.RecordFactoryProvider;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicBoolean;
/**
* Parameter server manager, it managers a group of
* {@link com.tencent.angel.master.ps.ps.AMParameterServer}. It is responsible for starting all the
* parameter servers. If all the workers have completed the model training, it will inform the
* parameter servers to write models to the corresponding file, and finally merge them.
*/
public class ParameterServerManager extends AbstractService
implements EventHandler {
private static final Log LOG = LogFactory.getLog(ParameterServerManager.class);
private final AMContext context;
/**
* parameter server number
*/
private final int psNumber;
/**
* If we need to suggest physical machines which the parameter servers will running on, we can
* specify the physical machine ip list
*/
private final String[] ips;
/**
* the amount of resources requested for each parameter server
*/
private final Resource psResource;
/**
* the resource priority for parameter servers
*/
private final Priority priority;
/**
* parameter server id to parameter server management unit map
*/
private final Map psMap;
/**
* The parameter server collection that has completed the commit operation
*/
private final Set committedPs;
/**
* whether you can start commit operation
*/
private final AtomicBoolean canCommit;
/**
* need commit matrices
*/
private volatile List needCommitMatrixIds;
/**
* parameter server id to attempt index map, it use to master recover
*/
private final Map psIdToAttemptIndexMap;
private final AMPSFailedReport report;
private final AtomicBoolean stopped = new AtomicBoolean(false);
/**
* parameter server attempt id to last heartbeat timestamp map
*/
private final ConcurrentHashMap psLastHeartbeatTS = new ConcurrentHashMap<>();
/**
* parameter server heartbeat timeout value in millisecond
*/
private final long psTimeOutMS;
public ParameterServerManager(AMContext context,
Map psIdToAttemptIndexMap) {
super("PS Manager");
this.context = context;
this.psIdToAttemptIndexMap = psIdToAttemptIndexMap;
Configuration conf = context.getConf();
String ipListStr = conf.get(AngelConf.ANGEL_PS_IP_LIST);
if (ipListStr != null) {
ips = ipListStr.split(",");
psNumber = ips.length;
} else {
ips = null;
psNumber = conf.getInt(AngelConf.ANGEL_PS_NUMBER, AngelConf.DEFAULT_ANGEL_PS_NUMBER);
}
int psServerMemory =
conf.getInt(AngelConf.ANGEL_PS_MEMORY_GB, AngelConf.DEFAULT_ANGEL_PS_MEMORY_GB) * 1024;
int psServerVcores =
conf.getInt(AngelConf.ANGEL_PS_CPU_VCORES, AngelConf.DEFAULT_ANGEL_PS_CPU_VCORES);
int psPriority = conf.getInt(AngelConf.ANGEL_PS_PRIORITY, AngelConf.DEFAULT_ANGEL_PS_PRIORITY);
psTimeOutMS = conf.getLong(AngelConf.ANGEL_PS_HEARTBEAT_TIMEOUT_MS,
AngelConf.DEFAULT_ANGEL_PS_HEARTBEAT_TIMEOUT_MS);
psResource = Resource.newInstance(psServerMemory, psServerVcores);
priority = RecordFactoryProvider.getRecordFactory(null).newRecordInstance(Priority.class);
priority.setPriority(psPriority);
psMap = new HashMap<>();
committedPs = new HashSet<>();
canCommit = new AtomicBoolean(false);
report = new AMPSFailedReport();
}
@Override protected void serviceStart() throws Exception {
}
@Override protected void serviceStop() throws Exception {
if (stopped.getAndSet(true)) {
return;
}
super.serviceStop();
LOG.info("ParameterServerManager stopped");
}
/**
* Init all PS
*/
public void init() {
for (int i = 0; i < psNumber; i++) {
ParameterServerId id = new ParameterServerId(i);
AMParameterServer server = null;
if (ips != null) {
server = new AMParameterServer(ips[i], id, context);
} else {
server = new AMParameterServer(id, context);
}
if (psIdToAttemptIndexMap != null && psIdToAttemptIndexMap.containsKey(id)) {
server.setNextAttemptNumber(psIdToAttemptIndexMap.get(id));
}
psMap.put(id, server);
}
}
/**
* Start all PS
*/
public void startAllPS() {
for (Map.Entry entry : psMap.entrySet()) {
entry.getValue()
.handle(new AMParameterServerEvent(AMParameterServerEventType.PS_SCHEDULE, entry.getKey()));
}
}
/**
* get parameter servers map
*
* @return Map parameter servers map
*/
public Map getParameterServerMap() {
return psMap;
}
@Override public void handle(ParameterServerManagerEvent event) {
LOG.debug("Processing event type " + event.getType());
switch (event.getType()) {
case COMMIT: {
LOG.info("set canCommit to true.");
canCommit.set(true);
needCommitMatrixIds = ((CommitEvent) event).getNeedCommitMatrixIds();
break;
}
case PARAMETERSERVER_DONE: {
commitSuccess(event);
break;
}
case PARAMETERSERVER_FAILED: {
psFailed(event);
break;
}
case PARAMETERSERVER_KILLED: {
psKilled(event);
break;
}
default:
break;
}
}
/**
* Check whether we can start the commit operation
*
* @return boolean true indicates that a commit operation can be performed
*/
public boolean psCanCommit() {
return canCommit.get();
}
/**
* get parameter server manager unit use id
*
* @param id parameter server id
* @return AMParameterServer parameter server manager unit
*/
public AMParameterServer getParameterServer(ParameterServerId id) {
return psMap.get(id);
}
/**
* get parameter server number
*
* @return int parameter server number
*/
public int getPsNumber() {
return psNumber;
}
/**
* get parameter server resource allocation
*
* @return Resource parameter server resource allocation
*/
public Resource getPsResource() {
return psResource;
}
/**
* get parameter server resource priority
*
* @return Priority parameter server resource priority
*/
public Priority getPriority() {
return priority;
}
@SuppressWarnings("unchecked") private void psKilled(ParameterServerManagerEvent event) {
context.getEventHandler().handle(new AppEvent(context.getApplicationId(), AppEventType.KILL));
}
@SuppressWarnings("unchecked") private void psFailed(ParameterServerManagerEvent event) {
List diagnostics =
context.getParameterServerManager().getParameterServer(event.getPsId()).getDiagnostics();
StringBuilder sb = new StringBuilder();
sb.append(StringUtils.join("\n", diagnostics));
context.getEventHandler()
.handle(new InternalErrorEvent(context.getApplicationId(), sb.toString()));
}
private void commitSuccess(ParameterServerManagerEvent event) {
committedPs.add(event.getPsId());
//if all parameter server complete commit, master can commit now
if (committedPs.size() == psNumber) {
commit();
}
}
private void commit() {
}
/**
* Get the matrices that need commit.
*
* @return List matrices that need commit.
*/
public List getNeedCommitMatrixIds() {
return needCommitMatrixIds;
}
/**
* Update ps failed counters
*
* @param counters ps failed counters
*/
public void psFailedReports(Map counters) {
report.psFailedReports(counters);
}
public void psFailedReport(PSLocation psLoc) {
//restartPS(psLoc);
}
private void restartPS(PSLocation psLoc) {
getParameterServer(psLoc.psId).restart(psLoc);
}
public void checkHBTimeOut() {
//check whether parameter server heartbeat timeout
Iterator> psIt = psLastHeartbeatTS.entrySet().iterator();
Map.Entry psEntry;
long currentTs = System.currentTimeMillis();
while (psIt.hasNext()) {
psEntry = psIt.next();
if (currentTs - psEntry.getValue() > psTimeOutMS) {
LOG.error(psEntry.getKey() + " heartbeat timeout!!!");
context.getEventHandler()
.handle(new PSAttemptDiagnosticsUpdateEvent("heartbeat timeout", psEntry.getKey()));
context.getEventHandler()
.handle(new PSAttemptEvent(PSAttemptEventType.PA_FAILMSG, psEntry.getKey()));
psIt.remove();
}
}
}
/**
* PS attempt register
*
* @param psAttemptId PS attempt id
*/
public void register(PSAttemptId psAttemptId) {
LOG.info("PS " + psAttemptId + " is registered in monitor!");
psLastHeartbeatTS.put(psAttemptId, System.currentTimeMillis());
}
/**
* PS attempt unregister
*
* @param psAttemptId PS attempt id
*/
public void unRegister(PSAttemptId psAttemptId) {
LOG.info("PS " + psAttemptId + " is finished, delete it in monitor!");
psLastHeartbeatTS.remove(psAttemptId);
}
/**
* Is PS attempt alive
*
* @param psAttemptId PS attempt id
* @return true mean alive
*/
public boolean isAlive(PSAttemptId psAttemptId) {
return psLastHeartbeatTS.containsKey(psAttemptId);
}
/**
* Update PS attempt latest heartbeat timestamp
*
* @param psAttemptId PS attempt id
*/
public void alive(PSAttemptId psAttemptId) {
psLastHeartbeatTS.put(psAttemptId, System.currentTimeMillis());
}
/**
* Is PS attempt failed
*
* @param psLoc PS id and PS location
* @return true means failed
*/
public boolean checkFailed(PSLocation psLoc) {
Location loc = context.getLocationManager().getPsLocation(psLoc.psId);
if (loc == null || !loc.equals(psLoc.loc)) {
return true;
} else {
return false;
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy