All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.davidbracewell.stream.SparkStream Maven / Gradle / Ivy

There is a newer version: 0.5
Show newest version
/*
 * (c) 2005 David B. Bracewell
 *
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package com.davidbracewell.stream;

import com.davidbracewell.config.Config;
import com.davidbracewell.conversion.Cast;
import com.davidbracewell.function.SerializableBinaryOperator;
import com.davidbracewell.function.SerializableConsumer;
import com.davidbracewell.function.SerializableFunction;
import com.davidbracewell.function.SerializablePredicate;
import com.davidbracewell.io.resource.Resource;
import lombok.NonNull;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import scala.Tuple2;

import java.io.IOException;
import java.io.Serializable;
import java.util.*;
import java.util.function.ToDoubleFunction;
import java.util.stream.Collector;

/**
 * The type Spark stream.
 *
 * @param  the type parameter
 * @author David B. Bracewell
 */
public class SparkStream implements MStream, Serializable {
  private static final long serialVersionUID = 1L;
  private final JavaRDD rdd;

  /**
   * Instantiates a new Spark stream.
   *
   * @param rdd the rdd
   */
  public SparkStream(JavaRDD rdd) {
    this.rdd = rdd;
  }

  /**
   * Instantiates a new Spark stream.
   *
   * @param collection the collection
   */
  public SparkStream(List collection) {
    int slices = Math.max(1, collection.size() / Config.get("spark.sliceSize").asIntegerValue(100));
    this.rdd = Spark.context().parallelize(collection, slices);
  }

  /**
   * Gets rdd.
   *
   * @return the rdd
   */
  public JavaRDD getRDD() {
    return rdd;
  }

  /**
   * Gets context.
   *
   * @return the context
   */
  public JavaSparkContext getContext() {
    return Spark.context(this);
  }

  @Override
  public void close() throws IOException {
  }

  @Override
  public MStream filter(SerializablePredicate predicate) {
    return new SparkStream<>(rdd.filter(predicate::test));
  }

  @Override
  public  MStream map(SerializableFunction function) {
    return new SparkStream<>(rdd.map(function::apply));
  }

  @Override
  public  MStream flatMap(SerializableFunction> mapper) {
    return new SparkStream<>(rdd.flatMap(t -> Cast.as(mapper.apply(t))));
  }

  @Override
  public  MPairStream flatMapToPair(SerializableFunction>> function) {
    return new SparkPairStream(rdd.flatMapToPair(t -> {
      List> list = new LinkedList<>();
      function.apply(t).forEach(e -> list.add(new Tuple2<>(e.getKey(), e.getValue())));
      return list;
    }));
  }

  @Override
  public  MPairStream mapToPair(SerializableFunction> function) {
    return new SparkPairStream<>(
      rdd.mapToPair(t -> {
        Map.Entry entry = Cast.as(function.apply(t));
        return new Tuple2<>(entry.getKey(), entry.getValue());
      })
    );
  }

  @Override
  public  MPairStream> groupBy(SerializableFunction function) {
    return new SparkPairStream<>(
      rdd.groupBy(function::apply)
    );
  }

  @Override
  public  R collect(Collector collector) {
    return collect().stream().collect(collector);
  }

  @Override
  public List collect() {
    return rdd.collect();
  }

  @Override
  public Optional reduce(SerializableBinaryOperator reducer) {
    return Optional.of(rdd.reduce(reducer::apply));
  }

  @Override
  public T fold(T zeroValue, SerializableBinaryOperator operator) {
    return rdd.fold(zeroValue, operator::apply);
  }

  @Override
  public void forEach(SerializableConsumer consumer) {
    rdd.foreach(consumer::accept);
  }

  @Override
  public void forEachLocal(SerializableConsumer consumer) {
    rdd.toLocalIterator().forEachRemaining(consumer);
  }

  @Override
  public Iterator iterator() {
    return rdd.toLocalIterator();
  }

  @Override
  public Optional first() {
    return Optional.ofNullable(rdd.first());
  }

  @Override
  public MStream sample(int number) {
    return new SparkStream<>(rdd.sample(false, number / (double) count()));
  }

  @Override
  public long count() {
    return rdd.count();
  }

  @Override
  public boolean isEmpty() {
    return rdd.isEmpty();
  }

  @Override
  public Map countByValue() {
    return rdd.countByValue();
  }

  @Override
  public MStream distinct() {
    return new SparkStream<>(rdd.distinct());
  }

  @Override
  public MStream limit(long number) {
    return new SparkStream<>(rdd.zipWithIndex().filter(p -> p._2() < number).map(Tuple2::_1));
  }

  @Override
  public List take(int n) {
    return rdd.take(n);
  }

  @Override
  public MStream skip(long n) {
    return new SparkStream<>(rdd.zipWithIndex().filter(p -> p._2() > n).map(Tuple2::_1));
  }

  @Override
  public void onClose(Runnable closeHandler) {

  }

  @Override
  public MStream sorted(boolean ascending) {
    return new SparkStream<>(rdd.sortBy(t -> t, ascending, rdd.partitions().size()));
  }

  @Override
  public Optional max(Comparator comparator) {
    return Optional.ofNullable(rdd.max(Cast.as(comparator)));
  }

  @Override
  public Optional min(Comparator comparator) {
    return Optional.ofNullable(rdd.min(Cast.as(comparator)));
  }

  @Override
  public  MPairStream zip(MStream other) {
    if (other instanceof SparkStream) {
      return new SparkPairStream<>(rdd.zip(Cast.>as(other).rdd));
    }
    JavaSparkContext jsc = new JavaSparkContext(rdd.context());
    return new SparkPairStream<>(rdd.zip(jsc.parallelize(other.collect(), rdd.partitions().size())));
  }

  @Override
  public MPairStream zipWithIndex() {
    return new SparkPairStream<>(rdd.zipWithIndex());
  }


  @Override
  public MDoubleStream mapToDouble(ToDoubleFunction function) {
    return new SparkDoubleStream(rdd.mapToDouble(function::applyAsDouble));
  }

  @Override
  public MStream cache() {
    return new SparkStream<>(rdd.cache());
  }

  @Override
  public MStream union(MStream other) {
    if (other instanceof SparkStream) {
      return new SparkStream<>(rdd.union(Cast.>as(other).rdd));
    }
    return new SparkStream<>(rdd.union(Spark.context(rdd).parallelize(other.collect())));
  }

  @Override
  public void saveAsTextFile(@NonNull Resource location) {
    rdd.saveAsTextFile(location.descriptor());
  }

  @Override
  public void saveAsTextFile(@NonNull String location) {
    rdd.saveAsTextFile(location);
  }


  @Override
  public MStream parallel() {
    return this;
  }

  @Override
  public MStream shuffle(@NonNull Random random) {
    return new SparkStream<>(
      rdd.mapToPair(t -> new Tuple2<>(random.nextDouble(), t))
        .sortByKey()
        .map(Tuple2::_2)
    );
  }
}//END OF SparkStream