All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.microsoft.reef.io.network.group.config.AllReduceManager Maven / Gradle / Ivy

There is a newer version: 0.9
Show newest version
/*
 * Copyright 2013 Microsoft.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.microsoft.reef.io.network.group.config;

import com.microsoft.reef.driver.task.TaskConfigurationOptions;
import com.microsoft.reef.exception.evaluator.NetworkException;
import com.microsoft.reef.io.network.Connection;
import com.microsoft.reef.io.network.Message;
import com.microsoft.reef.io.network.group.impl.GCMCodec;
import com.microsoft.reef.io.network.group.impl.GroupCommNetworkHandler;
import com.microsoft.reef.io.network.group.impl.operators.faulty.AllReduceConfig;
import com.microsoft.reef.io.network.group.impl.operators.faulty.AllReduceHandler;
import com.microsoft.reef.io.network.group.impl.operators.faulty.ExceptionHandler;
import com.microsoft.reef.io.network.group.operators.Reduce.ReduceFunction;
import com.microsoft.reef.io.network.impl.MessagingTransportFactory;
import com.microsoft.reef.io.network.impl.NetworkService;
import com.microsoft.reef.io.network.impl.NetworkServiceParameters;
import com.microsoft.reef.io.network.naming.NameServerParameters;
import com.microsoft.reef.io.network.proto.ReefNetworkGroupCommProtos.GroupCommMessage;
import com.microsoft.reef.io.network.proto.ReefNetworkGroupCommProtos.GroupCommMessage.Type;
import com.microsoft.reef.io.network.util.StringIdentifierFactory;
import com.microsoft.reef.io.network.util.Utils;
import com.microsoft.tang.Configuration;
import com.microsoft.tang.JavaConfigurationBuilder;
import com.microsoft.tang.Tang;
import com.microsoft.tang.exceptions.BindException;
import com.microsoft.wake.ComparableIdentifier;
import com.microsoft.wake.EventHandler;
import com.microsoft.wake.Identifier;
import com.microsoft.wake.impl.LoggingEventHandler;
import com.microsoft.wake.impl.SingleThreadStage;
import com.microsoft.wake.remote.Codec;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/**
 *
 */
public class AllReduceManager {

  /**
   * TANG instance
   */
  private static final Tang tang = Tang.Factory.getTang();

  private Configuration allRedBaseConf;


  /**
   * Common configs
   */
  private Class> dataCodecClass;
  private Class> redFuncClass;

  /**
   * {@link NetworkService} related configs
   */
  private final String nameServiceAddr;
  private final int nameServicePort;
  private Map id2port;
  private final NetworkService ns;
  private final StringIdentifierFactory idFac = new StringIdentifierFactory();
  private final ComparableIdentifier driverId = (ComparableIdentifier) idFac.getNewInstance("driver");


  private Map taskIdMap;
  private final ComparableIdentifier[] tasks;
  private final int numTasks;
  private int runningTasks;

  /**
   * @param dataCodec
   * @param redFunc
   * @param nameServiceAddr
   * @param nameServicePort
   * @param id2port
   * @throws BindException
   */
  public AllReduceManager(Class> dataCodec, Class> redFunc,
                          String nameServiceAddr, int nameServicePort,
                          Map id2port) throws BindException {
    dataCodecClass = dataCodec;
    redFuncClass = redFunc;
    this.nameServiceAddr = nameServiceAddr;
    this.nameServicePort = nameServicePort;
    this.id2port = id2port;
    taskIdMap = new HashMap();
    int i = 1;
    tasks = new ComparableIdentifier[id2port.size() + 1];
    for (ComparableIdentifier id : id2port.keySet()) {
      tasks[i] = id;
      taskIdMap.put(id, i++);
    }
    numTasks = tasks.length - 1;
    runningTasks = numTasks;
    JavaConfigurationBuilder jcb = tang.newConfigurationBuilder();
    jcb.bindNamedParameter(AllReduceConfig.DataCodec.class, dataCodecClass);
    jcb.bindNamedParameter(AllReduceConfig.ReduceFunction.class, redFuncClass);
    jcb.bindNamedParameter(NetworkServiceParameters.NetworkServiceCodec.class,
        GCMCodec.class);
    jcb.bindNamedParameter(
        NetworkServiceParameters.NetworkServiceHandler.class,
        AllReduceHandler.class);
    jcb.bindNamedParameter(
        NetworkServiceParameters.NetworkServiceExceptionHandler.class,
        ExceptionHandler.class);
    jcb.bindNamedParameter(NameServerParameters.NameServerAddr.class,
        nameServiceAddr);
    jcb.bindNamedParameter(NameServerParameters.NameServerPort.class,
        Integer.toString(nameServicePort));
    allRedBaseConf = jcb.build();

    ns = new NetworkService<>(
        idFac, 0, nameServiceAddr, nameServicePort, new GCMCodec(),
        new MessagingTransportFactory(), new LoggingEventHandler>(),
        new LoggingEventHandler());
  }

  /**
   * @param taskId
   * @return
   */
  public synchronized double estimateVarInc(final ComparableIdentifier taskId) {
    double actVarDrop = 1.0 / numTasks;
    int childrenLost = getChildren(taskId) + 1;
    double curVarDrop = 1.0 / (runningTasks - childrenLost);
    return (curVarDrop / actVarDrop) - 1;
  }

  /**
   * @param taskId
   * @return
   */
  private synchronized int getChildren(final ComparableIdentifier taskId) {
    int idx = taskIdMap.get(taskId);
    int leftChildren, rightChildren;
    if (leftChild(idx) > numTasks) {
      return 0;
    } else {
      leftChildren = getChildren(tasks[leftChild(idx)]) + 1;
      if (rightChild(idx) > numTasks) {
        return leftChildren;
      } else {
        rightChildren = getChildren(tasks[rightChild(idx)]) + 1;
      }
    }
    return leftChildren + rightChildren;
  }


  /**
   * @param failedTaskId
   * @throws NetworkException
   */
  public synchronized void remove(final ComparableIdentifier failedTaskId) {

    System.out.println("All Reduce Manager removing " + failedTaskId);
    final ComparableIdentifier from = failedTaskId;
    final ComparableIdentifier to = tasks[parent(taskIdMap.get(failedTaskId))];

    final SingleThreadStage senderStage = new SingleThreadStage<>("SrcDeadMsgSender", new EventHandler() {

      @Override
      public void onNext(GroupCommMessage srcDeadMsg) {
        Connection link = ns.newConnection(to);
        try {
          link.open();
          System.out.println("Sending source dead msg " + srcDeadMsg + " to parent " + to);
          link.write(srcDeadMsg);
        } catch (NetworkException e) {
          e.printStackTrace();
          throw new RuntimeException("Unable to send failed task msg to parent of " + to, e);
        }
      }
    }, 5);

    final GroupCommMessage srcDeadMsg = Utils.bldGCM(Type.SourceDead, from, to, new byte[0]);
    senderStage.onNext(srcDeadMsg);
    --runningTasks;
  }

  /**
   * @return
   */
  public List getReceivers() {
    List retVal = new ArrayList<>();
    int end = (numTasks == 1) ? 1 : parent(numTasks);
    for (int i = 1; i <= end; i++)
      retVal.add(tasks[i]);
    return retVal;
  }

  /**
   * @return
   */
  public List getSenders() {
    List retVal = new ArrayList<>();
    int start = (numTasks == 1) ? 1 : parent(numTasks);
    for (int i = start + 1; i <= numTasks; i++)
      retVal.add(tasks[i]);
    return retVal;
  }

  private int parent(int i) {
    return i >> 1;
  }

  private int leftChild(int i) {
    return i << 1;
  }

  private int rightChild(int i) {
    return (i << 1) + 1;
  }

  /**
   * @param taskId
   * @return
   * @throws BindException
   */
  public Configuration getConfig(final ComparableIdentifier taskId) throws BindException {
    JavaConfigurationBuilder jcb = tang.newConfigurationBuilder(allRedBaseConf);
    jcb.bindNamedParameter(AllReduceConfig.SelfId.class, taskId.toString());
    List ids = new ArrayList<>();
    int idx = taskIdMap.get(taskId);
    if (idx != 1) {
      ComparableIdentifier par = tasks[parent(idx)];
      ids.add(par);
      jcb.bindNamedParameter(AllReduceConfig.ParentId.class, par.toString());
    }

    int lcId = leftChild(idx);
    if (lcId <= numTasks) {
      ComparableIdentifier lc = tasks[lcId];
      ids.add(lc);
      jcb.bindSetEntry(AllReduceConfig.ChildIds.class, lc.toString());
      int rcId = rightChild(idx);
      if (rcId <= numTasks) {
        ComparableIdentifier rc = tasks[rcId];
        ids.add(rc);
        jcb.bindSetEntry(AllReduceConfig.ChildIds.class, rc.toString());
      }
    }

    jcb.addConfiguration(createNetworkServiceConf(nameServiceAddr, nameServicePort, taskId, ids, id2port.get(taskId)));
    return jcb.build();
  }

  /**
   * Create {@link Configuration} for {@link GroupCommNetworkHandler}
   * using base conf + list of identifiers
   *
   * @param ids
   * @return
   * @throws BindException
   */
  private Configuration createHandlerConf(
      List ids) throws BindException {
    JavaConfigurationBuilder jcb = tang
        .newConfigurationBuilder();
    for (ComparableIdentifier comparableIdentifier : ids) {
      jcb.bindSetEntry(AllReduceHandler.IDs.class, comparableIdentifier.toString());
    }
    return jcb.build();
  }

  /**
   * Create {@link NetworkService} {@link Configuration} for each task
   * using base conf + per task parameters
   *
   * @param nameServiceAddr
   * @param nameServicePort
   * @param self
   * @param ids
   * @param nsPort
   * @return per task {@link NetworkService} {@link Configuration} for the specified task
   * @throws BindException
   */
  private Configuration createNetworkServiceConf(
      String nameServiceAddr, int nameServicePort, Identifier self,
      List ids, int nsPort) throws BindException {
    JavaConfigurationBuilder jcb = tang
        .newConfigurationBuilder();

    jcb.bindNamedParameter(TaskConfigurationOptions.Identifier.class, self.toString());
    jcb.bindNamedParameter(
        NetworkServiceParameters.NetworkServicePort.class,
        Integer.toString(nsPort));

    jcb.addConfiguration(createHandlerConf(ids));
    return jcb.build();
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy