All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.tencent.angel.ps.client.MasterClient Maven / Gradle / Ivy

The newest version!
/*
 * Tencent is pleased to support the open source community by making Angel available.
 *
 * Copyright (C) 2017-2018 THL A29 Limited, a Tencent company. All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in 
 * compliance with the License. You may obtain a copy of the License at
 *
 * https://opensource.org/licenses/Apache-2.0
 *
 * Unless required by applicable law or agreed to in writing, software distributed under the License
 * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
 * or implied. See the License for the specific language governing permissions and limitations under
 * the License.
 *
 */


package com.tencent.angel.ps.client;

import com.google.protobuf.ServiceException;
import com.tencent.angel.common.location.Location;
import com.tencent.angel.ipc.TConnection;
import com.tencent.angel.ipc.TConnectionManager;
import com.tencent.angel.master.MasterProtocol;
import com.tencent.angel.master.matrix.committer.SaveResult;
import com.tencent.angel.ml.matrix.MatrixMeta;
import com.tencent.angel.ml.matrix.PartitionLocation;
import com.tencent.angel.model.PSMatricesLoadResult;
import com.tencent.angel.model.PSMatricesSaveResult;
import com.tencent.angel.protobuf.ProtobufUtil;
import com.tencent.angel.protobuf.generated.MLProtos;
import com.tencent.angel.protobuf.generated.MLProtos.MatrixClock;
import com.tencent.angel.protobuf.generated.PSMasterServiceProtos.*;
import com.tencent.angel.ps.PSContext;
import com.tencent.angel.ps.ParameterServer;
import com.tencent.angel.ps.ParameterServerId;
import com.tencent.angel.ps.server.data.PSLocation;
import it.unimi.dsi.fastutil.ints.Int2IntOpenHashMap;
import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

import java.io.IOException;
import java.net.InetAddress;
import java.net.UnknownHostException;
import java.util.ArrayList;
import java.util.List;

/**
 * Master RPC client
 */
public class MasterClient {
  private static final Log LOG = LogFactory.getLog(ParameterServer.class);
  /**
   * Master rpc protocol
   */
  private volatile MasterProtocol masterProxy;

  private volatile TConnection connection;

  /**
   * PS context
   */
  private final PSContext context;

  /**
   * Create MasterClient
   *
   * @param context PS context
   */
  public MasterClient(PSContext context) {
    this.context = context;
  }

  /**
   * Init
   */
  public void init() {
    connection = TConnectionManager.getConnection(context.getConf());
    Location masterLoc = context.getLocationManager().getMasterLocation();
    try {
      masterProxy = connection.getMasterService(masterLoc.getIp(), masterLoc.getPort());
    } catch (Throwable e) {
      LOG.fatal("Connect to master failed! PS is to exit now!", e);
      context.getPs().failed("Connect to master failed!");
    }
  }

  /**
   * Start
   */
  public void start() {

  }

  /**
   * Stop
   */
  public void stop() {
    if (connection != null) {
      try {
        connection.close();
      } catch (IOException e) {
        LOG.error("close connection falied ", e);
      }
    }
  }

  /**
   * Get task clocks for all matrices from Master
   *
   * @return task clocks for all matrices from Master
   * @throws ServiceException
   */
  public Int2ObjectOpenHashMap getTaskMatrixClocks() throws ServiceException {
    GetTaskMatrixClockResponse response =
      masterProxy.getTaskMatrixClocks(null, GetTaskMatrixClockRequest.newBuilder().build());
    Int2ObjectOpenHashMap taskIdToMatrixClocksMap =
      new Int2ObjectOpenHashMap<>(response.getTaskMatrixClocksCount());

    List taskMatrixClocks = response.getTaskMatrixClocksList();
    int size = taskMatrixClocks.size();
    int matrixNum;
    for (int i = 0; i < size; i++) {
      Int2IntOpenHashMap matrixIdToClockMap =
        new Int2IntOpenHashMap(taskMatrixClocks.get(i).getMatrixClocksCount());
      taskIdToMatrixClocksMap
        .put(taskMatrixClocks.get(i).getTaskId().getTaskIndex(), matrixIdToClockMap);
      List matrixClocks = taskMatrixClocks.get(i).getMatrixClocksList();
      matrixNum = matrixClocks.size();
      for (int j = 0; j < matrixNum; j++) {
        matrixIdToClockMap.put(matrixClocks.get(j).getMatrixId(), matrixClocks.get(j).getClock());
      }
    }

    return taskIdToMatrixClocksMap;
  }

  /**
   * Report PS run over successfully to Master
   *
   * @throws ServiceException
   */
  public void done() throws ServiceException {
    masterProxy.psDone(null, PSDoneRequest.newBuilder()
      .setPsAttemptId(ProtobufUtil.convertToIdProto(context.getPSAttemptId())).build());
  }

  /**
   * Report PS run failed to Master
   *
   * @param errorLog failed message
   * @throws ServiceException
   */
  public void failed(String errorLog) throws ServiceException {
    masterProxy.psError(null, PSErrorRequest.newBuilder()
      .setPsAttemptId(ProtobufUtil.convertToIdProto(context.getPSAttemptId())).setMsg(errorLog)
      .build());
  }

  /**
   * Register to Master
   *
   * @throws IOException
   * @throws ServiceException
   */
  public void register() throws IOException, ServiceException {
    PSRegisterRequest.Builder regBuilder = PSRegisterRequest.newBuilder();
    regBuilder.setPsAttemptId(ProtobufUtil.convertToIdProto(context.getPSAttemptId()));
    try {
      Location location =
        new Location(InetAddress.getLocalHost().getHostAddress(), context.getPsService().getPort());
      regBuilder.setLocation(ProtobufUtil.convertLocation(location));
    } catch (UnknownHostException eop) {
      LOG.error("UnknownHostException: " + eop);
      throw new IOException(eop);
    }

    masterProxy.psRegister(null, regBuilder.build());
  }

  /**
   * Heartbeat to Master
   *
   * @param request heartbeat message
   * @return heartbeat response
   * @throws ServiceException
   */
  public PSReportResponse psReport(PSReportRequest request) throws ServiceException {
    return masterProxy.psReport(null, request);
  }

  /**
   * Get a ps location from master
   *
   * @param serverId server id
   * @return PS location
   * @throws ServiceException
   */
  public Location getPsLocation(ParameterServerId serverId) throws ServiceException {
    MLProtos.GetPSLocationReponse response = masterProxy.getPSLocation(null,
      MLProtos.GetPSLocationRequest.newBuilder().setPsId(ProtobufUtil.convertToIdProto(serverId))
        .build());
    return ProtobufUtil.convertToLocation(response.getPsLocation());
  }

  /**
   * Get the stored pss and the locations for a matrix partition
   *
   * @param matrixId matrix id
   * @param partId   partition id
   * @return the stored pss and the locations
   * @throws ServiceException
   */
  public PartitionLocation getPartLocation(int matrixId, int partId) throws ServiceException {
    MLProtos.GetPartLocationResponse response = masterProxy.getPartLocation(null,
      MLProtos.GetPartLocationRequest.newBuilder().setMatrixId(matrixId).setPartId(partId).build());
    List psLocsProto = response.getLocationsList();

    int size = psLocsProto.size();
    List psLocs = new ArrayList<>(size);
    for (int i = 0; i < size; i++) {
      psLocs.add(new PSLocation(ProtobufUtil.convertToId(psLocsProto.get(i).getPsId()),
        ProtobufUtil.convertToLocation(psLocsProto.get(i))));
    }
    return new PartitionLocation(psLocs);
  }

  /**
   * Get the stored pss for a matrix partition
   *
   * @param matrixId    matrix id
   * @param partitionId partition id
   * @return the stored pss
   * @throws ServiceException
   */
  public List getStoredPss(int matrixId, int partitionId)
    throws ServiceException {
    List psIdProtos = masterProxy.getStoredPss(null,
      MLProtos.GetStoredPssRequest.newBuilder().setMatrixId(matrixId).setPartId(partitionId)
        .build()).getPsIdsList();
    int size = psIdProtos.size();
    List psIds = new ArrayList<>(psIdProtos.size());
    for (int i = 0; i < size; i++) {
      psIds.add(ProtobufUtil.convertToId(psIdProtos.get(i)));
    }
    return psIds;
  }

  /**
   * Get current iteration
   *
   * @return current iteration
   * @throws ServiceException
   */
  public int getIteration() throws ServiceException {
    return masterProxy.getIteration(null, GetIterationRequest.newBuilder().build()).getIteration();
  }

  /**
   * Get matrices meta for this ps
   *
   * @return
   * @throws ServiceException
   * @throws ClassNotFoundException
   */
  public List getMatricesMeta() throws ServiceException, ClassNotFoundException {
    GetPSMatricesResponse response = masterProxy.getPSMatricesMeta(null,
      GetPSMatricesMetaRequest.newBuilder()
        .setPsId(ProtobufUtil.convertToIdProto(context.getPSAttemptId().getPsId())).build());
    List matricesMataProto = response.getMatricesMetaList();
    int size = matricesMataProto.size();
    List matricesMeta = new ArrayList<>(size);
    for (int i = 0; i < size; i++) {
      matricesMeta.add(ProtobufUtil.convertToMatrixMeta(matricesMataProto.get(i)));
    }

    return matricesMeta;
  }

  /**
   * Notify save result
   *
   * @param result save result
   * @throws ServiceException
   */
  public void saveFinish(PSMatricesSaveResult result) throws ServiceException {
    masterProxy.saveFinish(null, SaveFinishRequest.newBuilder()
      .setPsAttemptId(ProtobufUtil.convertToIdProto(context.getPSAttemptId()))
      .setResult(ProtobufUtil.convert(result)).build());
  }

  /**
   * Notify load result
   *
   * @param result load result
   * @throws ServiceException
   */
  public void loadFinish(PSMatricesLoadResult result) throws ServiceException {
    masterProxy.loadFinish(null, LoadFinishRequest.newBuilder()
      .setPsAttemptId(ProtobufUtil.convertToIdProto(context.getPSAttemptId()))
      .setResult(ProtobufUtil.convert(result)).build());
  }

  /**
   * Notify master save start
   *
   * @param requestId    save request id
   * @param subRequestId save sub-request id
   * @throws ServiceException
   */
  public void saveStart(int requestId, int subRequestId) throws ServiceException {
    masterProxy.saveStart(null, SaveStartRequest.newBuilder()
      .setPsAttemptId(ProtobufUtil.convertToIdProto(context.getPSAttemptId()))
      .setRequestId(requestId).setSubRequestId(subRequestId).build());
  }

  /**
   * Notify master load start
   *
   * @param requestId    load request id
   * @param subRequestId load sub-request id
   * @throws ServiceException
   */
  public void loadStart(int requestId, int subRequestId) throws ServiceException {
    masterProxy.loadStart(null, LoadStartRequest.newBuilder()
      .setPsAttemptId(ProtobufUtil.convertToIdProto(context.getPSAttemptId()))
      .setRequestId(requestId).setSubRequestId(subRequestId).build());
  }

  /**
   * Get matrix save result contexts
   * @param matrixId matrix id
   * @return matrix save result contexts
   * @throws ServiceException
   */
  public List getSaveResult(int matrixId) throws ServiceException {
    GetSaveResultsResponse response = masterProxy
        .getSaveResults(null, GetSaveResultsRequest.newBuilder().setMatrixId(matrixId).build());
    List resultProtos = response.getSaveResultsList();
    if(resultProtos == null || resultProtos.isEmpty()) {
      return new ArrayList<>(0);
    }

    List results = new ArrayList<>(resultProtos.size());
    for(SaveResultProto resultProto : resultProtos) {
      results.add(ProtobufUtil.convert(resultProto));
    }

    return results;
  }

  /**
   * Get matrix checkpoint result contexts
   * @param matrixId matrix id
   * @return matrix checkpoint result contexts
   * @throws ServiceException
   */
  public List getCheckpoints(int matrixId) throws ServiceException {
    GetCheckpointsResponse response = masterProxy
        .getCheckpointResults(null, GetCheckpointsRequest.newBuilder().setMatrixId(matrixId).build());
    List resultProtos = response.getSaveResultsList();
    if(resultProtos == null || resultProtos.isEmpty()) {
      return new ArrayList<>(0);
    }

    List results = new ArrayList<>(resultProtos.size());
    for(SaveResultProto resultProto : resultProtos) {
      results.add(ProtobufUtil.convert(resultProto));
    }

    return results;
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy