
com.google.cloud.dataflow.sdk.transforms.ApproximateUnique Maven / Gradle / Ivy
/*
* Copyright (C) 2015 Google Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/
package com.google.cloud.dataflow.sdk.transforms;
import com.google.cloud.dataflow.sdk.coders.Coder;
import com.google.cloud.dataflow.sdk.coders.Coder.Context;
import com.google.cloud.dataflow.sdk.coders.CoderException;
import com.google.cloud.dataflow.sdk.coders.CoderRegistry;
import com.google.cloud.dataflow.sdk.coders.KvCoder;
import com.google.cloud.dataflow.sdk.coders.SerializableCoder;
import com.google.cloud.dataflow.sdk.transforms.Combine.CombineFn;
import com.google.cloud.dataflow.sdk.values.KV;
import com.google.cloud.dataflow.sdk.values.PCollection;
import com.google.common.hash.Hashing;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.Serializable;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.PriorityQueue;
/**
* {@code PTransform}s for estimating the number of distinct elements
* in a {@code PCollection}, or the number of distinct values
* associated with each key in a {@code PCollection} of {@code KV}s.
*/
public class ApproximateUnique {
/**
* Returns a {@code PTransform} that takes a {@code PCollection}
* and returns a {@code PCollection} containing a single value
* that is an estimate of the number of distinct elements in the
* input {@code PCollection}.
*
* The {@code sampleSize} parameter controls the estimation
* error. The error is about {@code 2 / sqrt(sampleSize)}, so for
* {@code ApproximateUnique.globally(10000)} the estimation error is
* about 2%. Similarly, for {@code ApproximateUnique.of(16)} the
* estimation error is about 50%. If there are fewer than
* {@code sampleSize} distinct elements then the returned result
* will be exact with extremely high probability (the chance of a
* hash collision is about {@code sampleSize^2 / 2^65}).
*
*
This transform approximates the number of elements in a set
* by computing the top {@code sampleSize} hash values, and using
* that to extrapolate the size of the entire set of hash values by
* assuming the rest of the hash values are as densely distributed
* as the top {@code sampleSize}.
*
*
See also {@link #globally(double)}.
*
*
Example of use:
*
{@code
* PCollection pc = ...;
* PCollection approxNumDistinct =
* pc.apply(ApproximateUnique.globally(1000));
* }
*
* @param the type of the elements in the input {@code PCollection}
* @param sampleSize the number of entries in the statistical
* sample; the higher this number, the more accurate the
* estimate will be; should be {@code >= 16}
* @throws IllegalArgumentException if the {@code sampleSize}
* argument is too small
*/
public static Globally globally(int sampleSize) {
return new Globally<>(sampleSize);
}
/**
* Like {@link #globally(int)}, but specifies the desired maximum
* estimation error instead of the sample size.
*
* @param the type of the elements in the input {@code PCollection}
* @param maximumEstimationError the maximum estimation error, which
* should be in the range {@code [0.01, 0.5]}
* @throws IllegalArgumentException if the
* {@code maximumEstimationError} argument is out of range
*/
public static Globally globally(double maximumEstimationError) {
return new Globally<>(maximumEstimationError);
}
/**
* Returns a {@code PTransform} that takes a
* {@code PCollection>} and returns a
* {@code PCollection>} that contains an output element
* mapping each distinct key in the input {@code PCollection} to an
* estimate of the number of distinct values associated with that
* key in the input {@code PCollection}.
*
* See {@link #globally(int)} for an explanation of the
* {@code sampleSize} parameter. A separate sampling is computed
* for each distinct key of the input.
*
*
See also {@link #perKey(double)}.
*
*
Example of use:
*
{@code
* PCollection> pc = ...;
* PCollection> approxNumDistinctPerKey =
* pc.apply(ApproximateUnique.perKey(1000));
* }
*
* @param the type of the keys in the input and output
* {@code PCollection}s
* @param the type of the values in the input {@code PCollection}
* @param sampleSize the number of entries in the statistical
* sample; the higher this number, the more accurate the
* estimate will be; should be {@code >= 16}
* @throws IllegalArgumentException if the {@code sampleSize}
* argument is too small
*/
public static PerKey perKey(int sampleSize) {
return new PerKey<>(sampleSize);
}
/**
* Like {@link #perKey(int)}, but specifies the desired maximum
* estimation error instead of the sample size.
*
* @param the type of the keys in the input and output
* {@code PCollection}s
* @param the type of the values in the input {@code PCollection}
* @param maximumEstimationError the maximum estimation error, which
* should be in the range {@code [0.01, 0.5]}
* @throws IllegalArgumentException if the
* {@code maximumEstimationError} argument is out of range
*/
public static PerKey perKey(double maximumEstimationError) {
return new PerKey<>(maximumEstimationError);
}
/////////////////////////////////////////////////////////////////////////////
/**
* {@code PTransform} for estimating the number of distinct elements
* in a {@code PCollection}.
*
* @param the type of the elements in the input {@code PCollection}
*/
@SuppressWarnings("serial")
static class Globally extends PTransform, PCollection> {
/**
* The number of entries in the statistical sample; the higher this number,
* the more accurate the estimate will be.
*/
private final long sampleSize;
/**
* @see ApproximateUnique#globally(int)
*/
public Globally(int sampleSize) {
if (sampleSize < 16) {
throw new IllegalArgumentException(
"ApproximateUnique needs a sampleSize "
+ ">= 16 for an estimation error <= 50%. "
+ "In general, the estimation "
+ "error is about 2 / sqrt(sampleSize).");
}
this.sampleSize = sampleSize;
}
/**
* @see ApproximateUnique#globally(double)
*/
public Globally(double maximumEstimationError) {
if (maximumEstimationError < 0.01 || maximumEstimationError > 0.5) {
throw new IllegalArgumentException(
"ApproximateUnique needs an "
+ "estimation error between 1% (0.01) and 50% (0.5).");
}
this.sampleSize = sampleSizeFromEstimationError(maximumEstimationError);
}
@Override
public PCollection apply(PCollection input) {
Coder coder = input.getCoder();
return input.apply(
Combine.globally(
new ApproximateUniqueCombineFn<>(sampleSize, coder)));
}
}
/**
* {@code PTransform} for estimating the number of distinct values
* associated with each key in a {@code PCollection} of {@code KV}s.
*
* @param the type of the keys in the input and output
* {@code PCollection}s
* @param the type of the values in the input {@code PCollection}
*/
@SuppressWarnings("serial")
static class PerKey
extends PTransform>, PCollection>> {
private final long sampleSize;
/**
* @see ApproximateUnique#perKey(int)
*/
public PerKey(int sampleSize) {
if (sampleSize < 16) {
throw new IllegalArgumentException(
"ApproximateUnique needs a "
+ "sampleSize >= 16 for an estimation error <= 50%. In general, "
+ "the estimation error is about 2 / sqrt(sampleSize).");
}
this.sampleSize = sampleSize;
}
/**
* @see ApproximateUnique#perKey(double)
*/
public PerKey(double estimationError) {
if (estimationError < 0.01 || estimationError > 0.5) {
throw new IllegalArgumentException(
"ApproximateUnique.PerKey needs an "
+ "estimation error between 1% (0.01) and 50% (0.5).");
}
this.sampleSize = sampleSizeFromEstimationError(estimationError);
}
@Override
public PCollection> apply(PCollection> input) {
Coder> inputCoder = input.getCoder();
if (!(inputCoder instanceof KvCoder)) {
throw new IllegalStateException(
"ApproximateUnique.PerKey requires its input to use KvCoder");
}
@SuppressWarnings("unchecked")
final Coder coder = ((KvCoder) inputCoder).getValueCoder();
return input.apply(
Combine.perKey(new ApproximateUniqueCombineFn<>(
sampleSize, coder).asKeyedFn()));
}
}
/////////////////////////////////////////////////////////////////////////////
/**
* {@code CombineFn} that computes an estimate of the number of
* distinct values that were combined.
*
* Hashes input elements, computes the top {@code sampleSize}
* hash values, and uses those to extrapolate the size of the entire
* set of hash values by assuming the rest of the hash values are as
* densely distributed as the top {@code sampleSize}.
*
*
Used to implement
* {@link #globally(int) ApproximatUnique.globally(...)} and
* {@link #perKey(int) ApproximatUnique.perKey(...)}.
*
* @param the type of the values being combined
*/
@SuppressWarnings("serial")
public static class ApproximateUniqueCombineFn extends
CombineFn {
/**
* The size of the space of hashes returned by the hash function.
*/
static final double HASH_SPACE_SIZE =
Long.MAX_VALUE - (double) Long.MIN_VALUE;
/**
* A heap utility class to efficiently track the largest added elements.
*/
public static class LargestUnique implements Serializable {
private PriorityQueue heap = new PriorityQueue<>();
private final long sampleSize;
/**
* Creates a heap to track the largest {@code sampleSize} elements.
*
* @param sampleSize the size of the heap
*/
public LargestUnique(long sampleSize) {
this.sampleSize = sampleSize;
}
/**
* Adds a value to the heap, returning whether the value is (large enough
* to be) in the heap.
*/
public boolean add(Long value) {
if (heap.contains(value)) {
return true;
} else if (heap.size() < sampleSize) {
heap.add(value);
return true;
} else if (value > heap.element()) {
heap.remove();
heap.add(value);
return true;
} else {
return false;
}
}
/**
* Returns the values in the heap, ordered largest to smallest.
*/
public List extractOrderedList() {
// The only way to extract the order from the heap is element-by-element
// from smallest to largest.
Long[] array = new Long[heap.size()];
for (int i = heap.size() - 1; i >= 0; i--) {
array[i] = heap.remove();
}
return Arrays.asList(array);
}
}
private final long sampleSize;
private final Coder coder;
public ApproximateUniqueCombineFn(long sampleSize, Coder coder) {
this.sampleSize = sampleSize;
this.coder = coder;
}
@Override
public LargestUnique createAccumulator() {
return new LargestUnique(sampleSize);
}
@Override
public LargestUnique addInput(LargestUnique heap, T input) {
try {
heap.add(hash(input, coder));
return heap;
} catch (Throwable e) {
throw new RuntimeException(e);
}
}
@Override
public LargestUnique mergeAccumulators(Iterable heaps) {
Iterator iterator = heaps.iterator();
LargestUnique heap = iterator.next();
while (iterator.hasNext()) {
List largestHashes = iterator.next().extractOrderedList();
for (long hash : largestHashes) {
if (!heap.add(hash)) {
break; // The remainder of this list is all smaller.
}
}
}
return heap;
}
@Override
public Long extractOutput(LargestUnique heap) {
List largestHashes = heap.extractOrderedList();
if (largestHashes.size() < sampleSize) {
return (long) largestHashes.size();
} else {
long smallestSampleHash = largestHashes.get(largestHashes.size() - 1);
double sampleSpaceSize = Long.MAX_VALUE - (double) smallestSampleHash;
// This formula takes into account the possibility of hash collisions,
// which become more likely than not for 2^32 distinct elements.
// Note that log(1+x) ~ x for small x, so for sampleSize << maxHash
// log(1 - sampleSize/sampleSpace) / log(1 - 1/sampleSpace) ~ sampleSize
// and hence estimate ~ sampleSize * HASH_SPACE_SIZE / sampleSpace
// as one would expect.
double estimate = Math.log1p(-sampleSize / sampleSpaceSize)
/ Math.log1p(-1 / sampleSpaceSize)
* HASH_SPACE_SIZE / sampleSpaceSize;
return Math.round(estimate);
}
}
@Override
public Coder getAccumulatorCoder(CoderRegistry registry,
Coder inputCoder) {
return SerializableCoder.of(LargestUnique.class);
}
/**
* Encodes the given element using the given coder and hashes the encoding.
*/
static long hash(T element, Coder coder)
throws CoderException, IOException {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
coder.encode(element, baos, Context.OUTER);
return Hashing.murmur3_128().hashBytes(baos.toByteArray()).asLong();
}
}
/**
* Computes the sampleSize based on the desired estimation error.
*
* @param estimationError should be bounded by [0.01, 0.5]
* @return the sample size needed for the desired estimation error
*/
static long sampleSizeFromEstimationError(double estimationError) {
return Math.round(Math.ceil(4.0 / Math.pow(estimationError, 2.0)));
}
}