All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.crunch.impl.spark.SparkRuntimeContext Maven / Gradle / Ivy

There is a newer version: 1.0.0
Show newest version
/**
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.crunch.impl.spark;

import com.google.common.base.Joiner;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import javassist.util.proxy.MethodFilter;
import javassist.util.proxy.MethodHandler;
import javassist.util.proxy.ProxyFactory;
import org.apache.crunch.CrunchRuntimeException;
import org.apache.crunch.DoFn;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.filecache.DistributedCache;
import org.apache.hadoop.mapreduce.Counter;
import org.apache.hadoop.mapreduce.OutputCommitter;
import org.apache.hadoop.mapreduce.RecordWriter;
import org.apache.hadoop.mapreduce.StatusReporter;
import org.apache.hadoop.mapreduce.TaskAttemptID;
import org.apache.hadoop.mapreduce.TaskInputOutputContext;
import org.apache.spark.Accumulator;
import org.apache.spark.SparkFiles;
import org.apache.spark.broadcast.Broadcast;

import java.io.File;
import java.io.IOException;
import java.io.Serializable;
import java.lang.reflect.Method;
import java.net.URI;
import java.util.List;
import java.util.Map;
import java.util.Set;

public class SparkRuntimeContext implements Serializable {

  private Broadcast broadConf;
  private Accumulator> counters;
  private transient TaskInputOutputContext context;

  public SparkRuntimeContext(
      Broadcast broadConf,
      Accumulator> counters) {
    this.broadConf = broadConf;
    this.counters = counters;
  }

  public void initialize(DoFn fn) {
    if (context == null) {
      configureLocalFiles();
      context = getTaskIOContext(broadConf, counters);
    }
    fn.setContext(context);
    fn.initialize();
  }

  private void configureLocalFiles() {
    try {
      URI[] uris = DistributedCache.getCacheFiles(getConfiguration());
      if (uris != null) {
        List allFiles = Lists.newArrayList();
        for (URI uri : uris) {
          File f = new File(uri.getPath());
          String sparkFile = SparkFiles.get(f.getName());
          allFiles.add(SparkFiles.get(f.getName()));
        }
        String sparkFiles = Joiner.on(',').join(allFiles);
        // Hacking this for Hadoop1 and Hadoop2
        getConfiguration().set("mapreduce.job.cache.local.files", sparkFiles);
        getConfiguration().set("mapred.cache.localFiles", sparkFiles);
      }
    } catch (IOException e) {
      throw new CrunchRuntimeException(e);
    }
  }

  public Configuration getConfiguration() {
    return broadConf.value();
  }

  public static TaskInputOutputContext getTaskIOContext(
      final Broadcast conf,
      final Accumulator> counters) {
    ProxyFactory factory = new ProxyFactory();
    Class superType = TaskInputOutputContext.class;
    Class[] types = new Class[0];
    Object[] args = new Object[0];
    final TaskAttemptID taskAttemptId = new TaskAttemptID();
    if (superType.isInterface()) {
      factory.setInterfaces(new Class[] { superType });
    } else {
      types = new Class[] { Configuration.class, TaskAttemptID.class, RecordWriter.class, OutputCommitter.class,
          StatusReporter.class };
      args = new Object[] { conf.value(), taskAttemptId, null, null, null };
      factory.setSuperclass(superType);
    }

    final Set handledMethods = ImmutableSet.of("getConfiguration", "getCounter",
        "progress", "getTaskAttemptID");
    factory.setFilter(new MethodFilter() {
      @Override
      public boolean isHandled(Method m) {
        return handledMethods.contains(m.getName());
      }
    });
    MethodHandler handler = new MethodHandler() {
      @Override
      public Object invoke(Object arg0, Method m, Method arg2, Object[] args) throws Throwable {
        String name = m.getName();
        if ("getConfiguration".equals(name)) {
          return conf.value();
        } else if ("progress".equals(name)) {
          // no-op
          return null;
        } else if ("getTaskAttemptID".equals(name)) {
          return taskAttemptId;
        } else if ("getCounter".equals(name)){ // getCounter
          if (args.length == 1) {
            return getCounter(counters, args[0].getClass().getName(), ((Enum) args[0]).name());
          } else {
            return getCounter(counters, (String) args[0], (String) args[1]);
          }
        } else {
          throw new IllegalStateException("Unhandled method " + name);
        }
      }
    };

    try {
      Object newInstance = factory.create(types, args, handler);
      return (TaskInputOutputContext) newInstance;
    } catch (Exception e) {
      e.printStackTrace();
      throw new RuntimeException(e);
    }
  }

  private static Counter getCounter(final Accumulator> accum, final String group,
                                    final String counterName) {
    ProxyFactory factory = new ProxyFactory();
    Class superType = Counter.class;
    Class[] types = new Class[0];
    Object[] args = new Object[0];
    if (superType.isInterface()) {
      factory.setInterfaces(new Class[] { superType });
    } else {
      types = new Class[] { String.class, String.class };
      args = new Object[] { group, counterName };
      factory.setSuperclass(superType);
    }

    final Set handledMethods = ImmutableSet.of("getDisplayName", "getName",
        "getValue", "increment", "setValue", "setDisplayName");
    factory.setFilter(new MethodFilter() {
      @Override
      public boolean isHandled(Method m) {
        return handledMethods.contains(m.getName());
      }
    });
    MethodHandler handler = new MethodHandler() {
      @Override
      public Object invoke(Object arg0, Method m, Method arg2, Object[] args) throws Throwable {
        String name = m.getName();
        if ("increment".equals(name)) {
          accum.add(ImmutableMap.of(group + ":" + counterName, (Long) args[0]));
          return null;
        } else if ("getDisplayName".equals(name)) {
          return counterName;
        } else if ("getName".equals(name)) {
          return counterName;
        } else if ("setDisplayName".equals(name)) {
          // No-op
          return null;
        } else if ("setValue".equals(name)) {
          throw new UnsupportedOperationException("Cannot set counter values in Spark, only increment them");
        } else if ("getValue".equals(name)) {
          throw new UnsupportedOperationException("Cannot read counters during Spark execution");
        } else {
          throw new IllegalStateException("Unhandled method " + name);
        }
      }
    };
    try {
      Object newInstance = factory.create(types, args, handler);
      return (Counter) newInstance;
    } catch (Exception e) {
      e.printStackTrace();
      throw new RuntimeException(e);
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy