All Downloads are FREE. Search and download functionalities are using the official Maven repository.

co.cask.hydrator.plugin.spark.dynamic.ScalaSparkCompute Maven / Gradle / Ivy

/*
 * Copyright © 2017 Cask Data, Inc.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"); you may not
 * use this file except in compliance with the License. You may obtain a copy of
 * the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations under
 * the License.
 */

package co.cask.hydrator.plugin.spark.dynamic;

import co.cask.cdap.api.annotation.Description;
import co.cask.cdap.api.annotation.Macro;
import co.cask.cdap.api.annotation.Name;
import co.cask.cdap.api.annotation.Plugin;
import co.cask.cdap.api.data.format.StructuredRecord;
import co.cask.cdap.api.data.schema.Schema;
import co.cask.cdap.api.plugin.PluginConfig;
import co.cask.cdap.api.spark.dynamic.CompilationFailureException;
import co.cask.cdap.api.spark.dynamic.SparkInterpreter;
import co.cask.cdap.etl.api.PipelineConfigurer;
import co.cask.cdap.etl.api.StageConfigurer;
import co.cask.cdap.etl.api.batch.SparkCompute;
import co.cask.cdap.etl.api.batch.SparkExecutionPluginContext;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.rdd.RDD;
import scala.reflect.ClassTag$;

import java.io.IOException;
import java.io.PrintWriter;
import java.io.StringWriter;
import java.lang.reflect.Method;
import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;
import javax.annotation.Nullable;

/**
 * A {@link SparkCompute} that takes any scala code and executes it.
 */
@Plugin(type = SparkCompute.PLUGIN_TYPE)
@Name("ScalaSparkCompute")
@Description("Executes user-provided Spark code written in Scala that performs RDD to RDD transformation")
public class ScalaSparkCompute extends SparkCompute {

  private static final String PACKAGE_NAME = "co.cask.hydrator.plugin.spark.dynamic.generated";
  private static final String CLASS_NAME = "UserSparkCompute";
  private static final String FULL_CLASS_NAME = PACKAGE_NAME + "." + CLASS_NAME;

  private final Config config;
  // A strong reference is needed to keep the compiled classes around
  @SuppressWarnings("FieldCanBeLocal")
  private SparkInterpreter interpreter;
  private Method method;
  private boolean takeContext;

  public ScalaSparkCompute(Config config) {
    this.config = config;
  }

  @Override
  public void configurePipeline(PipelineConfigurer pipelineConfigurer) throws IllegalArgumentException {
    StageConfigurer stageConfigurer = pipelineConfigurer.getStageConfigurer();
    try {
      if (!config.containsMacro("schema")) {
        stageConfigurer.setOutputSchema(
          config.getSchema() == null ? stageConfigurer.getInputSchema() : Schema.parseJson(config.getSchema())
        );
      }
    } catch (IOException e) {
      throw new IllegalArgumentException("Unable to parse output schema " + config.getSchema(), e);
    }

    if (!config.containsMacro("scalaCode")) {
      SparkInterpreter interpreter = SparkCompilers.createInterpreter();
      if (interpreter != null) {
        try {
          interpreter.compile(generateSourceClass());
        } catch (CompilationFailureException e) {
          throw new IllegalArgumentException(e.getMessage(), e);
        }
      }
    }
  }

  @Override
  public void initialize(SparkExecutionPluginContext context) throws Exception {
    interpreter = context.createSparkInterpreter();
    StringWriter writer = new StringWriter();

    try (PrintWriter sourceWriter = new PrintWriter(writer, false)) {
      sourceWriter.println("package " + PACKAGE_NAME);
      // Includes some commonly used imports.
      sourceWriter.println("import co.cask.cdap.api.data.format._");
      sourceWriter.println("import co.cask.cdap.api.data.schema._");
      sourceWriter.println("import co.cask.cdap.etl.api.batch._");
      sourceWriter.println("import org.apache.spark._");
      sourceWriter.println("import org.apache.spark.api.java._");
      sourceWriter.println("import org.apache.spark.rdd._");
      sourceWriter.println("import org.apache.spark.sql._");
      sourceWriter.println("import org.apache.spark.SparkContext._");
      sourceWriter.println("import scala.collection.JavaConversions._");
      sourceWriter.println("object " + CLASS_NAME + " {");
      sourceWriter.println(config.getScalaCode());
      sourceWriter.println("}");
    }

    interpreter.compile(writer.toString());

    // Use reflection to load the class and get the transform method
    try {
      Class computeClass = interpreter.getClassLoader().loadClass(FULL_CLASS_NAME);

      // First see if it has the transform(RDD[StructuredRecord], SparkExecutionPluginContext) method
      try {
        method = computeClass.getDeclaredMethod("transform", RDD.class, SparkExecutionPluginContext.class);
        takeContext = true;
      } catch (NoSuchMethodException e) {
        // Try to find the transform(RDD[StructuredRecord])
        method = computeClass.getDeclaredMethod("transform", RDD.class);
        takeContext = false;
      }

      Type[] parameterTypes = method.getGenericParameterTypes();

      // The first parameter should be of type RDD[StructuredRecord]
      validateRDDType(parameterTypes[0],
                      "The first parameter of the 'transform' method should have type as 'RDD[StructuredRecord]'");

      // If it has second parameter, then must be SparkExecutionPluginContext
      if (takeContext && !SparkExecutionPluginContext.class.equals(parameterTypes[1])) {
        throw new IllegalArgumentException(
          "The second parameter of the 'transform' method should have type as SparkExecutionPluginContext");
      }

      // The return type of the method must be RDD[StructuredRecord]
      validateRDDType(method.getGenericReturnType(),
                      "The return type of the 'transform' method should be 'RDD[StructuredRecord]'");

      method.setAccessible(true);
    } catch (NoSuchMethodException e) {
      throw new IllegalArgumentException(
        "Missing a `transform` method that has signature either as " +
        "'def transform(rdd: RDD[StructuredRecord]) : RDD[StructuredRecord]' or " +
        "'def transform(rdd: RDD[StructuredRecord], context: SparkExecutionPluginContext) : RDD[StructuredRecord]'", e);
    }
  }

  @Override
  public JavaRDD transform(SparkExecutionPluginContext sparkExecutionPluginContext,
                                             JavaRDD javaRDD) throws Exception {
    RDD rdd;
    if (takeContext) {
      //noinspection unchecked
      rdd = (RDD) method.invoke(null, javaRDD.rdd(), sparkExecutionPluginContext);
    } else {
      //noinspection unchecked
      rdd = (RDD) method.invoke(null, javaRDD.rdd());
    }
    return JavaRDD.fromRDD(rdd, ClassTag$.MODULE$.apply(StructuredRecord.class));
  }

  private String generateSourceClass() {
    StringWriter writer = new StringWriter();

    try (PrintWriter sourceWriter = new PrintWriter(writer, false)) {
      sourceWriter.println("package " + PACKAGE_NAME);
      sourceWriter.println("import co.cask.cdap.api.data.format._");
      sourceWriter.println("import co.cask.cdap.api.data.schema._");
      sourceWriter.println("import co.cask.cdap.etl.api.batch._");
      sourceWriter.println("import org.apache.spark._");
      sourceWriter.println("import org.apache.spark.api.java._");
      sourceWriter.println("import org.apache.spark.rdd._");
      sourceWriter.println("import org.apache.spark.SparkContext._");
      sourceWriter.println("object " + CLASS_NAME + " {");
      sourceWriter.println(config.getScalaCode());
      sourceWriter.println("}");
    }
    return writer.toString();
  }

  /**
   * Validates the given type is {@code RDD}.
   */
  private void validateRDDType(Type rddType, String errorMessage) {
    if (!(rddType instanceof ParameterizedType)) {
      throw new IllegalArgumentException(errorMessage);
    }
    if (!RDD.class.equals(((ParameterizedType) rddType).getRawType())) {
      throw new IllegalArgumentException(errorMessage);
    }

    Type[] typeParams = ((ParameterizedType) rddType).getActualTypeArguments();
    if (typeParams.length < 1 || !typeParams[0].equals(StructuredRecord.class)) {
      throw new IllegalArgumentException(errorMessage);
    }
  }

  /**
   * Configuration object for the plugin
   */
  public static final class Config extends PluginConfig {

    @Description("Spark code in Scala defining how to transform RDD to RDD. " +
      "The code must implement a function " +
      "called 'transform', which has signature as either \n" +
      "  def transform(rdd: RDD[StructuredRecord]) : RDD[StructuredRecord]\n" +
      "  or\n" +
      "  def transform(rdd: RDD[StructuredRecord], context: SparkExecutionPluginContext) : RDD[StructuredRecord]\n" +
      "For example:\n" +
      "'def transform(rdd: RDD[StructuredRecord]) : RDD[StructuredRecord] = {\n" +
      "   rdd.filter(_.get(\"gender\") == null)\n" +
      " }'\n" +
      "will filter out incoming records that does not have the 'gender' field."
    )
    @Macro
    private final String scalaCode;

    @Description("The schema of output objects. If no schema is given, it is assumed that the output schema is " +
      "the same as the input schema.")
    @Nullable
    @Macro
    private final String schema;

    public Config(String scalaCode, @Nullable String schema) {
      this.scalaCode = scalaCode;
      this.schema = schema;
    }

    public String getScalaCode() {
      return scalaCode;
    }

    @Nullable
    public String getSchema() {
      return schema;
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy