All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.aliyun.odps.graph.local.TaskContextImpl Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package com.aliyun.odps.graph.local;

import java.io.BufferedInputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;

import org.apache.commons.io.IOUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

import com.aliyun.odps.Column;
import com.aliyun.odps.conf.Configuration;
import com.aliyun.odps.counter.Counter;
import com.aliyun.odps.counter.Counters;
import com.aliyun.odps.data.TableInfo;
import com.aliyun.odps.graph.Aggregator;
import com.aliyun.odps.graph.ComputeContext;
import com.aliyun.odps.graph.Edge;
import com.aliyun.odps.graph.GRAPH_CONF;
import com.aliyun.odps.graph.JobConf;
import com.aliyun.odps.graph.JobConf.JobState;
import com.aliyun.odps.graph.Vertex;
import com.aliyun.odps.graph.common.COMMON_GRAPH_CONF;
import com.aliyun.odps.graph.local.utils.LocalGraphRunUtils;
import com.aliyun.odps.graph.local.worker.Worker;
import com.aliyun.odps.graph.utils.VerifyUtils;
import com.aliyun.odps.io.Writable;
import com.aliyun.odps.io.WritableComparable;
import com.aliyun.odps.io.WritableRecord;
import com.aliyun.odps.local.common.TableMeta;
import com.aliyun.odps.local.common.WareHouse;
import com.aliyun.odps.local.common.utils.LocalRunUtils;
import com.aliyun.odps.local.common.utils.SchemaUtils;

@SuppressWarnings("rawtypes")
public class TaskContextImpl, VERTEX_VALUE extends Writable, EDGE_VALUE extends Writable, MESSAGE extends Writable>
    extends ComputeContext {

  private static Log LOG = LogFactory.getLog(TaskContextImpl.class);
  private List mAggregators;

  private Counters mCounters;
  private RuntimeContext mCtx;
  private JobConf mJob;
  private Map mOutputRecords;
  private Map mOutputs;

  private int mTaskNum;
  long mTotalNumEdges;
  long mTotalNumVertices;
  private Worker mWorker;
  private int mWorkerID;

  private Map mWriters;

  private int maxUserDefinedCountersNum = 64;
  private Map userCounters =
      new HashMap();

  public TaskContextImpl(RuntimeContext ctx, JobConf conf, Worker worker,
                         int workerID, int taskNum, Map outputs,
                         Counters counters) {
    mWorker = worker;
    mCtx = ctx;
    mJob = conf;
    mWorkerID = workerID;
    mTaskNum = taskNum;
    mOutputs = outputs;
    mCounters = counters;

    mAggregators = LocalGraphRunUtils.getAggregator(mJob);

    mOutputRecords = new HashMap();

    maxUserDefinedCountersNum =
        mJob.getInt(COMMON_GRAPH_CONF.JOB_MAX_USER_DEFINED_COUNTERS_NUM, 64);
  }

  @SuppressWarnings("unchecked")
  public LocalVertexMutations getRealVertexMutations(VERTEX_ID id) {
    if (mJob.getRuntimePartitioning()) {
      return mWorker.getMaster().getVertexMutations(id);
    } else {
      return mWorker.getVertexMutations(id);
    }

  }

  @SuppressWarnings("unchecked")
  @Override
  public void addVertexRequest(
      Vertex vertex)
      throws IOException {
    VerifyUtils.verifyVertex(vertex);
    LocalVertexMutations vertexMutations = getRealVertexMutations(vertex
                                                                      .getId());
    vertexMutations.addVertex(vertex);
  }

  @SuppressWarnings("unchecked")
  @Override
  public void addEdgeRequest(VERTEX_ID sourceVertexId,
                             Edge edge) throws IOException {
    VerifyUtils.verifyVertexId(sourceVertexId);
    VerifyUtils.verifyVertexEdge(edge);
    LocalVertexMutations vertexMutations = getRealVertexMutations(sourceVertexId);
    vertexMutations.addEdge(edge);

  }

  @SuppressWarnings("unchecked")
  @Override
  public void removeEdgeRequest(VERTEX_ID sourceVertexId,
                                VERTEX_ID targetVertexId) throws IOException {
    VerifyUtils.verifyVertexId(sourceVertexId);
    VerifyUtils.verifyVertexId(targetVertexId);
    LocalVertexMutations vertexMutations = getRealVertexMutations(sourceVertexId);
    vertexMutations.removeEdge(targetVertexId);
  }

  @Override
  public void removeVertexRequest(VERTEX_ID vertexId) throws IOException {
    VerifyUtils.verifyVertexId(vertexId);
    LocalVertexMutations vertexMutations = getRealVertexMutations(vertexId);
    vertexMutations.removeVertex();
  }

  @SuppressWarnings("unchecked")
  @Override
  public void sendMessage(VERTEX_ID destVertexID, MESSAGE msg)
      throws IOException {
    if (!mJob.getRuntimePartitioning()) {
      throw new RuntimeException(
          "ODPS-0730001: vertex partitioning disabled, cannot send message");
    }
    if (msg == null) {
      throw new IllegalArgumentException(
          "ODPS-0730001: sendMessage: Cannot send null message to "
          + destVertexID);
    }
    mWorker.getMaster().pushMsg(mCtx, getSuperstep() + 1, destVertexID, msg);
  }

  @Override
  public void sendMessage(Iterable destVertexIDs, MESSAGE msg)
      throws IOException {
    for (VERTEX_ID vertexId : destVertexIDs) {
      sendMessage(vertexId, msg);
    }
  }

  /**
   * 发送消息到给定点的所有邻接点,在下一个超步中会传给这些邻接点的
   * {@link Vertex#compute(ComputeContext, Iterable)} 方法进行处理.
   *
   * @param msg
   *     待发送的消息
   * @throws IOException
   */
  @Override
  public void sendMessageToNeighbors(
      Vertex vertex, MESSAGE msg)
      throws IOException {
    if (vertex.hasEdges()) {
      for (Edge edge : vertex.getEdges()) {
        sendMessage(edge.getDestVertexId(), msg);
      }
    }
  }

  @Override
  public void aggregate(Object item) throws IOException {
    for (int i = 0; i < mAggregators.size(); ++i) {
      aggregate(i, item);
    }

  }

  @SuppressWarnings("unchecked")
  @Override
  public void aggregate(int aggregatorIndex, Object value) throws IOException {
    List aggregatorValues = mWorker.getAggregatorValues();
    Writable aggValue = aggregatorValues.get(aggregatorIndex);
    mAggregators.get(aggregatorIndex).aggregate(aggValue, value);
  }

  private WritableRecord createOutputRecord(String label) throws IOException {
    Column[] cols = SchemaUtils.readSchema(mCtx.getOutputDir(label)).getCols();
    return new SQLRecord(cols);
  }


  @Override
  public Configuration getConfiguration() {
    return new JobConf(mJob, JobState.RUNNING);
  }

  @Override
  public Counter getCounter(Enum name) {
    if (name == null) {
      throw new RuntimeException("ODPS-0730001: Counter name must be not null.");
    }
    return getCounter(name.getDeclaringClass().getName(), name.toString());
  }

  @Override
  public Counter getCounter(String group, String name) {
    String key = group + "#" + name;
    if (userCounters.containsKey(key)) {
      return userCounters.get(key);
    }
    checkUserDefinedCounters(group, name);
    Counter counter = mCounters.findCounter(group, name);
    userCounters.put(key, counter);
    return counter;
  }

  @SuppressWarnings("unchecked")
  @Override
  public  VALUE getLastAggregatedValue(
      int aggregatorIndex) {
    return ((VALUE) mWorker.getLastAggregatedValue().get(aggregatorIndex));
  }

  @Override
  public long getMaxIteration() {
    return getConfiguration().getLong(GRAPH_CONF.MAX_ITERATION, -1);
  }

  @Override
  public int getNumWorkers() {
    return mTaskNum;
  }

  @Override
  public TableInfo getOutputTable() throws IOException {
    return getOutputTable("");
  }

  @Override
  public TableInfo getOutputTable(String label) throws IOException {
    return mOutputs.get(label);
  }

  @Override
  public long getSuperstep() {
    return mWorker.getMaster().getSuperStep();
  }

  @Override
  public long getTotalNumEdges() {
    return mTotalNumEdges;
  }

  @Override
  public long getTotalNumVertices() {
    return mTotalNumVertices;
  }

  @Override
  public Writable getWorkerValue() {
    return mWorker.getWorkerValue();
  }

  @Override
  public Writable getComputeValue() {
    return null;
  }

  @Override
  public int getWorkerId() {
    return mWorkerID;
  }

  @Override
  public long getWorkerNumEdges() {
    return mWorker.getEgeNumber();
  }

  @Override
  public long getWorkerNumVertices() {
    return mWorker.getVertexNumber();
  }

  @Override
  public void progress() {
    LOG.debug("Graph Local Mode Just Mock progress method. Not Calculate Time");
  }

  public void closeWriters() throws IOException {
    for (LocalRecordWriter writer : mWriters.values()) {
      writer.close();
    }
  }


  public void setOutputWriters(Map writers) {
    this.mWriters = writers;
  }

  public void setTotalNumEdges(long totalEdges) {
    mTotalNumEdges = totalEdges;
  }

  public void setTotalNumVertices(long totalVertices) {
    mTotalNumVertices = totalVertices;
  }

  @Override
  public void write(String label, Writable... fieldVals) throws IOException {
    LocalRecordWriter writer = mWriters.get(label);

    if (writer == null) {
      throw new IOException("The label " + label + " is not found in output");
    }

    if (mOutputRecords.get(label) == null) {
      mOutputRecords.put(label, createOutputRecord(label));
    }

    WritableRecord record = mOutputRecords.get(label);
    record.set(fieldVals);
    writer.write(record);

  }

  @Override
  public void write(Writable... fieldVals) throws IOException {
    write("", fieldVals);
  }

  private void checkUserDefinedCounters(String groupName,
                                        String counterName) {
    if (counterName == null || counterName.isEmpty()) {
      throw new RuntimeException(
          "ODPS-0730001: CounterName must be not null or empty.");
    }
    if (groupName == null || groupName.isEmpty()) {
      throw new RuntimeException(
          "ODPS-0730001: groupName must be not null or empty.");
    }

    if (groupName.contains("#")) {
      throw new RuntimeException("ODPS-0730001: Group name: " + groupName
                                 + " is invalid, It should not contain '#'");
    }
    if (counterName.contains("#")) {
      throw new RuntimeException("ODPS-0730001: Counter name: " + counterName
                                 + " is invalid, It should not contain '#'");
    }

    int maxLength = 100;

    if (groupName.length() + counterName.length() > maxLength) {
      throw new RuntimeException("ODPS-0730001: Group name '" + groupName
                                 + "' and Counter name '" + counterName
                                 + "' is too long, sum of their length must <= " + maxLength);
    }
    if (userCounters.size() >= maxUserDefinedCountersNum) {
      throw new RuntimeException(
          "ODPS-0730001: Total num of user defined counters is too more, must be <= "
          + maxUserDefinedCountersNum);
    }
  }

  @Override
  public byte[] readCacheFile(String resourceName) throws IOException {
    return IOUtils.toByteArray(readCacheFileAsStream(resourceName));
  }

  @Override
  public BufferedInputStream readCacheFileAsStream(String resourceName)
      throws IOException {
    File file = new File(mCtx.getResourceDir(), resourceName);
    return new BufferedInputStream(new FileInputStream(file));
  }

  @Override
  public Iterable readCacheArchive(String resourceName)
      throws IOException {
    return readCacheArchive(resourceName, "");
  }

  @Override
  public Iterable readCacheArchive(String resourceName,
                                           String relativePath) throws IOException {
    File baseDir = new File(mCtx.getResourceDir(), resourceName);
    File dir = new File(baseDir, relativePath);
    File[] files = dir.listFiles();
    final List list = new ArrayList();
    for (File file : files) {
      list.add(IOUtils.toByteArray(
          new BufferedInputStream(new FileInputStream(file))));
    }
    return new Iterable() {
      @Override
      public Iterator iterator() {
        return list.iterator();
      }
    };
  }

  @Override
  public Iterable readCacheArchiveAsStream(
      String resourceName) throws IOException {
    return readCacheArchiveAsStream(resourceName, "");
  }

  @Override
  public Iterable readCacheArchiveAsStream(
      String resourceName, String relativePath) throws IOException {
    File baseDir = new File(mCtx.getResourceDir(), resourceName);
    File dir = new File(baseDir, relativePath);
    File[] files = dir.listFiles();
    final List list = new ArrayList();
    for (File file : files) {
      list.add(new BufferedInputStream(new FileInputStream(file)));
    }
    return new Iterable() {
      @Override
      public Iterator iterator() {
        return list.iterator();
      }
    };
  }

  @Override
  public Iterable readResourceTable(String resourceName)
      throws IOException {
    final File tableDir = new File(mCtx.getResourceDir(), resourceName);
    if (!tableDir.exists()) {
      throw new RuntimeException("resource " + resourceName + " not found!");
    }
    if (tableDir.isFile()) {
      throw new RuntimeException("resource " + resourceName + " is not a table resource!");
    }
    final List dataFiles = new ArrayList();
    LocalRunUtils.listAllDataFiles(tableDir, dataFiles);

    return new Iterable() {
      @Override
      public Iterator iterator() {
        return new WrappedRecordIterator(tableDir, dataFiles);
      }

    };
  }

  private class WrappedRecordIterator implements Iterator {

    LocalRecordReader reader;
    WritableRecord current;
    boolean fetched;
    Iterator fileIter;
    File tableDir;

    WrappedRecordIterator(File tableDir, List dataFiles) {
      this.tableDir = tableDir;
      this.fileIter = dataFiles.iterator();
    }

    @Override
    public boolean hasNext() {
      if (fetched) {
        return current != null;
      }
      // Fetch new one
      try {
        fetch();
      } catch (IOException e) {
        throw new RuntimeException(e);
      }
      return current != null;
    }

    private void fetch() throws IOException {
      fetched = true;
      while (true) {
        // first time or next reader
        if (reader == null) {
          if (!fileIter.hasNext()) {
            current = null;
            return;
          }
          File dataFile = fileIter.next();
          reader = new LocalRecordReader(dataFile.getParentFile(),
                                         tableDir, null, null);
        }
        // read next
        if (reader.nextKeyValue()) {
          current = reader.getCurrentValue();
          return;
        } else {
          reader = null;
          continue;
        }
      }
    }

    @Override
    public WritableRecord next() {
      if (!hasNext()) {
        throw new NoSuchElementException();
      }
      fetched = false;
      return current;
    }

    @Override
    public void remove() {
      throw new UnsupportedOperationException();
    }

  }

  @Override
  public TableInfo getResourceTable(String resourceName) throws IOException {
    File dir = new File(mCtx.getResourceDir(), resourceName);
    if (!dir.exists()) {
      throw new RuntimeException("resource " + resourceName + " not found!");
    }
    if (dir.isFile()) {
      throw new RuntimeException("resource " + resourceName + " is not a table resource!");
    }
    TableMeta meta = SchemaUtils.readSchema(dir);
    return WareHouse.getInstance().getReferencedTable(meta.getProjName(), resourceName);
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy