All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.hudi.client.common.HoodieSparkEngineContext Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.hudi.client.common;

import org.apache.hudi.client.SparkTaskContextSupplier;
import org.apache.hudi.common.data.HoodieAccumulator;
import org.apache.hudi.common.data.HoodieData;
import org.apache.hudi.common.data.HoodieData.HoodieDataCacheKey;
import org.apache.hudi.common.engine.EngineProperty;
import org.apache.hudi.common.engine.HoodieEngineContext;
import org.apache.hudi.common.function.SerializableBiFunction;
import org.apache.hudi.common.function.SerializableConsumer;
import org.apache.hudi.common.function.SerializableFunction;
import org.apache.hudi.common.function.SerializablePairFlatMapFunction;
import org.apache.hudi.common.function.SerializablePairFunction;
import org.apache.hudi.common.util.Functions;
import org.apache.hudi.common.util.Option;
import org.apache.hudi.common.util.collection.ImmutablePair;
import org.apache.hudi.common.util.collection.Pair;
import org.apache.hudi.data.HoodieJavaRDD;
import org.apache.hudi.data.HoodieSparkLongAccumulator;
import org.apache.hudi.exception.HoodieException;
import org.apache.hudi.hadoop.fs.HadoopFSUtils;

import org.apache.hadoop.conf.Configuration;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.spark.sql.SQLContext;

import javax.annotation.concurrent.ThreadSafe;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import scala.Tuple2;

/**
 * A Spark engine implementation of HoodieEngineContext.
 */
@ThreadSafe
public class HoodieSparkEngineContext extends HoodieEngineContext {

  private final JavaSparkContext javaSparkContext;
  private final SQLContext sqlContext;
  private final Map> cachedRddIds = new HashMap<>();

  public HoodieSparkEngineContext(JavaSparkContext jsc) {
    this(jsc, SQLContext.getOrCreate(jsc.sc()));
  }

  public HoodieSparkEngineContext(JavaSparkContext jsc, SQLContext sqlContext) {
    super(HadoopFSUtils.getStorageConfWithCopy(jsc.hadoopConfiguration()), new SparkTaskContextSupplier());
    this.javaSparkContext = jsc;
    this.sqlContext = sqlContext;
  }

  public JavaSparkContext getJavaSparkContext() {
    return javaSparkContext;
  }

  public JavaSparkContext jsc() {
    return javaSparkContext;
  }

  public SQLContext getSqlContext() {
    return sqlContext;
  }

  public static JavaSparkContext getSparkContext(HoodieEngineContext context) {
    return ((HoodieSparkEngineContext) context).getJavaSparkContext();
  }

  @Override
  public HoodieAccumulator newAccumulator() {
    HoodieSparkLongAccumulator accumulator = HoodieSparkLongAccumulator.create();
    javaSparkContext.sc().register(accumulator.getAccumulator());
    return accumulator;
  }

  @Override
  public  HoodieData emptyHoodieData() {
    return HoodieJavaRDD.of(javaSparkContext.emptyRDD());
  }

  public boolean supportsFileGroupReader() {
    return true;
  }

  @Override
  public  HoodieData parallelize(List data, int parallelism) {
    return HoodieJavaRDD.of(javaSparkContext.parallelize(data, parallelism));
  }

  @Override
  public  List map(List data, SerializableFunction func, int parallelism) {
    return javaSparkContext.parallelize(data, parallelism).map(func::apply).collect();
  }

  @Override
  public  List mapToPairAndReduceByKey(List data, SerializablePairFunction mapToPairFunc, SerializableBiFunction reduceFunc, int parallelism) {
    return javaSparkContext.parallelize(data, parallelism).mapToPair(input -> {
      Pair pair = mapToPairFunc.call(input);
      return new Tuple2<>(pair.getLeft(), pair.getRight());
    }).reduceByKey(reduceFunc::apply).map(Tuple2::_2).collect();
  }

  @Override
  public  Stream> mapPartitionsToPairAndReduceByKey(
      Stream data, SerializablePairFlatMapFunction, K, V> flatMapToPairFunc,
      SerializableBiFunction reduceFunc, int parallelism) {
    return javaSparkContext.parallelize(data.collect(Collectors.toList()), parallelism)
        .mapPartitionsToPair((PairFlatMapFunction, K, V>) iterator ->
            flatMapToPairFunc.call(iterator).collect(Collectors.toList()).stream()
                .map(e -> new Tuple2<>(e.getKey(), e.getValue())).iterator()
        )
        .reduceByKey(reduceFunc::apply)
        .map(e -> new ImmutablePair<>(e._1, e._2))
        .collect().stream();
  }

  @Override
  public  List reduceByKey(
      List> data, SerializableBiFunction reduceFunc, int parallelism) {
    return javaSparkContext.parallelize(data, parallelism).mapToPair(pair -> new Tuple2(pair.getLeft(), pair.getRight()))
        .reduceByKey(reduceFunc::apply).map(Tuple2::_2).collect();
  }

  @Override
  public  List flatMap(List data, SerializableFunction> func, int parallelism) {
    return javaSparkContext.parallelize(data, parallelism).flatMap(x -> func.apply(x).iterator()).collect();
  }

  @Override
  public  void foreach(List data, SerializableConsumer consumer, int parallelism) {
    javaSparkContext.parallelize(data, parallelism).foreach(consumer::accept);
  }

  @Override
  public  Map mapToPair(List data, SerializablePairFunction func, Integer parallelism) {
    if (Objects.nonNull(parallelism)) {
      return javaSparkContext.parallelize(data, parallelism).mapToPair(input -> {
        Pair pair = func.call(input);
        return new Tuple2(pair.getLeft(), pair.getRight());
      }).collectAsMap();
    } else {
      return javaSparkContext.parallelize(data).mapToPair(input -> {
        Pair pair = func.call(input);
        return new Tuple2(pair.getLeft(), pair.getRight());
      }).collectAsMap();
    }
  }

  @Override
  public void setProperty(EngineProperty key, String value) {
    if (key.equals(EngineProperty.COMPACTION_POOL_NAME)
        || key.equals(EngineProperty.CLUSTERING_POOL_NAME)
        || key.equals(EngineProperty.DELTASYNC_POOL_NAME)) {
      javaSparkContext.setLocalProperty("spark.scheduler.pool", value);
    } else {
      throw new HoodieException("Unknown engine property :" + key);
    }
  }

  @Override
  public Option getProperty(EngineProperty key) {
    if (key == EngineProperty.EMBEDDED_SERVER_HOST) {
      return Option.ofNullable(javaSparkContext.getConf().get("spark.driver.host", null));
    }
    throw new HoodieException("Unknown engine property :" + key);
  }

  @Override
  public void setJobStatus(String activeModule, String activityDescription) {
    javaSparkContext.setJobGroup(activeModule, activityDescription);
  }

  @Override
  public void putCachedDataIds(HoodieDataCacheKey cacheKey, int... ids) {
    synchronized (cachedRddIds) {
      cachedRddIds.putIfAbsent(cacheKey, new ArrayList<>());
      for (int id : ids) {
        cachedRddIds.get(cacheKey).add(id);
      }
    }
  }

  @Override
  public List getCachedDataIds(HoodieDataCacheKey cacheKey) {
    synchronized (cachedRddIds) {
      return cachedRddIds.getOrDefault(cacheKey, Collections.emptyList());
    }
  }

  @Override
  public List removeCachedDataIds(HoodieDataCacheKey cacheKey) {
    synchronized (cachedRddIds) {
      List removed = cachedRddIds.remove(cacheKey);
      return removed == null ? Collections.emptyList() : removed;
    }
  }

  @Override
  public void cancelJob(String groupId) {
    javaSparkContext.cancelJobGroup(groupId);
  }

  @Override
  public void cancelAllJobs() {
    javaSparkContext.cancelAllJobs();
  }

  @Override
  public  O aggregate(HoodieData data, O zeroValue, Functions.Function2 seqOp, Functions.Function2 combOp) {
    Function2 seqOpFunc = seqOp::apply;
    Function2 combOpFunc = combOp::apply;
    return HoodieJavaRDD.getJavaRDD(data).aggregate(zeroValue, seqOpFunc, combOpFunc);
  }

  public SparkConf getConf() {
    return javaSparkContext.getConf();
  }

  public Configuration hadoopConfiguration() {
    return javaSparkContext.hadoopConfiguration();
  }

  public  JavaRDD emptyRDD() {
    return javaSparkContext.emptyRDD();
  }
}