All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.yahoo.sketches.pig.sampling.ReservoirSampling Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2016, Yahoo! Inc.
 * Licensed under the terms of the Apache License 2.0. See LICENSE file at the project root for terms.
 */

package com.yahoo.sketches.pig.sampling;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

import org.apache.pig.AccumulatorEvalFunc;
import org.apache.pig.Algebraic;
import org.apache.pig.EvalFunc;
import org.apache.pig.backend.executionengine.ExecException;
import org.apache.pig.data.BagFactory;
import org.apache.pig.data.DataBag;
import org.apache.pig.data.DataType;
import org.apache.pig.data.Tuple;
import org.apache.pig.data.TupleFactory;
import org.apache.pig.impl.logicalLayer.FrontendException;
import org.apache.pig.impl.logicalLayer.schema.Schema;

import com.yahoo.sketches.sampling.ReservoirItemsSketch;
import com.yahoo.sketches.sampling.ReservoirItemsUnion;
import com.yahoo.sketches.sampling.SamplingPigUtil;

/**
 * This is a Pig UDF that applies reservoir sampling to input tuples. It implements both
 * the Accumulator and Algebraic interfaces for efficient performance.
 *
 * @author Jon Malkin
 */
public class ReservoirSampling extends AccumulatorEvalFunc implements Algebraic {
  // defined for test consistency
  static final String N_ALIAS = "n";
  static final String K_ALIAS = "k";
  static final String SAMPLES_ALIAS = "samples";

  private static final int DEFAULT_TARGET_K = 1024;

  private final int targetK_;
  private ReservoirItemsSketch reservoir_;

  /**
   * Reservoir sampling constructor.
   * @param kStr String indicating the maximum number of desired entries in the reservoir.
   */
  public ReservoirSampling(final String kStr) {
    targetK_ = Integer.parseInt(kStr);

    if (targetK_ < 2) {
      throw new IllegalArgumentException("ReservoirSampling requires target reservoir size >= 2: "
              + targetK_);
    }
  }

  ReservoirSampling() { targetK_ = DEFAULT_TARGET_K; }

  @Override
  public Tuple exec(final Tuple inputTuple) throws IOException {
    if (inputTuple == null || inputTuple.size() < 1 || inputTuple.isNull(0)) {
      return null;
    }

    final DataBag samples = (DataBag) inputTuple.get(0);

    // if entire input data fits in reservoir, shortcut result
    if (samples.size() <= targetK_) {
      return createResultTuple(samples.size(), targetK_, samples);
    }
    return super.exec(inputTuple);
  }

  @Override
  public void accumulate(final Tuple inputTuple) throws IOException {
    if (inputTuple == null || inputTuple.size() < 1 || inputTuple.isNull(0)) {
      return;
    }

    final DataBag samples = (DataBag) inputTuple.get(0);

    if (reservoir_ == null) {
      reservoir_ = ReservoirItemsSketch.newInstance(targetK_);
    }

    for (Tuple t : samples) {
      reservoir_.update(t);
    }
  }

  @Override
  public Tuple getValue() {
    if (reservoir_ == null) {
      return null;
    }

    final List data = SamplingPigUtil.getRawSamplesAsList(reservoir_);
    final DataBag sampleBag = BagFactory.getInstance().newDefaultBag(data);

    return createResultTuple(reservoir_.getN(), reservoir_.getK(), sampleBag);
  }

  @Override
  public void cleanup() {
    reservoir_ = null;
  }

  @Override
  public Schema outputSchema(final Schema input) {
    if (input != null && input.size() > 0) {
      try {
        Schema source = input;

        // if we have a bag, grab one level down to get a tuple
        if (source.size() == 1 && source.getField(0).type == DataType.BAG) {
          source = source.getField(0).schema;
        }

        final Schema recordSchema = new Schema();
        recordSchema.add(new Schema.FieldSchema(N_ALIAS, DataType.LONG));
        recordSchema.add(new Schema.FieldSchema(K_ALIAS, DataType.INTEGER));

        // this should add a bag to the output
        recordSchema.add(new Schema.FieldSchema(SAMPLES_ALIAS, source, DataType.BAG));

        return new Schema(new Schema.FieldSchema(getSchemaName(this
                .getClass().getName().toLowerCase(), source), recordSchema, DataType.TUPLE));
      }
      catch (final FrontendException e) {
        throw new RuntimeException(e);
      }
    }
    return null;
  }

  static Tuple createResultTuple(final long n, final int k, final DataBag samples) {
    final Tuple output = TupleFactory.getInstance().newTuple(3);

    try {
      output.set(0, n);
      output.set(1, k);
      output.set(2, samples);
    } catch (final ExecException e) {
      throw new RuntimeException("Pig error: " + e.getMessage(), e);
    }

    return output;
  }


  @Override
  public String getInitial() {
    return Initial.class.getName();
  }

  @Override
  public String getIntermed() {
    return IntermediateFinal.class.getName();
  }

  @Override
  public String getFinal() {
    return IntermediateFinal.class.getName();
  }

  public static class Initial extends EvalFunc {
    private final int targetK_;

    public Initial() {
      targetK_ = DEFAULT_TARGET_K;
    }

    /**
     * Map-side constructor for reservoir sampling UDF
     * @param kStr String indicating the maximum number of desired entries in the reservoir.
     * */
    public Initial(final String kStr) {
      targetK_ = Integer.parseInt(kStr);

      if (targetK_ < 2) {
        throw new IllegalArgumentException("ReservoirSampling requires target reservoir size >= 2: "
                + targetK_);
      }
    }

    @Override
    public Tuple exec(final Tuple inputTuple) throws IOException {
      if (inputTuple == null || inputTuple.size() < 1 || inputTuple.isNull(0)) {
        return null;
      }

      final DataBag records = (DataBag) inputTuple.get(0);

      final ReservoirItemsSketch reservoir;
      final DataBag outputBag;
      int k = targetK_;
      if (records.size() <= targetK_) {
        outputBag = records;
      } else {
        reservoir = ReservoirItemsSketch.newInstance(targetK_);
        for (Tuple t : records) {
          reservoir.update(t);
        }
        // newDefaultBag(List) does *not* copy values
        final List data = SamplingPigUtil.getRawSamplesAsList(reservoir);
        outputBag = BagFactory.getInstance().newDefaultBag(data);
        k = reservoir.getK();
      }

      final Tuple output = TupleFactory.getInstance().newTuple(3);
      output.set(0, records.size());
      output.set(1, k);
      output.set(2, outputBag);

      return output;
    }
  }

  public static class IntermediateFinal extends EvalFunc {
    private final int targetK_;

    public IntermediateFinal() {
      targetK_ = DEFAULT_TARGET_K;
    }

    /**
     * Combiner and reducer side constructor for reservoir sampling UDF
     * @param kStr String indicating the maximum number of desired entries in the reservoir.
     * */
    public IntermediateFinal(final String kStr) {
      targetK_ = Integer.parseInt(kStr);

      if (targetK_ < 2) {
        throw new IllegalArgumentException("ReservoirSampling requires target reservoir size >= 2: "
                + targetK_);
      }
    }

    @Override
    public Tuple exec(final Tuple inputTuple) throws IOException {
      if (inputTuple == null || inputTuple.size() < 1 || inputTuple.isNull(0)) {
        return null;
      }

      final ReservoirItemsUnion union = ReservoirItemsUnion.newInstance(targetK_);

      final DataBag outerBag = (DataBag) inputTuple.get(0);
      for (Tuple reservoir : outerBag) {
        final long n = (long) reservoir.get(0);
        final int k  = (int) reservoir.get(1);

        if (n <= k && k <= targetK_) {
          for (Tuple t : (DataBag) reservoir.get(2)) {
            union.update(t);
          }
        } else {
          final ArrayList samples = dataBagToArrayList((DataBag) reservoir.get(2));
          union.update(n, k, samples);
        }
      }

      final ReservoirItemsSketch result = union.getResult();
      final ArrayList data = SamplingPigUtil.getRawSamplesAsList(result);
      final DataBag sampleBag = BagFactory.getInstance().newDefaultBag(data);

      final Tuple output = TupleFactory.getInstance().newTuple(3);
      output.set(0, result.getN());
      output.set(1, result.getK());
      output.set(2, sampleBag);

      return output;
    }
  }

  static ArrayList dataBagToArrayList(final DataBag bag) {
    final int arrayLength = (int) bag.size();
    final ArrayList output = new ArrayList<>(arrayLength);

    for (Tuple t : bag) {
      output.add(t);
    }

    return output;
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy