All Downloads are FREE. Search and download functionalities are using the official Maven repository.

co.cask.wrangler.sampling.Reservoir Maven / Gradle / Ivy

There is a newer version: 3.2.2
Show newest version
/*
 * Copyright © 2017 Cask Data, Inc.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"); you may not
 * use this file except in compliance with the License. You may obtain a copy of
 * the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations under
 * the License.
 */

package co.cask.wrangler.sampling;

import co.cask.wrangler.api.Sampler;
import com.google.common.base.Preconditions;

import java.util.Iterator;
import java.util.PriorityQueue;
import java.util.Random;

/**
 * A simple in memory implementation of Reservoir Sampling without replacement, and with only one
 * pass through the input iteration whose size is unpredictable. The basic idea behind this sampler
 * implementation is to generate a random number for each input element as its weight, select the
 * top K elements with max weight. As the weights are generated randomly, so are the selected
 * top K elements. In the first phase, we generate random numbers as the weights for each element and
 * select top K elements as the output of each partitions.
 *
 * @param  The type of the sampler.
 */
public class Reservoir extends Sampler {
  private final int numSamples;
  private final Random random;

  /**
   * Create a new sampler with reservoir size and a supplied random number generator.
   *
   * @param numSamples Maximum number of samples to retain in reservoir, must be non-negative.
   * @param random     Instance of random number generator for sampling.
   */
  public Reservoir(int numSamples, Random random) {
    Preconditions.checkArgument(numSamples >= 0, "numSamples should be non-negative.");
    this.numSamples = numSamples;
    this.random = random;
  }

  /**
   * Create a new sampler with reservoir size and a default random number generator.
   *
   * @param numSamples Maximum number of samples to retain in reservoir, must be non-negative.
   */
  public Reservoir(int numSamples) {
    this(numSamples, new XORShiftRNG());
  }

  /**
   * Create a new sampler with reservoir size and the seed for random number generator.
   *
   * @param numSamples Maximum number of samples to retain in reservoir, must be non-negative.
   * @param seed       Random number generator seed.
   */
  public Reservoir(int numSamples, long seed) {
    this(numSamples, new XORShiftRNG(seed));
  }

  @Override
  public Iterator sample(Iterator input) {
    if (numSamples == 0) {
      return EMPTY_ITERABLE;
    }

    // This queue holds fixed number elements with the top K weight for current partition.
    PriorityQueue> queue = new PriorityQueue<>(numSamples);
    int index = 0;
    IntermediateSample smallest = null;
    while (input.hasNext()) {
      T element = input.next();
      if (index < numSamples) {
        // Fill the queue with first K elements from input.
        queue.add(new IntermediateSample(random.nextDouble(), element));
        smallest = queue.peek();
      } else {
        double rand = random.nextDouble();
        // Remove the element with the smallest weight, and append current element into the queue.
        if (rand > smallest.getWeight()) {
          queue.remove();
          queue.add(new IntermediateSample(rand, element));
          smallest = queue.peek();
        }
      }
      index++;
    }

    final Iterator> itr = queue.iterator();
    return new Iterator() {
      @Override
      public boolean hasNext() {
        return itr.hasNext();
      }

      @Override
      public T next() {
        return itr.next().getElement();
      }

      @Override
      public void remove() {
        itr.remove();
      }
    };
  }

  public static class IntermediateSample implements Comparable> {
    private double weight;
    private T element;

    public IntermediateSample(double weight, T element) {
      this.weight = weight;
      this.element = element;
    }

    public double getWeight() {
      return weight;
    }

    public T getElement() {
      return element;
    }

    @Override
    public int compareTo(IntermediateSample other) {
      return this.weight >= other.getWeight() ? 1 : -1;
    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy