All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.tencent.angel.psagent.task.TaskContext Maven / Gradle / Ivy

There is a newer version: 3.2.0
Show newest version
/*
 * Tencent is pleased to support the open source community by making Angel available.
 *
 * Copyright (C) 2017-2018 THL A29 Limited, a Tencent company. All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in 
 * compliance with the License. You may obtain a copy of the License at
 *
 * https://opensource.org/licenses/Apache-2.0
 *
 * Unless required by applicable law or agreed to in writing, software distributed under the License
 * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
 * or implied. See the License for the specific language governing permissions and limitations under
 * the License.
 *
 */


package com.tencent.angel.psagent.task;

import com.google.protobuf.ServiceException;
import com.tencent.angel.PartitionKey;
import com.tencent.angel.conf.AngelConf;
import com.tencent.angel.master.task.TaskCounter;
import com.tencent.angel.ml.metric.Metric;
import com.tencent.angel.psagent.PSAgentContext;
import com.tencent.angel.psagent.clock.ClockCache;
import com.tencent.angel.psagent.matrix.storage.MatrixStorageManager;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;

/**
 * Task running context. Task is executing unit for consistency control.
 */
public class TaskContext {
  private static final Log LOG = LogFactory.getLog(TaskContext.class);
  // matrix id to clock map
  private final ConcurrentHashMap matrixIdToClockMap;

  // task index, it must be unique for whole application
  private final int index;

  /**
   * Matrix storage for task
   */
  private final MatrixStorageManager matrixStorage;

  /**
   * Current epoch number
   */
  private final AtomicInteger epoch;

  /**
   * Update the matrix clock to Master synchronously
   */
  private final boolean syncClockEnable;

  /**
   * Task current epoch running progress
   */
  private volatile float progress;

  /**
   * Task system metrics
   */
  private final Map metrics;

  /**
   * Task algorithm metrics
   */
  private final Map algoMetrics;

  /**
   * Create a TaskContext
   *
   * @param index task index
   */
  public TaskContext(int index) {
    this.index = index;
    this.matrixIdToClockMap = new ConcurrentHashMap();
    this.matrixStorage = new MatrixStorageManager();
    this.epoch = new AtomicInteger(0);
    this.metrics = new ConcurrentHashMap<>();
    this.algoMetrics = new ConcurrentHashMap<>();

    syncClockEnable = PSAgentContext.get().getConf()
      .getBoolean(AngelConf.ANGEL_PSAGENT_SYNC_CLOCK_ENABLE,
        AngelConf.DEFAULT_ANGEL_PSAGENT_SYNC_CLOCK_ENABLE);
  }

  /**
   * Get task index
   *
   * @return task index
   */
  public int getIndex() {
    return index;
  }

  /**
   * Get the clock value of a matrix
   *
   * @param matrixId matrix id
   * @return clock value
   */
  public int getMatrixClock(int matrixId) {
    if (!matrixIdToClockMap.containsKey(matrixId)) {
      matrixIdToClockMap.putIfAbsent(matrixId, new AtomicInteger(0));
    }
    return matrixIdToClockMap.get(matrixId).get();
  }

  /**
   * Get the clock value of a matrix
   *
   * @param matrixId matrix id
   * @return clock value
   */
  public int getPSMatrixClock(int matrixId) {
    ClockCache clockCache = PSAgentContext.get().getClockCache();
    List pkeys = PSAgentContext.get().getMatrixMetaManager().getPartitions(matrixId);
    int size = pkeys.size();
    int clock = Integer.MAX_VALUE;
    int partClock = 0;
    for (int i = 0; i < size; i++) {
      partClock = clockCache.getClock(matrixId, pkeys.get(i));
      if (partClock < clock) {
        clock = partClock;
      }
    }

    return clock;
  }

  /**
   * Global sync with special matrix,still wait until all matrixes's clock is synchronized.
   *
   * @param matrixId the matrix id
   * @throws InterruptedException
   */
  public void globalSync(int matrixId) throws InterruptedException {
    ClockCache clockCache = PSAgentContext.get().getClockCache();
    List pkeys = PSAgentContext.get().getMatrixMetaManager().getPartitions(matrixId);

    int syncTimeIntervalMS = PSAgentContext.get().getConf()
      .getInt(AngelConf.ANGEL_PSAGENT_CACHE_SYNC_TIMEINTERVAL_MS,
        AngelConf.DEFAULT_ANGEL_PSAGENT_CACHE_SYNC_TIMEINTERVAL_MS);

    while (true) {
      boolean sync = true;
      for (PartitionKey pkey : pkeys) {
        if (clockCache.getClock(matrixId, pkey) < getMatrixClock(matrixId)) {
          sync = false;
          break;
        }
      }

      if (!sync) {
        Thread.sleep(syncTimeIntervalMS);
      } else {
        break;
      }
    }
  }

  /**
   * Global sync with all matrix.
   *
   * @throws InterruptedException
   */
  public void globalSync() throws InterruptedException {
    for (Integer matId : getMatrixClocks().keySet())
      globalSync(matId);
  }

  @Override public String toString() {
    return super.toString() + "TaskContext [index=" + index + ", matrix clocks="
      + printMatrixClocks() + "]";
  }

  private String printMatrixClocks() {
    StringBuilder sb = new StringBuilder();

    for (Entry entry : matrixIdToClockMap.entrySet()) {
      sb.append("(matrixId=");
      sb.append(entry.getKey());
      sb.append(",");
      sb.append("clock=");
      sb.append(entry.getValue().get());
      sb.append(")");
    }

    return sb.toString();
  }

  /**
   * Increase the clock value of the matrix
   *
   * @param matrixId matrix id
   */
  public void increaseMatrixClock(int matrixId) {
    if (!matrixIdToClockMap.containsKey(matrixId)) {
      matrixIdToClockMap.putIfAbsent(matrixId, new AtomicInteger(0));
    }
    matrixIdToClockMap.get(matrixId).incrementAndGet();
  }

  /**
   * Get the matrix storage manager for this task
   *
   * @return matrix storage manager
   */
  public MatrixStorageManager getMatrixStorage() {
    return matrixStorage;
  }

  /**
   * Get Task current epoch number
   *
   * @return current epoch number
   */
  public int getEpoch() {
    return epoch.get();
  }

  /**
   * Increase epoch number
   *
   * @throws ServiceException
   */
  public void increaseEpoch() throws ServiceException {
    int iterationValue = epoch.incrementAndGet();
    if (syncClockEnable) {
      PSAgentContext.get().getMasterClient().setAlgoMetrics(index, algoMetrics);
      PSAgentContext.get().getMasterClient().taskIteration(index, iterationValue);
    }
  }

  /**
   * Set the epoch number
   *
   * @param iterationValue epoch number
   */
  public void setEpoch(int iterationValue) {
    epoch.set(iterationValue);
  }

  /**
   * Set matrix clock value
   *
   * @param matrixId   matrix id
   * @param clockValue clock value
   */
  public void setMatrixClock(int matrixId, int clockValue) {
    AtomicInteger clock = matrixIdToClockMap.get(matrixId);
    if (clock == null) {
      clock = matrixIdToClockMap.putIfAbsent(matrixId, new AtomicInteger(clockValue));
      if (clock == null) {
        clock = matrixIdToClockMap.get(matrixId);
      }
    }

    clock.set(clockValue);
  }

  /**
   * Get matrix clocks
   *
   * @return matrix clocks
   */
  public Map getMatrixClocks() {
    return matrixIdToClockMap;
  }

  /**
   * Set task running progress in current epoch
   *
   * @param progress task running progress in current epoch
   */
  public void setProgress(float progress) {
    this.progress = progress;
  }

  /**
   * Get task running progress in current epoch
   *
   * @return progress
   */
  public float getProgress() {
    return progress;
  }

  /**
   * Update calculate profiling counters
   *
   * @param sampleNum calculate sample number
   * @param useTimeMs the time use to calculate the samples
   */
  public void updateProfileCounter(int sampleNum, int useTimeMs) {
    updateCounter(TaskCounter.TOTAL_CALCULATE_SAMPLES, sampleNum);
    updateCounter(TaskCounter.TOTAL_CALCULATE_TIME_MS, useTimeMs);
  }

  /**
   * Increment the counter
   *
   * @param counterName counter name
   * @param updateValue increment value
   */
  public void updateCounter(String counterName, int updateValue) {
    AtomicLong counter = metrics.get(counterName);
    if (counter == null) {
      counter = metrics.putIfAbsent(counterName, new AtomicLong(0));
      if (counter == null) {
        counter = metrics.get(counterName);
      }
    }
    counter.addAndGet(updateValue);
  }

  /**
   * Update the counter
   *
   * @param counterName counter name
   * @param updateValue new counter value
   */
  public void setCounter(String counterName, int updateValue) {
    AtomicLong counter = metrics.get(counterName);
    if (counter == null) {
      counter = metrics.putIfAbsent(counterName, new AtomicLong(0));
      if (counter == null) {
        counter = metrics.get(counterName);
      }
    }
    counter.set(updateValue);
  }

  public Map getMetrics() {
    return metrics;
  }

  /**
   * Add a algorithm metric
   *
   * @param name   metric name
   * @param metric metric dependency values
   */
  public void addAlgoMetric(String name, Metric metric) {
    algoMetrics.put(name, metric);
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy