All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.aliyun.odps.graph.local.worker.Worker Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package com.aliyun.odps.graph.local.worker;

import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;

import org.apache.commons.io.FileUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

import com.aliyun.odps.counter.Counter;
import com.aliyun.odps.counter.Counters;
import com.aliyun.odps.data.TableInfo;
import com.aliyun.odps.graph.Aggregator;
import com.aliyun.odps.graph.Combiner;
import com.aliyun.odps.graph.GraphLoader;
import com.aliyun.odps.graph.JobConf;
import com.aliyun.odps.graph.Vertex;
import com.aliyun.odps.graph.VertexResolver;
import com.aliyun.odps.graph.WorkerComputer;
import com.aliyun.odps.graph.WorkerContext;
import com.aliyun.odps.graph.local.BaseRecordReader;
import com.aliyun.odps.graph.local.COUNTER;
import com.aliyun.odps.graph.local.EmptyRecordReader;
import com.aliyun.odps.graph.local.GraphTaskAttemptID;
import com.aliyun.odps.graph.local.InputSplit;
import com.aliyun.odps.graph.local.LocalRecordReader;
import com.aliyun.odps.graph.local.LocalRecordWriter;
import com.aliyun.odps.graph.local.LocalVertexMutations;
import com.aliyun.odps.graph.local.RuntimeContext;
import com.aliyun.odps.graph.local.SQLRecord;
import com.aliyun.odps.graph.local.TaskContextImpl;
import com.aliyun.odps.graph.local.master.Master;
import com.aliyun.odps.graph.local.message.MsgManager;
import com.aliyun.odps.graph.local.utils.LocalGraphRunUtils;
import com.aliyun.odps.graph.utils.VerifyUtils;
import com.aliyun.odps.io.LongWritable;
import com.aliyun.odps.io.NullWritable;
import com.aliyun.odps.io.Writable;
import com.aliyun.odps.io.WritableComparable;
import com.aliyun.odps.io.WritableRecord;
import com.aliyun.odps.io.WritableUtils;
import com.aliyun.odps.utils.ReflectionUtils;

@SuppressWarnings("rawtypes")
public class Worker, VERTEX_VALUE extends Writable, EDGE_VALUE extends Writable, MESSAGE extends Writable, VALUE extends Writable> {

  private static Log LOG = LogFactory.getLog(Worker.class);

  private List mAggregators;
  private List mAggregatorValues;

  private Master master;
  private Counters mCounters;
  private RuntimeContext mCtx;
  private InputSplit mInput;
  private JobConf mJob;
  private List mLastAggregatorValues;
  private Map>
      mLastStepMessage =
      new HashMap>();
  private Map mOutputs;
  private Map mVertexMutations;
  private GraphTaskAttemptID mTaskAttemptID;
  private TaskContextImpl mTaskContext;
  private WorkerComputer mWorkerComputer;

  private int mWorkerID;
  private int mWorkerNum;
  private Map mWriters;

  private Map vertices = new HashMap();

  private Writable mWorkerValue;

  private MsgManager mMsgManager;

  private Combiner mCombiner;

  @SuppressWarnings("unchecked")
  public Worker(JobConf job, RuntimeContext ctx, Master m,
                GraphTaskAttemptID taskAttemptID, int workerID, int workerNum,
                InputSplit input, Map outputs)
      throws InstantiationException, IllegalAccessException, IOException,
             ClassNotFoundException {

    mJob = job;
    mCtx = ctx;
    master = m;
    mTaskAttemptID = taskAttemptID;
    mWorkerID = workerID;
    mWorkerNum = workerNum;

    mInput = input;
    mOutputs = outputs;
    mCounters = new Counters();

    mAggregators = LocalGraphRunUtils.getAggregator(mJob);
    mMsgManager = new MsgManager();

    mTaskContext = new TaskContextImpl(mCtx, mJob, this, mWorkerID, mWorkerNum,
                                       mOutputs, mCounters);
    mVertexMutations = new HashMap();

  }

  public boolean allVertexVoltHalt() {
    boolean ret = true;
    for (Vertex v : vertices.values()) {
      ret = ret && v.isHalted();
    }
    return ret && (!mMsgManager.hasNextStepMessages());
  }

  private void initCombiner() {
    Class combinerCls = mJob.getCombinerClass();
    if (combinerCls != null) {
      try {
        mCombiner = combinerCls.newInstance();
        mCombiner.configure(mJob);
      } catch (Exception e) {
        throw new RuntimeException(
            "exception occored when Instantiate combiner ", e);
      }
    } else {
      mCombiner = null;
    }

  }

  @SuppressWarnings("unchecked")
  private void initWorkerComputer() throws IOException {
    try {
      Class workerComputerClass = mJob
          .getWorkerComputerClass();

      mWorkerComputer = workerComputerClass.newInstance();

      List> workerValueClass = ReflectionUtils
          .getTypeArguments(WorkerComputer.class,
                                            workerComputerClass);

      if (workerValueClass.size() > 1) {
        throw new IOException("more than ONE workerValue Type Declared");
      } else if (workerValueClass.size() == 0) {
        mWorkerValue = NullWritable.get();
      } else if (workerValueClass.get(0) == null) {
        mWorkerValue = NullWritable.get();
      } else {
        mWorkerValue = (Writable) ReflectionUtils.newInstance(workerValueClass.get(0), mJob);
      }
      mWorkerComputer.setup(mTaskContext, mWorkerValue);
    } catch (Exception e) {
      throw new IOException(e);
    }
  }

  private void initAggregatorValues() throws IOException {
    mLastAggregatorValues = new ArrayList();
    for (Aggregator agg : mAggregators) {
      mLastAggregatorValues.add(agg.createStartupValue(mTaskContext));
    }
  }

  @SuppressWarnings("unchecked")
  public void loadGraph() throws IOException {

    Class loaderClass = mJob.getGraphLoaderClass();

    GraphLoader graphLoader = ReflectionUtils.newInstance(loaderClass, mJob);

    graphLoader.setup(mTaskContext.getConfiguration(),
                      mTaskContext.getWorkerId(), mInput.getTable(), mTaskContext);

    BaseRecordReader reader;

    if (mInput != null && mInput.getTable() != null) {
      String fullName = mInput.getTable().getProjectName() + "."
                        + mInput.getTable().getTableName();
      //File inputDir = mCtx.getInputDir(fullName, mInput.getPartSpec());
      File tableInfo = mInput.getFile().getParentFile();
      while (!tableInfo.getName().equals(mInput.getTable().getTableName())) {
        tableInfo = tableInfo.getParentFile();
      }
      reader = new LocalRecordReader(mInput.getFile().getParentFile(), tableInfo,
                                     mCounters.findCounter(COUNTER.TASK_INPUT_RECORD),
                                     mCounters.findCounter(COUNTER.TASK_INPUT_BYTE));
    } else {
      reader = new EmptyRecordReader();
    }

    mWriters = new HashMap();
    for (String label : mOutputs.keySet()) {
      Counter outputRecordCounter = mCounters
          .findCounter(COUNTER.TASK_OUTPUT_RECORD);
      Counter outputByteCounter = mCounters
          .findCounter(COUNTER.TASK_OUTPUT_BYTE);
      LocalRecordWriter writer = new LocalRecordWriter(new File(mCtx.getOutputDir(label),
                                                                this.mTaskAttemptID.toString()),
                                                       outputRecordCounter,
                                                       outputByteCounter);
      mWriters.put(label, writer);
    }

    while (reader.nextKeyValue()) {
      LongWritable recordNum = new LongWritable();
      recordNum.set((reader.getCurrentKey().get()));
      graphLoader.load(recordNum, ((SQLRecord) reader.getCurrentValue()).clone(),
                       mTaskContext);
    }

    reader.close();

    mTaskContext.setOutputWriters(mWriters);
  }

  @SuppressWarnings("unchecked")
  public void init() throws IOException {
    initCombiner();
    initWorkerComputer();
    initAggregatorValues();

    for (Vertex v : vertices.values()) {
      v.setup(mTaskContext);
    }

  }

  public void processNextStep() throws IOException {
    mAggregatorValues = new ArrayList();
    for (int i = 0; i < mAggregators.size(); ++i) {
      Writable initAggregatorValue = mAggregators.get(i)
          .createInitialValue(mTaskContext);
      if (initAggregatorValue == null) {
        throw new RuntimeException("ODPS-0730001: " + mAggregators.get(i).getClass().getName()
                                   + " createInitialValue return null");
      }
      mAggregatorValues.add(initAggregatorValue);
    }
    mMsgManager.nextSuperStep(mCtx);
  }

  @SuppressWarnings("unchecked")
  public void Compute() throws IOException {

    prepareMsg();
    for (Vertex v : vertices.values()) {
      Iterable msg = mLastStepMessage.get(v);
      if (v.isHalted() && msg.iterator().hasNext()) {
        v.wakeUp();
      }

      if (!v.isHalted()) {
        v.compute(mTaskContext, msg);
      }
    }
  }

  @SuppressWarnings("unchecked")
  public void processMutation(VERTEX_ID id, LocalVertexMutations mutations,
                              VertexResolver vertexResolver) throws IOException {
    Vertex v = vertices.get(id);
    boolean hasMessage = mMsgManager.hasMessageForVertex(mCtx,
                                                         master.getSuperStep(), id);
    if (vertexResolver == null) {
      throw new IOException(
          "ODPS-0730001: encounter mutations in compute but not set the mutation resolver.");
    }
    Vertex new_v = vertexResolver.resolve(id, v, mutations, hasMessage);
    if (new_v == null) {
      vertices.remove(id);
    } else {
      VerifyUtils.verifyVertex(new_v);
      vertices.put(id, new_v);
    }
  }

  @SuppressWarnings("unchecked")
  public void processWorkerMutations(VertexResolver vertexResolver)
      throws IOException {
    HashSet mutationIDs = new HashSet();
    mutationIDs.addAll(mVertexMutations.keySet());

    for (WritableComparable id : mMsgManager.getVertexIDList()) {
      if (vertices.get(id) == null) {
        mutationIDs.add((VERTEX_ID) id);
      }
    }

    for (VERTEX_ID id : mutationIDs) {
      processMutation(id, mVertexMutations.get(id), vertexResolver);
    }
    mVertexMutations = new HashMap();
  }

  public List getAggregatorValues() {
    return mAggregatorValues;
  }

  public Counters getCounters() {
    return mCounters;
  }

  public long getEgeNumber() {
    long egeNumber = 0;
    for (Vertex v : vertices.values()) {
      egeNumber += v.getNumEdges();
    }
    return egeNumber;
  }

  public List getLastAggregatedValue() {
    return mLastAggregatorValues;
  }

  public Master getMaster() {
    return master;
  }

  public WorkerContext getTaskContext() {
    return mTaskContext;
  }

  public long getVertexNumber() {
    return vertices.size();
  }

  public LocalVertexMutations getVertexMutations(VERTEX_ID id) {
    LocalVertexMutations ret = mVertexMutations.get(id);
    if (ret == null) {
      ret = new LocalVertexMutations();
      mVertexMutations.put(id, ret);
    }
    return ret;
  }

  public List partialAggregate() {
    return mAggregatorValues;
  }

  public void pushMsg(RuntimeContext context, long superStep,
                      WritableComparable vertexId, Writable msg) {
    mMsgManager.pushMsg(context, superStep, vertexId, msg);
  }

  public void close() throws IOException {
    for (LocalRecordWriter writer : mWriters.values()) {
      writer.close();
    }
    mWriters = null;

    FileUtils.writeStringToFile(
        new File(mCtx.getCounterDir(), String.valueOf(this.mTaskAttemptID
                                                          .getTaskId())), mCounters.toString());
    LOG.debug(mCounters);
  }

  @SuppressWarnings("unchecked")
  public void cleanup() throws IOException {
    for (Vertex v : vertices.values()) {
      v.cleanup(mTaskContext);
    }
    mWorkerComputer.cleanup(mTaskContext);
  }

  @SuppressWarnings("unchecked")
  private Iterable combineMsg(WritableComparable id,
                                        Iterable msgs) throws IOException {
    if (mCombiner != null) {
      Writable combineMsg = null;
      for (Writable msg : msgs) {
        if (combineMsg == null) {
          combineMsg = msg;
        } else {
          mCombiner.combine(id, combineMsg, msg);
        }
      }
      ArrayList combinedMsgs = new ArrayList();
      if (combineMsg != null) {
        combinedMsgs.add(combineMsg);
      }
      return combinedMsgs;
    } else {
      return msgs;
    }
  }

  private void prepareMsg() throws IOException {
    long superStep = master.getSuperStep();
    LOG.debug("worker super step " + superStep + ", vertices count "
              + vertices.size());
    mLastStepMessage.clear();
    for (Vertex v : vertices.values()) {
      Iterable msgs = mMsgManager
          .popMsges(mCtx, superStep, v.getId());
      if (mCombiner != null) {
        msgs = combineMsg(v.getId(), msgs);
      }
      mLastStepMessage.put(v, msgs);
    }

  }

  public void setLastAggregatedValue(List lastAggrValues) {
    mLastAggregatorValues = new ArrayList(lastAggrValues.size());
    for (int i = 0; i < lastAggrValues.size(); ++i) {
      Writable value = null;
      if (lastAggrValues.get(i) != null) {
        value = WritableUtils.clone(lastAggrValues.get(i), mJob);
      }
      mLastAggregatorValues.add(value);
    }
  }

  public void setTotalNumVerticesAndEdges(int totalVertices, int totalEdge) {
    mTaskContext.setTotalNumVertices(totalVertices);
    mTaskContext.setTotalNumEdges(totalEdge);
  }

  public Writable getWorkerValue() {
    return mWorkerValue;
  }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy