org.apache.flink.api.java.utils.DataSetUtils Maven / Gradle / Ivy
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.api.java.utils;
import org.apache.flink.annotation.PublicEvolving;
import org.apache.flink.api.common.JobExecutionResult;
import org.apache.flink.api.common.distributions.DataDistribution;
import org.apache.flink.api.common.functions.BroadcastVariableInitializer;
import org.apache.flink.api.common.functions.MapPartitionFunction;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.common.functions.RichMapPartitionFunction;
import org.apache.flink.api.common.operators.Keys;
import org.apache.flink.api.common.operators.base.PartitionOperatorBase;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.Utils;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.api.java.functions.SampleInCoordinator;
import org.apache.flink.api.java.functions.SampleInPartition;
import org.apache.flink.api.java.functions.SampleWithFraction;
import org.apache.flink.api.java.operators.GroupReduceOperator;
import org.apache.flink.api.java.operators.MapPartitionOperator;
import org.apache.flink.api.java.operators.PartitionOperator;
import org.apache.flink.api.java.summarize.aggregation.SummaryAggregatorFactory;
import org.apache.flink.api.java.summarize.aggregation.TupleSummaryAggregator;
import org.apache.flink.api.java.tuple.Tuple;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.typeutils.TupleTypeInfoBase;
import org.apache.flink.api.java.typeutils.TypeExtractor;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.util.AbstractID;
import org.apache.flink.util.Collector;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
/**
* This class provides simple utility methods for zipping elements in a data set with an index or
* with a unique identifier.
*/
@PublicEvolving
public final class DataSetUtils {
/**
* Method that goes over all the elements in each partition in order to retrieve the total
* number of elements.
*
* @param input the DataSet received as input
* @return a data set containing tuples of subtask index, number of elements mappings.
*/
public static DataSet> countElementsPerPartition(DataSet input) {
return input.mapPartition(
new RichMapPartitionFunction>() {
@Override
public void mapPartition(
Iterable values, Collector> out)
throws Exception {
long counter = 0;
for (T value : values) {
counter++;
}
out.collect(
new Tuple2<>(getRuntimeContext().getIndexOfThisSubtask(), counter));
}
});
}
/**
* Method that assigns a unique {@link Long} value to all elements in the input data set. The
* generated values are consecutive.
*
* @param input the input data set
* @return a data set of tuple 2 consisting of consecutive ids and initial values.
*/
public static DataSet> zipWithIndex(DataSet input) {
DataSet> elementCount = countElementsPerPartition(input);
return input.mapPartition(
new RichMapPartitionFunction>() {
long start = 0;
@Override
public void open(Configuration parameters) throws Exception {
super.open(parameters);
List> offsets =
getRuntimeContext()
.getBroadcastVariableWithInitializer(
"counts",
new BroadcastVariableInitializer<
Tuple2,
List>>() {
@Override
public List>
initializeBroadcastVariable(
Iterable<
Tuple2<
Integer,
Long>>
data) {
// sort the list by task id to
// calculate the correct offset
List>
sortedData =
new ArrayList<>();
for (Tuple2 datum :
data) {
sortedData.add(datum);
}
Collections.sort(
sortedData,
new Comparator<
Tuple2<
Integer,
Long>>() {
@Override
public int compare(
Tuple2<
Integer,
Long>
o1,
Tuple2<
Integer,
Long>
o2) {
return o1.f0
.compareTo(
o2.f0);
}
});
return sortedData;
}
});
// compute the offset for each partition
for (int i = 0;
i < getRuntimeContext().getIndexOfThisSubtask();
i++) {
start += offsets.get(i).f1;
}
}
@Override
public void mapPartition(
Iterable values, Collector> out)
throws Exception {
for (T value : values) {
out.collect(new Tuple2<>(start++, value));
}
}
})
.withBroadcastSet(elementCount, "counts");
}
/**
* Method that assigns a unique {@link Long} value to all elements in the input data set as
* described below.
*
*
* - a map function is applied to the input data set
*
- each map task holds a counter c which is increased for each record
*
- c is shifted by n bits where n = log2(number of parallel tasks)
*
- to create a unique ID among all tasks, the task id is added to the counter
*
- for each record, the resulting counter is collected
*
*
* @param input the input data set
* @return a data set of tuple 2 consisting of ids and initial values.
*/
public static DataSet> zipWithUniqueId(DataSet input) {
return input.mapPartition(
new RichMapPartitionFunction>() {
long maxBitSize = getBitSize(Long.MAX_VALUE);
long shifter = 0;
long start = 0;
long taskId = 0;
long label = 0;
@Override
public void open(Configuration parameters) throws Exception {
super.open(parameters);
shifter = getBitSize(getRuntimeContext().getNumberOfParallelSubtasks() - 1);
taskId = getRuntimeContext().getIndexOfThisSubtask();
}
@Override
public void mapPartition(Iterable values, Collector> out)
throws Exception {
for (T value : values) {
label = (start << shifter) + taskId;
if (getBitSize(start) + shifter < maxBitSize) {
out.collect(new Tuple2<>(label, value));
start++;
} else {
throw new Exception(
"Exceeded Long value range while generating labels");
}
}
}
});
}
// --------------------------------------------------------------------------------------------
// Sample
// --------------------------------------------------------------------------------------------
/**
* Generate a sample of DataSet by the probability fraction of each element.
*
* @param withReplacement Whether element can be selected more than once.
* @param fraction Probability that each element is chosen, should be [0,1] without replacement,
* and [0, ∞) with replacement. While fraction is larger than 1, the elements are expected
* to be selected multi times into sample on average.
* @return The sampled DataSet
*/
public static MapPartitionOperator sample(
DataSet input, final boolean withReplacement, final double fraction) {
return sample(input, withReplacement, fraction, Utils.RNG.nextLong());
}
/**
* Generate a sample of DataSet by the probability fraction of each element.
*
* @param withReplacement Whether element can be selected more than once.
* @param fraction Probability that each element is chosen, should be [0,1] without replacement,
* and [0, ∞) with replacement. While fraction is larger than 1, the elements are expected
* to be selected multi times into sample on average.
* @param seed random number generator seed.
* @return The sampled DataSet
*/
public static MapPartitionOperator sample(
DataSet input,
final boolean withReplacement,
final double fraction,
final long seed) {
return input.mapPartition(new SampleWithFraction(withReplacement, fraction, seed));
}
/**
* Generate a sample of DataSet which contains fixed size elements.
*
* NOTE: Sample with fixed size is not as efficient as sample with fraction,
* use sample with fraction unless you need exact precision.
*
* @param withReplacement Whether element can be selected more than once.
* @param numSamples The expected sample size.
* @return The sampled DataSet
*/
public static DataSet sampleWithSize(
DataSet input, final boolean withReplacement, final int numSamples) {
return sampleWithSize(input, withReplacement, numSamples, Utils.RNG.nextLong());
}
/**
* Generate a sample of DataSet which contains fixed size elements.
*
* NOTE: Sample with fixed size is not as efficient as sample with fraction,
* use sample with fraction unless you need exact precision.
*
* @param withReplacement Whether element can be selected more than once.
* @param numSamples The expected sample size.
* @param seed Random number generator seed.
* @return The sampled DataSet
*/
public static DataSet sampleWithSize(
DataSet input,
final boolean withReplacement,
final int numSamples,
final long seed) {
SampleInPartition sampleInPartition =
new SampleInPartition<>(withReplacement, numSamples, seed);
MapPartitionOperator mapPartitionOperator = input.mapPartition(sampleInPartition);
// There is no previous group, so the parallelism of GroupReduceOperator is always 1.
String callLocation = Utils.getCallLocationName();
SampleInCoordinator sampleInCoordinator =
new SampleInCoordinator<>(withReplacement, numSamples, seed);
return new GroupReduceOperator<>(
mapPartitionOperator, input.getType(), sampleInCoordinator, callLocation);
}
// --------------------------------------------------------------------------------------------
// Partition
// --------------------------------------------------------------------------------------------
/** Range-partitions a DataSet on the specified tuple field positions. */
public static PartitionOperator partitionByRange(
DataSet input, DataDistribution distribution, int... fields) {
return new PartitionOperator<>(
input,
PartitionOperatorBase.PartitionMethod.RANGE,
new Keys.ExpressionKeys<>(fields, input.getType(), false),
distribution,
Utils.getCallLocationName());
}
/** Range-partitions a DataSet on the specified fields. */
public static PartitionOperator partitionByRange(
DataSet input, DataDistribution distribution, String... fields) {
return new PartitionOperator<>(
input,
PartitionOperatorBase.PartitionMethod.RANGE,
new Keys.ExpressionKeys<>(fields, input.getType()),
distribution,
Utils.getCallLocationName());
}
/** Range-partitions a DataSet using the specified key selector function. */
public static > PartitionOperator partitionByRange(
DataSet input, DataDistribution distribution, KeySelector keyExtractor) {
final TypeInformation keyType =
TypeExtractor.getKeySelectorTypes(keyExtractor, input.getType());
return new PartitionOperator<>(
input,
PartitionOperatorBase.PartitionMethod.RANGE,
new Keys.SelectorFunctionKeys<>(
input.clean(keyExtractor), input.getType(), keyType),
distribution,
Utils.getCallLocationName());
}
// --------------------------------------------------------------------------------------------
// Summarize
// --------------------------------------------------------------------------------------------
/**
* Summarize a DataSet of Tuples by collecting single pass statistics for all columns.
*
* Example usage:
*
*
{@code
* Dataset> input = // [...]
* Tuple3 summary = DataSetUtils.summarize(input)
*
* summary.f0.getStandardDeviation()
* summary.f1.getMaxLength()
* }
*
* @return the summary as a Tuple the same width as input rows
*/
public static R summarize(DataSet input)
throws Exception {
if (!input.getType().isTupleType()) {
throw new IllegalArgumentException(
"summarize() is only implemented for DataSet's of Tuples");
}
final TupleTypeInfoBase> inType = (TupleTypeInfoBase>) input.getType();
DataSet> result =
input.mapPartition(
new MapPartitionFunction>() {
@Override
public void mapPartition(
Iterable values,
Collector> out)
throws Exception {
TupleSummaryAggregator aggregator =
SummaryAggregatorFactory.create(inType);
for (Tuple value : values) {
aggregator.aggregate(value);
}
out.collect(aggregator);
}
})
.reduce(
new ReduceFunction>() {
@Override
public TupleSummaryAggregator reduce(
TupleSummaryAggregator agg1,
TupleSummaryAggregator agg2)
throws Exception {
agg1.combine(agg2);
return agg1;
}
});
return result.collect().get(0).result();
}
// --------------------------------------------------------------------------------------------
// Checksum
// --------------------------------------------------------------------------------------------
/**
* Convenience method to get the count (number of elements) of a DataSet as well as the checksum
* (sum over element hashes).
*
* @return A ChecksumHashCode that represents the count and checksum of elements in the data
* set.
* @deprecated This method will be removed at some point.
*/
@Deprecated
public static Utils.ChecksumHashCode checksumHashCode(DataSet input) throws Exception {
final String id = new AbstractID().toString();
input.output(new Utils.ChecksumHashCodeHelper(id)).name("ChecksumHashCode");
JobExecutionResult res = input.getExecutionEnvironment().execute();
return res.getAccumulatorResult(id);
}
// *************************************************************************
// UTIL METHODS
// *************************************************************************
public static int getBitSize(long value) {
if (value > Integer.MAX_VALUE) {
return 64 - Integer.numberOfLeadingZeros((int) (value >> 32));
} else {
return 32 - Integer.numberOfLeadingZeros((int) value);
}
}
/** Private constructor to prevent instantiation. */
private DataSetUtils() {
throw new RuntimeException();
}
}