All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.tencent.angel.master.task.AMTask Maven / Gradle / Ivy

The newest version!
/*
 * Tencent is pleased to support the open source community by making Angel available.
 *
 * Copyright (C) 2017-2018 THL A29 Limited, a Tencent company. All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in 
 * compliance with the License. You may obtain a copy of the License at
 *
 * https://opensource.org/licenses/Apache-2.0
 *
 * Unless required by applicable law or agreed to in writing, software distributed under the License
 * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
 * or implied. See the License for the specific language governing permissions and limitations under
 * the License.
 *
 */


package com.tencent.angel.master.task;

import com.tencent.angel.protobuf.generated.MLProtos.MatrixClock;
import com.tencent.angel.protobuf.generated.MLProtos.Pair;
import com.tencent.angel.protobuf.generated.WorkerMasterServiceProtos.TaskStateProto;
import com.tencent.angel.worker.task.TaskId;
import it.unimi.dsi.fastutil.ints.Int2IntMap.Entry;
import it.unimi.dsi.fastutil.ints.Int2IntOpenHashMap;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReadWriteLock;
import java.util.concurrent.locks.ReentrantReadWriteLock;

/**
 * Manager for a single task.
 */
public class AMTask {
  private static final Log LOG = LogFactory.getLog(AMTask.class);
  /**
   * task id
   */
  private final TaskId taskId;

  /**
   * task state
   */
  private AMTaskState state;

  /**
   * task iteration
   */
  private int iteration;

  /**
   * matrix id to clock value map
   */
  private final Int2IntOpenHashMap matrixIdToClockMap;

  /**
   * task run progress in current iteration
   */
  private float progress;

  /**
   * task metrics
   */
  private final Map metrics;

  /**
   * task startup time
   */
  private long startTime;

  /**
   * task finish time
   */
  private long finishTime;

  private final Lock readLock;
  private final Lock writeLock;

  public AMTask(TaskId id, AMTask amTask) {
    state = AMTaskState.NEW;
    taskId = id;
    metrics = new HashMap();
    startTime = -1;
    finishTime = -1;

    matrixIdToClockMap = new Int2IntOpenHashMap();
    // if amTask is not null, we should clone task state from it
    if (amTask == null) {
      iteration = 0;
      progress = 0.0f;
    } else {
      iteration = amTask.getIteration();
      progress = amTask.getProgress();
      matrixIdToClockMap.putAll(amTask.matrixIdToClockMap);
    }

    ReadWriteLock readWriteLock = new ReentrantReadWriteLock();
    readLock = readWriteLock.readLock();
    writeLock = readWriteLock.writeLock();
  }

  /**
   * get task id
   *
   * @return TaskId task id
   */
  public TaskId getTaskId() {
    return taskId;
  }

  /**
   * get task state
   *
   * @return AMTaskState task state
   */
  public AMTaskState getState() {
    try {
      readLock.lock();
      return state;
    } finally {
      readLock.unlock();
    }
  }

  private void setState(AMTaskState newState) {
    try {
      writeLock.lock();
      switch (state) {
        case NEW: {
          state = newState;
          break;
        }
        case INITED: {
          if (newState != AMTaskState.NEW) {
            state = newState;
          }
          break;
        }
        case RUNNING:
        case WAITING: {
          if (newState != AMTaskState.NEW && newState != AMTaskState.INITED) {
            state = newState;
            break;
          }
        }
        case COMMITING: {
          if (newState == AMTaskState.SUCCESS || newState == AMTaskState.FAILED
            || newState == AMTaskState.KILLED) {
            state = newState;
          }
          break;
        }

        default: {
          break;
        }
      }
    } finally {
      writeLock.unlock();
    }
  }

  /**
   * get the task progress in current iteration
   *
   * @return float the task progress in current iteration
   */
  public float getProgress() {
    try {
      readLock.lock();
      return progress;
    } finally {
      readLock.unlock();
    }
  }

  /**
   * update task information: iteration, clocks, counters, progress, startup time and finish time etc.
   *
   * @param taskStateProto task information from worker
   */
  public void updateTaskState(TaskStateProto taskStateProto) {
    try {
      writeLock.lock();
      LOG.debug("taskStateProto=" + taskStateProto);
      setState(transformState(taskStateProto.getState()));
      this.progress = taskStateProto.getProgress();
      if (taskStateProto.hasIteration()) {
        this.iteration = taskStateProto.getIteration();
      }
      updateMatrixClocks(taskStateProto.getMatrixClocksList());
      updateMetrics(taskStateProto.getCountersList());

      if (taskStateProto.hasFinishTime()) {
        this.finishTime = taskStateProto.getFinishTime();
      }
      if (taskStateProto.hasStartTime()) {
        this.startTime = taskStateProto.getStartTime();
      }
    } finally {
      writeLock.unlock();
    }
  }

  public void updateCounters(List counters) {
    try {
      writeLock.lock();
      updateMetrics(counters);
    } finally {
      writeLock.unlock();
    }
  }

  private void updateMetrics(List counters) {
    int size = counters.size();
    for (int i = 0; i < size; i++) {
      metrics.put(counters.get(i).getKey(), counters.get(i).getValue());
    }
  }

  private void updateMatrixClocks(List matrixClocks) {
    int size = matrixClocks.size();
    for (int i = 0; i < size; i++) {
      matrixIdToClockMap.put(matrixClocks.get(i).getMatrixId(), matrixClocks.get(i).getClock());
    }
  }

  private AMTaskState transformState(String state) {
    switch (state) {
      case "NEW":
        return AMTaskState.NEW;
      case "INITED":
        return AMTaskState.INITED;
      case "RUNNING":
        return AMTaskState.RUNNING;
      case "WAITING":
        return AMTaskState.WAITING;
      case "COMMITING":
        return AMTaskState.COMMITING;
      case "SUCCESS":
        return AMTaskState.SUCCESS;
      case "FAILED":
        return AMTaskState.FAILED;
      case "KILLED":
        return AMTaskState.KILLED;
      default:
        return AMTaskState.NEW;
    }
  }

  /**
   * get task startup time
   *
   * @return long task startup time
   */
  public long getStartTime() {
    try {
      readLock.lock();
      return startTime;
    } finally {
      readLock.unlock();
    }
  }

  /**
   * get task finish time
   *
   * @return long task finish time
   */
  public long getFinishTime() {
    try {
      readLock.lock();
      return finishTime;
    } finally {
      readLock.unlock();
    }
  }

  /**
   * get the current iteration value
   *
   * @return int the current iteration value
   */
  public int getIteration() {
    try {
      readLock.lock();
      return iteration;
    } finally {
      readLock.unlock();
    }
  }

  @Override public String toString() {
    try {
      readLock.lock();
      return "AMTask [taskId=" + taskId + ", state=" + state + ", iteration=" + iteration
        + ", matrixIdToClockMap=" + matrixIdToClockMap.toString() + ", progress=" + progress + "]";
    } finally {
      readLock.unlock();
    }
  }

  /**
   * get task metrics
   *
   * @return Map task metrics
   */
  public Map getMetrics() {
    try {
      readLock.lock();
      Map cloneMetrics = new HashMap();
      cloneMetrics.putAll(metrics);
      return cloneMetrics;
    } finally {
      readLock.unlock();
    }
  }

  /**
   * update the clock value of the specified matrix
   *
   * @param matrixId matrix id
   * @param clock    clock value
   */
  public void clock(int matrixId, int clock) {
    try {
      writeLock.lock();
      matrixIdToClockMap.put(matrixId, clock);
    } finally {
      writeLock.unlock();
    }
  }

  /**
   * update the task current iteration value
   *
   * @param iteration the task current iteration value
   */
  public void iteration(int iteration) {
    try {
      writeLock.lock();
      this.iteration = iteration;
    } finally {
      writeLock.unlock();
    }
  }

  /**
   * get all matrix clocks
   *
   * @return Int2IntOpenHashMap  all matrix clocks
   */
  public Int2IntOpenHashMap getMatrixClocks() {
    try {
      readLock.lock();
      return matrixIdToClockMap.clone();
    } finally {
      readLock.unlock();
    }
  }

  /**
   * get the clock of the specified matrix
   *
   * @param matrixId the matrix id
   * @return int the clock of the matrix
   */
  public int getMatrixClock(int matrixId) {
    try {
      readLock.lock();
      return matrixIdToClockMap.get(matrixId);
    } finally {
      readLock.unlock();
    }
  }

  /**
   * write the task state to a output stream
   *
   * @param output the output stream
   */
  public void serialize(DataOutputStream output) throws IOException {
    try {
      readLock.lock();
      output.writeUTF(state.toString());
      output.writeInt(iteration);
      output.writeInt(matrixIdToClockMap.size());
      for (Entry clockEntry : matrixIdToClockMap.int2IntEntrySet()) {
        output.writeInt(clockEntry.getIntKey());
        output.writeInt(clockEntry.getIntValue());
      }
    } finally {
      readLock.unlock();
    }
  }

  /**
   * read the task state from a input stream
   *
   * @param input the input stream
   */
  public void deserialize(DataInputStream input) throws IOException {
    try {
      writeLock.lock();
      state = transformState(input.readUTF());
      iteration = input.readInt();
      int clockNum = input.readInt();
      for (int i = 0; i < clockNum; i++) {
        matrixIdToClockMap.put(input.readInt(), input.readInt());
      }
    } finally {
      writeLock.unlock();
    }
  }

  /**
   * Reset some profiling counters
   */
  public void resetCounters() {
    metrics.put(TaskCounter.TOTAL_CALCULATE_SAMPLES, "0");
    metrics.put(TaskCounter.TOTAL_CALCULATE_TIME_MS, "0");
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy