com.yahoo.sketches.pig.sampling.VarOptSampling Maven / Gradle / Ivy
/*
* Copyright 2017, Yahoo! Inc.
* Licensed under the terms of the Apache License 2.0. See LICENSE file at the project root for terms.
*/
package com.yahoo.sketches.pig.sampling;
import static com.yahoo.sketches.pig.sampling.VarOptCommonImpl.DEFAULT_TARGET_K;
import static com.yahoo.sketches.pig.sampling.VarOptCommonImpl.DEFAULT_WEIGHT_IDX;
import static com.yahoo.sketches.pig.sampling.VarOptCommonImpl.RECORD_ALIAS;
import static com.yahoo.sketches.pig.sampling.VarOptCommonImpl.WEIGHT_ALIAS;
import static com.yahoo.sketches.pig.sampling.VarOptCommonImpl.createDataBagFromSketch;
import static com.yahoo.sketches.pig.sampling.VarOptCommonImpl.unionSketches;
import java.io.IOException;
import org.apache.pig.AccumulatorEvalFunc;
import org.apache.pig.Algebraic;
import org.apache.pig.EvalFunc;
import org.apache.pig.data.DataBag;
import org.apache.pig.data.DataType;
import org.apache.pig.data.Tuple;
import org.apache.pig.impl.logicalLayer.FrontendException;
import org.apache.pig.impl.logicalLayer.schema.Schema;
import com.yahoo.sketches.sampling.VarOptItemsSketch;
import com.yahoo.sketches.sampling.VarOptItemsUnion;
/**
* Applies VarOpt sampling to input tuples. Implements both the Accumulator and
* Algebraic interfaces for efficient performance.
*
* @author Jon Malkin
*/
public class VarOptSampling extends AccumulatorEvalFunc implements Algebraic {
private final int targetK_;
private final int weightIdx_;
private VarOptItemsSketch sketch_;
/**
* VarOpt sampling constructor.
* @param kStr String indicating the maximum number of desired samples to return.
*/
public VarOptSampling(final String kStr) {
targetK_ = Integer.parseInt(kStr);
weightIdx_ = DEFAULT_WEIGHT_IDX;
if (targetK_ < 1) {
throw new IllegalArgumentException("VarOptSampling requires target sample size >= 1: "
+ targetK_);
}
}
/**
* VarOpt sampling constructor.
* @param kStr String indicating the maximum number of desired samples to return.
* @param weightIdxStr String indicating column index (0-based) of weight values
*/
public VarOptSampling(final String kStr, final String weightIdxStr) {
targetK_ = Integer.parseInt(kStr);
weightIdx_ = Integer.parseInt(weightIdxStr);
if (targetK_ < 1) {
throw new IllegalArgumentException("VarOptSampling requires target sample size >= 1: "
+ targetK_);
}
if (weightIdx_ < 0) {
throw new IllegalArgumentException("VarOptSampling requires weight index >= 0: "
+ weightIdx_);
}
}
VarOptSampling() {
targetK_ = DEFAULT_TARGET_K;
weightIdx_ = DEFAULT_WEIGHT_IDX;
}
@Override
public void accumulate(final Tuple inputTuple) throws IOException {
if (inputTuple == null || inputTuple.size() < 1 || inputTuple.isNull(0)) {
return;
}
final DataBag samples = (DataBag) inputTuple.get(0);
if (sketch_ == null) {
sketch_ = VarOptItemsSketch.newInstance(targetK_);
}
for (final Tuple t : samples) {
final double weight = (double) t.get(weightIdx_);
sketch_.update(t, weight);
}
}
@Override
public DataBag getValue() {
if (sketch_ == null) {
return null;
}
return createDataBagFromSketch(sketch_);
}
@Override
public void cleanup() {
sketch_ = null;
}
@Override
public Schema outputSchema(final Schema input) {
try {
if (input == null || input.size() == 0) {
throw new IllegalArgumentException("Degenerate input schema to VarOptSampling");
}
// first element must be a bag, weightIdx_ element of tuples must be a float or double
if (input.getField(0).type != DataType.BAG) {
throw new IllegalArgumentException("VarOpt input must be a data bag: "
+ input.toString());
}
final Schema record = input.getField(0).schema; // record has a tuple in field 0
final Schema fields = record.getField(0).schema;
if (fields.getField(weightIdx_).type != DataType.DOUBLE
&& fields.getField(weightIdx_).type != DataType.FLOAT) {
throw new IllegalArgumentException("weightIndex item of VarOpt tuple must be a "
+ "weight (double/float), found " + fields.getField(0).type
+ ": " + fields.toString());
}
final Schema weightedSampleSchema = new Schema();
weightedSampleSchema.add(new Schema.FieldSchema(WEIGHT_ALIAS, DataType.DOUBLE));
weightedSampleSchema.add(new Schema.FieldSchema(RECORD_ALIAS, record, DataType.TUPLE));
return new Schema(new Schema.FieldSchema(getSchemaName(this
.getClass().getName().toLowerCase(), record), weightedSampleSchema, DataType.BAG));
}
catch (final FrontendException e) {
throw new RuntimeException(e);
}
}
@Override
public String getInitial() {
return VarOptCommonImpl.RawTuplesToSketchTuple.class.getName();
}
@Override
public String getIntermed() {
return VarOptCommonImpl.UnionSketchesAsTuple.class.getName();
}
@Override
public String getFinal() {
return Final.class.getName();
}
public static class Final extends EvalFunc {
private final int targetK_;
private final int weightIdx_;
public Final() {
targetK_ = DEFAULT_TARGET_K;
weightIdx_ = DEFAULT_WEIGHT_IDX;
}
/**
* Reducer VarOpt sampling constructor.
* @param kStr String indicating the maximum number of desired samples to return.
*/
public Final(final String kStr) {
targetK_ = Integer.parseInt(kStr);
weightIdx_ = DEFAULT_WEIGHT_IDX;
if (targetK_ < 1) {
throw new IllegalArgumentException("ReservoirSampling requires target reservoir size >= 1: "
+ targetK_);
}
}
/**
* VarOpt sampling constructor.
* @param kStr String indicating the maximum number of desired samples to return.
* @param weightIdxStr String indicating column index (0-based) of weight values
*/
public Final(final String kStr, final String weightIdxStr) {
targetK_ = Integer.parseInt(kStr);
weightIdx_ = Integer.parseInt(weightIdxStr);
if (targetK_ < 1) {
throw new IllegalArgumentException("VarOptSampling requires target sample size >= 1: "
+ targetK_);
}
if (weightIdx_ < 0) {
throw new IllegalArgumentException("VarOptSampling requires weight index >= 0: "
+ weightIdx_);
}
}
@Override
public DataBag exec(final Tuple inputTuple) throws IOException {
if (inputTuple == null || inputTuple.size() < 1 || inputTuple.isNull(0)) {
return null;
}
final VarOptItemsUnion union = unionSketches(inputTuple, targetK_);
return createDataBagFromSketch(union.getResult());
}
}
}