com.tencent.angel.psagent.consistency.ConsistencyController Maven / Gradle / Ivy
* Tencent is pleased to support the open source community by making Angel available.
* Copyright (C) 2017-2018 THL A29 Limited, a Tencent company. All rights reserved.
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in
* compliance with the License. You may obtain a copy of the License at
* Unless required by applicable law or agreed to in writing, software distributed under the License
* is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
* or implied. See the License for the specific language governing permissions and limitations under
* the License.
package com.tencent.angel.psagent.consistency;
import com.tencent.angel.conf.AngelConf;
import com.tencent.angel.conf.MatrixConf;
import com.tencent.angel.psagent.PSAgentContext;
import com.tencent.angel.psagent.clock.ClockCache;
import com.tencent.angel.psagent.matrix.ResponseType;
import com.tencent.angel.psagent.matrix.transport.adapter.GetRowsResult;
import com.tencent.angel.psagent.matrix.transport.adapter.RowIndex;
import com.tencent.angel.psagent.task.TaskContext;
import it.unimi.dsi.fastutil.ints.IntOpenHashSet;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import java.util.concurrent.locks.ReentrantReadWriteLock;
* Angel task consistency controller, Angel support 3 consistency protocol: BSP, SSP, ASYNC. If
* stalenss > 0, means SSP if staleness = 0, means BSP if staleness < 0, means ASYNC
public class ConsistencyController {
private static final Log LOG = LogFactory.getLog(ConsistencyController.class);
* staleness value
private final int globalStaleness;
* Create a new ConsistencyController.
* @param staleness staleness value
public ConsistencyController(int staleness) {
this.globalStaleness = staleness;
* Init.
public void init() {
public Vector getRow(TaskContext taskContext, int matrixId, int rowIndex)
throws ExecutionException, InterruptedException, IOException {
int staleness = getStaleness(matrixId);
if (staleness >= 0) {
// Use simple flow, do not use any cache
if (staleness == 0 && PSAgentContext.get().getLocalTaskNum() == 1) {
waitForClock(matrixId, rowIndex, taskContext.getMatrixClock(matrixId));
return ((GetRowResult) PSAgentContext.get().getUserRequestAdapter()
.get(new GetRow(new GetRowParam(matrixId, rowIndex)))).getRow();
// Get row from cache.
Vector row = PSAgentContext.get().getMatrixStorageManager().getRow(matrixId, rowIndex);
// if row clock is satisfy ssp staleness limit, just return.
if (row != null && (taskContext.getPSMatrixClock(matrixId) <= row.getClock()) && (
taskContext.getMatrixClock(matrixId) - row.getClock() <= staleness)) {
LOG.debug("task " + taskContext.getIndex() + " matrix " + matrixId + " clock " + taskContext
.getMatrixClock(matrixId) + ", row clock " + row.getClock() + ", staleness " + staleness
+ ", just get from global storage");
return cloneRow(matrixId, rowIndex, row);
// Get row from ps.
// Wait until the clock value of this row is greater than or equal to the value
int stalenessClock = taskContext.getMatrixClock(matrixId) - staleness;
waitForClock(matrixId, rowIndex, stalenessClock);
row = PSAgentContext.get().getUserRequestAdapter().getRow(matrixId, rowIndex, stalenessClock);
PSAgentContext.get().getMatrixStorageManager().addRow(matrixId, rowIndex, row);
return cloneRow(matrixId, rowIndex, row);
} else {
// For ASYNC mode, just get from pss.
GetRow func = new GetRow(new GetRowParam(matrixId, rowIndex));
GetRowResult result = ((GetRowResult) PSAgentContext.get().getUserRequestAdapter().get(func));
if (result.getResponseType() == ResponseType.FAILED) {
throw new IOException("get row from ps failed.");
} else {
return result.getRow();
* Get a batch of row from storage/cache or pss.
* @param taskContext task context
* @param rowIndex row indexes
* @param rpcBatchSize fetch row number in one rpc request
* @return GetRowsResult rows
* @throws Exception
public GetRowsResult getRowsFlow(TaskContext taskContext, RowIndex rowIndex, int rpcBatchSize)
throws Exception {
GetRowsResult result = new GetRowsResult();
if (rowIndex.getRowsNumber() == 0) {
LOG.error("need get rowId set is empty, just return");
return result;
int staleness = getStaleness(rowIndex.getMatrixId());
if (staleness >= 0) {
// For BSP/SSP, get rows from storage/cache first
int stalenessClock = taskContext.getMatrixClock(rowIndex.getMatrixId()) - staleness;
findRowsInStorage(taskContext, result, rowIndex, stalenessClock);
if (!result.isFetchOver()) {
LOG.debug("need fetch from parameterserver");
// Get from ps.
// Wait until the clock value of this row is greater than or equal to the value
waitForClock(rowIndex.getMatrixId(), -1, stalenessClock);
.getRowsFlow(result, rowIndex, rpcBatchSize, stalenessClock);
return result;
} else {
//For ASYNC, just get rows from pss.
IntOpenHashSet rowIdSet = rowIndex.getRowIds();
GetRows func = new GetRows(new GetRowsParam(rowIndex.getMatrixId(), rowIdSet.toIntArray())); funcResult =
(( PSAgentContext.get()
if (funcResult.getResponseType() == ResponseType.FAILED) {
throw new IOException("get rows from ps failed.");
} else {
Map rows = funcResult.getRows();
for (Entry rowEntry : rows.entrySet()) {
return result;
* Update clock for a matrix.
* @param taskContext task context
* @param matrixId matrix id
* @param flushFirst flush matrix oplog first or not
* @return Future clock result future
public Future clock(TaskContext taskContext, int matrixId, boolean flushFirst) {
return PSAgentContext.get().getOpLogCache().clock(taskContext, matrixId, flushFirst);
* Get staleness value.
* @return int staleness value
public int getStaleness() {
return globalStaleness;
* Get staleness value for the matrix.
* @param matrixId matrix id
* @return int staleness value
public int getStaleness(int matrixId) {
MatrixMeta meta = PSAgentContext.get().getMatrixMetaManager().getMatrixMeta(matrixId);
if (meta == null || meta.getAttribute(MatrixConf.MATRIX_STALENESS) == null) {
return globalStaleness;
} else {
try {
return Integer.valueOf(meta.getAttribute(MatrixConf.MATRIX_STALENESS));
} catch (Exception x) {
LOG.warn("parse matrix staleness value failed for matrix " + matrixId, x);
return globalStaleness;
private void findRowsInStorage(TaskContext taskContext, GetRowsResult result, RowIndex rowIndexes,
int stalenessClock) throws InterruptedException {
MatrixStorage storage =
for (int rowIndex : rowIndexes.getRowIds()) {
Vector processRow = storage.getRow(rowIndex);
if (processRow != null && (taskContext.getPSMatrixClock(rowIndexes.getMatrixId())
<= processRow.getClock()) && (processRow.getClock() >= stalenessClock)) {
if (result.getRowsNumber() == rowIndexes.getRowsNumber()) {
private Vector cloneRow(int matrixId, int rowIndex, Vector row) {
if (row == null) {
return null;
if (isNeedClone(matrixId)) {
ReentrantReadWriteLock globalStorage =
try {
return row.copy();
} finally {
} else {
return row;
private boolean isNeedClone(int matrixId) {
MatrixMeta matrixMeta = PSAgentContext.get().getMatrixMetaManager().getMatrixMeta(matrixId);
int localTaskNum = PSAgentContext.get().getLocalTaskNum();
return !matrixMeta.isHogwild() && localTaskNum > 1;
* Get row use index
* @param taskContext task context
* @param func index get psf
* @return the need row
* @throws Exception
public Vector getRow(TaskContext taskContext, IndexGet func) throws Exception {
int matrixId = func.getParam().getMatrixId();
int rowIndex = ((IndexGetParam) func.getParam()).getRowId();
int staleness = getStaleness(matrixId);
if (staleness >= 0) {
waitForClock(matrixId, rowIndex, taskContext.getMatrixClock(matrixId) - staleness);
return ((GetRowResult) PSAgentContext.get().getUserRequestAdapter().get(func)).getRow();
* Get row use index
* @param taskContext task context
* @param func index get psf
* @return the need row
* @throws Exception
public Vector getRow(TaskContext taskContext, LongIndexGet func) throws Exception {
int matrixId = func.getParam().getMatrixId();
int rowIndex = ((LongIndexGetParam) func.getParam()).getRowId();
int staleness = getStaleness(matrixId);
if (staleness >= 0) {
waitForClock(matrixId, rowIndex, taskContext.getMatrixClock(matrixId) - staleness);
return ((GetRowResult) PSAgentContext.get().getUserRequestAdapter().get(func)).getRow();
* Wait for clock for the row of the matrix
* TODO:check success task instead
* @param matrixId matrix id
* @param rowIndex row index
* @param clock clock value
public void waitForClock(int matrixId, int rowIndex, int clock) {"wait for clock " + clock);
ClockCache clockCache = PSAgentContext.get().getClockCache();
int clockUpdateIntervalMs = PSAgentContext.get().getConf()
int checkMasterIntervalMs = clockUpdateIntervalMs * 2;
long startTs = System.currentTimeMillis();
while (true) {
int cachedClock;
if (rowIndex == -1) {
cachedClock = clockCache.getClock(matrixId);
} else {
cachedClock = clockCache.getClock(matrixId, rowIndex);
if (cachedClock >= clock) {"wait for clock " + clock + " over");
try {
} catch (InterruptedException e) {
LOG.warn("waitForClock is interrupted " + e.getMessage());
if (System.currentTimeMillis() - startTs > checkMasterIntervalMs) {
try {
if (PSAgentContext.get().getMasterClient().getSuccessWorkerGroupNum() >= 1) {"Some Worker run success, do not need wait");
} catch (ServiceException e) {
LOG.error("getSuccessWorkerGroupNum from Master falied ", e);
startTs = System.currentTimeMillis();
© 2015 - 2025 Weber Informatics LLC | Privacy Policy