All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.davidbracewell.stream.SparkPairStream Maven / Gradle / Ivy

There is a newer version: 0.5
Show newest version
package com.davidbracewell.stream;

import com.davidbracewell.conversion.Cast;
import com.davidbracewell.function.*;
import com.davidbracewell.tuple.Tuple2;
import lombok.NonNull;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaSparkContext;

import java.io.Serializable;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Random;

/**
 * @author David B. Bracewell
 */
public class SparkPairStream implements MPairStream, Serializable {
  private static final long serialVersionUID = 1L;

  private final JavaPairRDD rdd;

  public SparkPairStream(JavaPairRDD rdd) {
    this.rdd = rdd;
  }

  public SparkPairStream(Map map) {
    this(Spark.context(), map);
  }

  public SparkPairStream(JavaSparkContext context, Map map) {
    List> tuples = new LinkedList<>();
    map.forEach((k, v) -> tuples.add(new scala.Tuple2<>(k, v)));
    this.rdd = context
      .parallelize(tuples)
      .mapToPair(t -> t);
  }

  static  Map.Entry toMapEntry(scala.Tuple2 tuple2) {
    return Tuple2.of(tuple2._1(), tuple2._2());
  }

  @Override
  public  MPairStream> join(MPairStream stream) {
    return new SparkPairStream<>(rdd.join(toPairRDD(stream)).mapToPair(t -> new scala.Tuple2<>(t._1(), toMapEntry(t._2()))));
  }

  @Override
  public MPairStream reduceByKey(SerializableBinaryOperator operator) {
    return new SparkPairStream<>(rdd.reduceByKey(operator::apply));
  }

  @Override
  public void forEach(@NonNull SerializableBiConsumer consumer) {
    rdd.foreach(tuple -> consumer.accept(tuple._1(), tuple._2()));
  }

  @Override
  public void forEachLocal(SerializableBiConsumer consumer) {
    rdd.toLocalIterator().forEachRemaining(e -> consumer.accept(e._1(), e._2()));
  }

  @Override
  public void close() throws Exception {

  }

  @Override
  public  MStream map(@NonNull SerializableBiFunction function) {
    return new SparkStream<>(rdd.map(
      e -> function.apply(e._1(), e._2())));
  }

  @Override
  public MPairStream> groupByKey() {
    return new SparkPairStream<>(rdd.groupByKey());
  }

  @Override
  public  MPairStream mapToPair(SerializableBiFunction> function) {
    return new SparkPairStream<>(
      rdd.mapToPair((t) -> {
        Map.Entry e = function.apply(t._1(), t._2());
        return new scala.Tuple2<>(e.getKey(), e.getValue());
      }));
  }

  @Override
  public MPairStream filter(SerializableBiPredicate predicate) {
    return new SparkPairStream<>(rdd.filter(tuple -> predicate.test(tuple._1(), tuple._2())));
  }

  @Override
  public Map collectAsMap() {
    return rdd.collectAsMap();
  }

  @Override
  public MPairStream filterByKey(SerializablePredicate predicate) {
    return new SparkPairStream<>(rdd.filter(tuple -> predicate.test(tuple._1())));
  }

  @Override
  public MPairStream filterByValue(SerializablePredicate predicate) {
    return new SparkPairStream<>(rdd.filter(tuple -> predicate.test(tuple._2())));
  }

  @Override
  public List> collectAsList() {
    return rdd.map(t -> Cast.>as(Tuple2.of(t._1(), t._2()))).collect();
  }

  @Override
  public long count() {
    return rdd.count();
  }


  @Override
  public MStream keys() {
    return new SparkStream<>(rdd.keys());
  }

  private  JavaPairRDD toPairRDD(MPairStream other) {
    JavaPairRDD oRDD;
    if (other instanceof SparkPairStream) {
      oRDD = Cast.>as(other).rdd;
    } else {
      oRDD = new SparkPairStream<>(Spark.context(rdd), other.collectAsMap()).rdd;
    }
    return oRDD;
  }


  @Override
  public MPairStream union(MPairStream other) {
    return new SparkPairStream<>(rdd.union(toPairRDD(other)));
  }

  @Override
  public MStream values() {
    return new SparkStream<>(rdd.values());
  }

  @Override
  public MPairStream sortByKey(SerializableComparator comparator) {
    return new SparkPairStream<>(rdd.sortByKey(comparator));
  }

  @Override
  public MPairStream parallel() {
    return this;
  }

  /**
   * Gets context.
   *
   * @return the context
   */
  public JavaSparkContext getContext() {
    return Spark.context(rdd);
  }

  @Override
  public MPairStream shuffle(Random random) {
    return new SparkPairStream(
      rdd.mapToPair(t -> new scala.Tuple2<>(random.nextDouble(), t))
        .sortByKey()
        .mapToPair(scala.Tuple2::_2)
    );
  }
}// END OF SparkPairStream