All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.tencent.angel.psagent.matrix.PSAgentMatrixMetaManager Maven / Gradle / Ivy

There is a newer version: 3.2.0
Show newest version
/*
 * Tencent is pleased to support the open source community by making Angel available.
 *
 * Copyright (C) 2017-2018 THL A29 Limited, a Tencent company. All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in 
 * compliance with the License. You may obtain a copy of the License at
 *
 * https://opensource.org/licenses/Apache-2.0
 *
 * Unless required by applicable law or agreed to in writing, software distributed under the License
 * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
 * or implied. See the License for the specific language governing permissions and limitations under
 * the License.
 *
 */


package com.tencent.angel.psagent.matrix;

import com.google.protobuf.ServiceException;
import com.tencent.angel.PartitionKey;
import com.tencent.angel.ml.matrix.MatrixMeta;
import com.tencent.angel.ml.matrix.MatrixMetaManager;
import com.tencent.angel.ml.matrix.PartitionLocation;
import com.tencent.angel.ml.matrix.PartitionMeta;
import com.tencent.angel.ps.ParameterServerId;
import com.tencent.angel.ps.server.data.PSLocation;
import com.tencent.angel.psagent.PSAgentContext;
import com.tencent.angel.psagent.clock.ClockCache;
import com.tencent.angel.psagent.matrix.transport.adapter.RowIndex;
import it.unimi.dsi.fastutil.ints.IntOpenHashSet;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

import java.util.*;
import java.util.concurrent.ConcurrentHashMap;

/**
 * PSAgent matrix meta manager
 */
public class PSAgentMatrixMetaManager {
  private static final Log LOG = LogFactory.getLog(PSAgentMatrixMetaManager.class);
  /**
   * Matrix meta manager
   */
  private final MatrixMetaManager matrixMetaManager;

  /**
   * Matrix id -> the partitions of the matrix map
   */
  private final Map> matrixIdToPartsMap;

  /**
   * Matrix id -> (row index -> contain this row partitions map) map
   */
  private final Map>> rowIndexToPartsMap;

  /**
   * Local clock cache
   */
  private final ClockCache partClockCache;

  /**
   * Create PSAgentMatrixMetaManager
   *
   * @param partClockCache clock cache
   */
  public PSAgentMatrixMetaManager(ClockCache partClockCache) {
    this.partClockCache = partClockCache;
    this.matrixMetaManager = new MatrixMetaManager();
    this.matrixIdToPartsMap = new ConcurrentHashMap<>();
    this.rowIndexToPartsMap = new ConcurrentHashMap<>();
  }

  /**
   * Add matrices meta
   *
   * @param matrixMetas matrices meta
   */
  public void addMatrices(List matrixMetas) {
    int size = matrixMetas.size();
    for (int i = 0; i < size; i++) {
      addMatrix(matrixMetas.get(i));
    }
  }

  /**
   * Add matrix meta
   *
   * @param matrixMeta matrix meta
   */
  public void addMatrix(MatrixMeta matrixMeta) {
    matrixMetaManager.addMatrix(matrixMeta);
    List partitions = getPartitions(matrixMeta.getId());
    partClockCache.addMatrix(matrixMeta.getId(), partitions);
  }

  /**
   * Remove matrix meta
   *
   * @param matrixId matrix id
   */
  public void removeMatrix(int matrixId) {
    matrixMetaManager.removeMatrix(matrixId);
    partClockCache.removeMatrix(matrixId);
  }

  /**
   * Get the server that hold the partition.
   *
   * @param partitionKey matrix partition key
   * @param sync         true means get from master, false means get from local cache
   * @return ParameterServerId server id
   */
  public List getPss(PartitionKey partitionKey, boolean sync)
    throws ServiceException {
    if (!sync) {
      return getPss(partitionKey);
    } else {
      List psIds = PSAgentContext.get().getMasterClient()
        .getStoredPss(partitionKey.getMatrixId(), partitionKey.getPartitionId());
      matrixMetaManager.setPss(partitionKey.getMatrixId(), partitionKey.getPartitionId(), psIds);
      return psIds;
    }
  }

  /**
   * Get partition location: includes stored pss and the location of the pss
   *
   * @param partitionKey partition information
   * @param sync         true means get from master; false means get from cache
   * @return partition location
   * @throws ServiceException
   */
  public PartitionLocation getPartLocation(PartitionKey partitionKey, boolean sync)
    throws ServiceException {
    if (!sync) {
      return getPartLocation(partitionKey);
    } else {
      return PSAgentContext.get().getMasterClient()
        .getPartLocation(partitionKey.getMatrixId(), partitionKey.getPartitionId());
    }
  }

  /**
   * Get partition location: includes stored pss and the location of the pss
   *
   * @param partitionKey partition information
   * @return partition location
   * @throws ServiceException
   */
  public PartitionLocation getPartLocation(PartitionKey partitionKey) {
    List psIds = getPss(partitionKey);
    if (psIds == null) {
      return new PartitionLocation(new ArrayList<>());
    }

    int size = psIds.size();
    List psLocs = new ArrayList<>(size);
    for (int i = 0; i < size; i++) {
      psLocs.add(new PSLocation(psIds.get(i),
        PSAgentContext.get().getLocationManager().getPsLocation(psIds.get(i))));
    }
    return new PartitionLocation(psLocs);
  }

  /**
   * Get the server that hold the partition.
   *
   * @param partitionKey matrix partition key
   * @return ParameterServerId server id
   */
  public List getPss(PartitionKey partitionKey) {
    return matrixMetaManager.getPss(partitionKey.getMatrixId(), partitionKey.getPartitionId());
  }

  /**
   * Get list of partitionKeys for matrixId and rowId.
   *
   * @param matrixId, matrixId
   * @param rowIndex, rowId
   * @return list of partitionKeys
   */
  public List getPartitions(int matrixId, int rowIndex) {
    Map> rowPartKeysCache = rowIndexToPartsMap.get(matrixId);
    if (rowPartKeysCache == null) {
      rowPartKeysCache = new ConcurrentHashMap<>();
      rowIndexToPartsMap.put(matrixId, rowPartKeysCache);
    }
    List rowParts = rowPartKeysCache.get(rowIndex);
    if (rowParts == null) {
      rowParts = getPartitionsFromMeta(matrixId, rowIndex);

      rowPartKeysCache.put(rowIndex, rowParts);
    }

    return rowParts;
  }

  private List getPartitionsFromMeta(int matrixId, int rowIndex) {
    List partitionKeys = new ArrayList<>();
    Iterator iter =
      matrixMetaManager.getMatrixMeta(matrixId).getPartitionMetas().values().iterator();
    while (iter.hasNext()) {
      PartitionKey partitionKey = iter.next().getPartitionKey();
      if (partitionKey.getMatrixId() == matrixId && partitionKey.getStartRow() <= rowIndex
        && partitionKey.getEndRow() > rowIndex)
        partitionKeys.add(partitionKey);
    }
    partitionKeys.sort(new Comparator() {
      @Override public int compare(PartitionKey p1, PartitionKey p2) {
        if (p1.getStartCol() < p2.getStartCol()) {
          return -1;
        } else if (p1.getStartCol() > p2.getStartCol()) {
          return 1;
        } else {
          return 0;
        }
      }
    });
    return partitionKeys;
  }

  /**
   * Get list of partitionkeys belong to matrixId.
   *
   * @param matrixId, identity number of matrix requested
   * @return list of partitionKeys
   */
  public List getPartitions(int matrixId) {
    List partitions = matrixIdToPartsMap.get(matrixId);
    if (partitions == null) {
      partitions = getPartitionsFromMeta(matrixId);
      matrixIdToPartsMap.put(matrixId, partitions);
    }

    return partitions;
  }

  private List getPartitionsFromMeta(int matrixId) {
    List partitionKeys = new ArrayList<>();
    Iterator iter =
      matrixMetaManager.getMatrixMeta(matrixId).getPartitionMetas().values().iterator();
    while (iter.hasNext()) {
      partitionKeys.add(iter.next().getPartitionKey());
    }
    partitionKeys.sort((PartitionKey p1, PartitionKey p2) -> {
      if (p1.getStartCol() < p2.getStartCol()) {
        return -1;
      } else if (p1.getStartCol() > p2.getStartCol()) {
        return 1;
      } else {
        return 0;
      }
    });
    return partitionKeys;
  }

  /**
   * Get the partitions the rows in.
   *
   * @param matrixId   matrix id
   * @param rowIndexes row indexes
   * @return Map> partitions to the rows contained in the partition map
   */
  public Map> getPartitionToRowsMap(int matrixId,
    List rowIndexes) {
    Map> partToRowsMap = new HashMap>();

    int rowNum = rowIndexes.size();
    int partNum;
    for (int i = 0; i < rowNum; i++) {
      List partKeys = getPartitions(matrixId, rowIndexes.get(i));
      partNum = partKeys.size();
      for (int j = 0; j < partNum; j++) {
        List rows = partToRowsMap.get(partKeys.get(j));
        if (rows == null) {
          rows = new ArrayList<>();
          partToRowsMap.put(partKeys.get(j), rows);
        }
        rows.add(rowIndexes.get(i));
      }
    }
    return partToRowsMap;
  }

  /**
   * Get the partitions the rows in.
   *
   * @param matrixId matrix id
   * @param rowIds   row indexes
   * @return Map> partitions to the rows contained in the partition map
   */
  public Map> getPartitionToRowsMap(int matrixId, int[] rowIds) {
    Map> partToRowsMap = new HashMap>();

    int rowNum = rowIds.length;
    int partNum;
    for (int i = 0; i < rowNum; i++) {
      List partKeys = getPartitions(matrixId, rowIds[i]);
      partNum = partKeys.size();
      for (int j = 0; j < partNum; j++) {
        List rows = partToRowsMap.get(partKeys.get(j));
        if (rows == null) {
          rows = new ArrayList<>();
          partToRowsMap.put(partKeys.get(j), rows);
        }
        rows.add(rowIds[i]);
      }
    }
    return partToRowsMap;
  }

  /**
   * Get the partitions the rows in.
   *
   * @param rowIndex    matrix id and row indexes
   * @param batchNumber the split batch size.
   * @return Map> partitions to the rows contained in the partition map
   */
  public Map> getPartitionToRowIndexMap(RowIndex rowIndex,
    int batchNumber) {
    Map> partToRowIndexMap =
      new TreeMap>();
    if (rowIndex.getRowIds() == null) {
      return partToRowIndexMap;
    }

    IntOpenHashSet rowIdSet = rowIndex.getRowIds();
    IntOpenHashSet filtedRowIdSet = rowIndex.getFiltedIdSet();
    int[] rowIds = new int[rowIdSet.size() - filtedRowIdSet.size()];
    int count = 0;
    for (int rowId : rowIdSet) {
      if (!filtedRowIdSet.contains(rowId)) {
        rowIds[count++] = rowId;
      }
    }

    Arrays.sort(rowIds);

    int partNum = 0;
    for (int i = 0; i < rowIds.length; i++) {
      List partKeys = getPartitions(rowIndex.getMatrixId(), rowIds[i]);
      partNum = partKeys.size();
      for (int j = 0; j < partNum; j++) {
        PartitionKey partitionKey = partKeys.get(j);
        List indexList = partToRowIndexMap.get(partitionKey);
        if (indexList == null) {
          indexList = new ArrayList();
          partToRowIndexMap.put(partitionKey, indexList);
          indexList.add(new RowIndex(rowIndex.getMatrixId(), rowIndex));
        }

        RowIndex index = indexList.get(indexList.size() - 1);
        if (index.getRowsNumber() >= batchNumber) {
          index = new RowIndex(rowIndex.getMatrixId(), rowIndex);
          indexList.add(index);
        }
        index.addRowId(rowIds[i]);
      }
    }

    return partToRowIndexMap;
  }

  /**
   * Get the number of the partitions that contains the row.
   *
   * @param matrixId matrix id
   * @param rowIndex row index
   * @return the number of the partitions that contains the row
   */
  public int getRowPartitionSize(int matrixId, int rowIndex) {
    return getPartitions(matrixId, rowIndex).size();
  }

  /**
   * Clear the router table.
   */
  public void clear() {
    matrixIdToPartsMap.clear();
    rowIndexToPartsMap.clear();
    matrixMetaManager.clear();
  }

  /**
   * Get matrix id
   *
   * @param matrixName matrix name
   * @return matrix id
   */
  public int getMatrixId(String matrixName) {
    if (matrixMetaManager.exists(matrixName)) {
      return matrixMetaManager.getMatrixId(matrixName);
    } else {
      synchronized (this) {
        if (!matrixMetaManager.exists(matrixName)) {
          try {
            PSAgentContext.get().getPsAgent().refreshMatrixInfo();
          } catch (Throwable e) {
            LOG.error("get matrix information from master failed. ", e);
          }
        }
        if (matrixMetaManager.exists(matrixName)) {
          return matrixMetaManager.getMatrixId(matrixName);
        } else {
          return -1;
        }
      }
    }
  }

  /**
   * Get matrix meta
   *
   * @param matrixId matrix id
   * @return matrix meta
   */
  public MatrixMeta getMatrixMeta(int matrixId) {
    return matrixMetaManager.getMatrixMeta(matrixId);
  }

  /**
   * Get matrix meta
   *
   * @param matrixName matrix name
   * @return matrix meta
   */
  public MatrixMeta getMatrixMeta(String matrixName) {
    return matrixMetaManager.getMatrixMeta(matrixName);
  }

  /**
   * Get the master ps of a partition
   *
   * @param partKey partition key
   * @return the master ps of a partition
   */
  public ParameterServerId getMasterPS(PartitionKey partKey) {
    return matrixMetaManager.getMasterPs(partKey.getMatrixId(), partKey.getPartitionId());
  }

  /**
   * Get all matrices meta
   *
   * @return all matrices meta
   */
  public Map getMatrixMetas() {
    return matrixMetaManager.getMatrixMetas();
  }

  /**
   * Is a matrix exist
   *
   * @param matrixId matrix id
   * @return true means exist
   */
  public boolean exist(int matrixId) {
    if (!matrixMetaManager.exists(matrixId)) {
      synchronized (this) {
        if (!matrixMetaManager.exists(matrixId)) {
          try {
            PSAgentContext.get().getPsAgent().refreshMatrixInfo();
          } catch (Throwable e) {
            LOG.error("get matrix information from master failed. ", e);
          }
        }
      }
    }
    return matrixMetaManager.exists(matrixId);
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy