org.apache.flink.api.java.utils.DataSetUtils Maven / Gradle / Ivy
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.api.java.utils;
import org.apache.flink.annotation.PublicEvolving;
import org.apache.flink.api.common.JobExecutionResult;
import org.apache.flink.api.common.distributions.DataDistribution;
import org.apache.flink.api.common.functions.BroadcastVariableInitializer;
import org.apache.flink.api.common.functions.MapPartitionFunction;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.common.functions.RichMapPartitionFunction;
import org.apache.flink.api.common.operators.Keys;
import org.apache.flink.api.common.operators.base.PartitionOperatorBase;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.Utils;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.api.java.functions.SampleInCoordinator;
import org.apache.flink.api.java.functions.SampleInPartition;
import org.apache.flink.api.java.functions.SampleWithFraction;
import org.apache.flink.api.java.operators.GroupReduceOperator;
import org.apache.flink.api.java.operators.MapPartitionOperator;
import org.apache.flink.api.java.operators.PartitionOperator;
import org.apache.flink.api.java.summarize.aggregation.SummaryAggregatorFactory;
import org.apache.flink.api.java.summarize.aggregation.TupleSummaryAggregator;
import org.apache.flink.api.java.tuple.Tuple;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.typeutils.TupleTypeInfoBase;
import org.apache.flink.api.java.typeutils.TypeExtractor;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.util.AbstractID;
import org.apache.flink.util.Collector;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
/**
* This class provides simple utility methods for zipping elements in a data set with an index
* or with a unique identifier.
*/
@PublicEvolving
public final class DataSetUtils {
/**
* Method that goes over all the elements in each partition in order to retrieve
* the total number of elements.
*
* @param input the DataSet received as input
* @return a data set containing tuples of subtask index, number of elements mappings.
*/
public static DataSet> countElementsPerPartition(DataSet input) {
return input.mapPartition(new RichMapPartitionFunction>() {
@Override
public void mapPartition(Iterable values, Collector> out) throws Exception {
long counter = 0;
for (T value : values) {
counter++;
}
out.collect(new Tuple2<>(getRuntimeContext().getIndexOfThisSubtask(), counter));
}
});
}
/**
* Method that assigns a unique {@link Long} value to all elements in the input data set. The generated values are
* consecutive.
*
* @param input the input data set
* @return a data set of tuple 2 consisting of consecutive ids and initial values.
*/
public static DataSet> zipWithIndex(DataSet input) {
DataSet> elementCount = countElementsPerPartition(input);
return input.mapPartition(new RichMapPartitionFunction>() {
long start = 0;
@Override
public void open(Configuration parameters) throws Exception {
super.open(parameters);
List> offsets = getRuntimeContext().getBroadcastVariableWithInitializer(
"counts",
new BroadcastVariableInitializer, List>>() {
@Override
public List> initializeBroadcastVariable(Iterable> data) {
// sort the list by task id to calculate the correct offset
List> sortedData = new ArrayList<>();
for (Tuple2 datum : data) {
sortedData.add(datum);
}
Collections.sort(sortedData, new Comparator>() {
@Override
public int compare(Tuple2 o1, Tuple2 o2) {
return o1.f0.compareTo(o2.f0);
}
});
return sortedData;
}
});
// compute the offset for each partition
for (int i = 0; i < getRuntimeContext().getIndexOfThisSubtask(); i++) {
start += offsets.get(i).f1;
}
}
@Override
public void mapPartition(Iterable values, Collector> out) throws Exception {
for (T value: values) {
out.collect(new Tuple2<>(start++, value));
}
}
}).withBroadcastSet(elementCount, "counts");
}
/**
* Method that assigns a unique {@link Long} value to all elements in the input data set as described below.
*
* - a map function is applied to the input data set
*
- each map task holds a counter c which is increased for each record
*
- c is shifted by n bits where n = log2(number of parallel tasks)
*
- to create a unique ID among all tasks, the task id is added to the counter
*
- for each record, the resulting counter is collected
*
*
* @param input the input data set
* @return a data set of tuple 2 consisting of ids and initial values.
*/
public static DataSet> zipWithUniqueId (DataSet input) {
return input.mapPartition(new RichMapPartitionFunction>() {
long maxBitSize = getBitSize(Long.MAX_VALUE);
long shifter = 0;
long start = 0;
long taskId = 0;
long label = 0;
@Override
public void open(Configuration parameters) throws Exception {
super.open(parameters);
shifter = getBitSize(getRuntimeContext().getNumberOfParallelSubtasks() - 1);
taskId = getRuntimeContext().getIndexOfThisSubtask();
}
@Override
public void mapPartition(Iterable values, Collector> out) throws Exception {
for (T value : values) {
label = (start << shifter) + taskId;
if (getBitSize(start) + shifter < maxBitSize) {
out.collect(new Tuple2<>(label, value));
start++;
} else {
throw new Exception("Exceeded Long value range while generating labels");
}
}
}
});
}
// --------------------------------------------------------------------------------------------
// Sample
// --------------------------------------------------------------------------------------------
/**
* Generate a sample of DataSet by the probability fraction of each element.
*
* @param withReplacement Whether element can be selected more than once.
* @param fraction Probability that each element is chosen, should be [0,1] without replacement,
* and [0, ∞) with replacement. While fraction is larger than 1, the elements are
* expected to be selected multi times into sample on average.
* @return The sampled DataSet
*/
public static MapPartitionOperator sample(
DataSet input,
final boolean withReplacement,
final double fraction) {
return sample(input, withReplacement, fraction, Utils.RNG.nextLong());
}
/**
* Generate a sample of DataSet by the probability fraction of each element.
*
* @param withReplacement Whether element can be selected more than once.
* @param fraction Probability that each element is chosen, should be [0,1] without replacement,
* and [0, ∞) with replacement. While fraction is larger than 1, the elements are
* expected to be selected multi times into sample on average.
* @param seed random number generator seed.
* @return The sampled DataSet
*/
public static MapPartitionOperator sample(
DataSet input,
final boolean withReplacement,
final double fraction,
final long seed) {
return input.mapPartition(new SampleWithFraction(withReplacement, fraction, seed));
}
/**
* Generate a sample of DataSet which contains fixed size elements.
*
* NOTE: Sample with fixed size is not as efficient as sample with fraction, use sample with
* fraction unless you need exact precision.
*
* @param withReplacement Whether element can be selected more than once.
* @param numSamples The expected sample size.
* @return The sampled DataSet
*/
public static DataSet sampleWithSize(
DataSet input,
final boolean withReplacement,
final int numSamples) {
return sampleWithSize(input, withReplacement, numSamples, Utils.RNG.nextLong());
}
/**
* Generate a sample of DataSet which contains fixed size elements.
*
* NOTE: Sample with fixed size is not as efficient as sample with fraction, use sample with
* fraction unless you need exact precision.
*
* @param withReplacement Whether element can be selected more than once.
* @param numSamples The expected sample size.
* @param seed Random number generator seed.
* @return The sampled DataSet
*/
public static DataSet sampleWithSize(
DataSet input,
final boolean withReplacement,
final int numSamples,
final long seed) {
SampleInPartition sampleInPartition = new SampleInPartition<>(withReplacement, numSamples, seed);
MapPartitionOperator mapPartitionOperator = input.mapPartition(sampleInPartition);
// There is no previous group, so the parallelism of GroupReduceOperator is always 1.
String callLocation = Utils.getCallLocationName();
SampleInCoordinator sampleInCoordinator = new SampleInCoordinator<>(withReplacement, numSamples, seed);
return new GroupReduceOperator<>(mapPartitionOperator, input.getType(), sampleInCoordinator, callLocation);
}
// --------------------------------------------------------------------------------------------
// Partition
// --------------------------------------------------------------------------------------------
/**
* Range-partitions a DataSet on the specified tuple field positions.
*/
public static PartitionOperator partitionByRange(DataSet input, DataDistribution distribution, int... fields) {
return new PartitionOperator<>(input, PartitionOperatorBase.PartitionMethod.RANGE, new Keys.ExpressionKeys<>(fields, input.getType(), false), distribution, Utils.getCallLocationName());
}
/**
* Range-partitions a DataSet on the specified fields.
*/
public static PartitionOperator partitionByRange(DataSet input, DataDistribution distribution, String... fields) {
return new PartitionOperator<>(input, PartitionOperatorBase.PartitionMethod.RANGE, new Keys.ExpressionKeys<>(fields, input.getType()), distribution, Utils.getCallLocationName());
}
/**
* Range-partitions a DataSet using the specified key selector function.
*/
public static > PartitionOperator partitionByRange(DataSet input, DataDistribution distribution, KeySelector keyExtractor) {
final TypeInformation keyType = TypeExtractor.getKeySelectorTypes(keyExtractor, input.getType());
return new PartitionOperator<>(input, PartitionOperatorBase.PartitionMethod.RANGE, new Keys.SelectorFunctionKeys<>(input.clean(keyExtractor), input.getType(), keyType), distribution, Utils.getCallLocationName());
}
// --------------------------------------------------------------------------------------------
// Summarize
// --------------------------------------------------------------------------------------------
/**
* Summarize a DataSet of Tuples by collecting single pass statistics for all columns.
*
* Example usage:
*
* {@code
* Dataset> input = // [...]
* Tuple3 summary = DataSetUtils.summarize(input)
*
* summary.f0.getStandardDeviation()
* summary.f1.getMaxLength()
* }
*
* @return the summary as a Tuple the same width as input rows
*/
public static R summarize(DataSet input) throws Exception {
if (!input.getType().isTupleType()) {
throw new IllegalArgumentException("summarize() is only implemented for DataSet's of Tuples");
}
final TupleTypeInfoBase> inType = (TupleTypeInfoBase>) input.getType();
DataSet> result = input.mapPartition(new MapPartitionFunction>() {
@Override
public void mapPartition(Iterable values, Collector> out) throws Exception {
TupleSummaryAggregator aggregator = SummaryAggregatorFactory.create(inType);
for (Tuple value : values) {
aggregator.aggregate(value);
}
out.collect(aggregator);
}
}).reduce(new ReduceFunction>() {
@Override
public TupleSummaryAggregator reduce(TupleSummaryAggregator agg1, TupleSummaryAggregator agg2) throws Exception {
agg1.combine(agg2);
return agg1;
}
});
return result.collect().get(0).result();
}
// --------------------------------------------------------------------------------------------
// Checksum
// --------------------------------------------------------------------------------------------
/**
* Convenience method to get the count (number of elements) of a DataSet
* as well as the checksum (sum over element hashes).
*
* @return A ChecksumHashCode that represents the count and checksum of elements in the data set.
* @deprecated replaced with {@code org.apache.flink.graph.asm.dataset.ChecksumHashCode} in Gelly
*/
@Deprecated
public static Utils.ChecksumHashCode checksumHashCode(DataSet input) throws Exception {
final String id = new AbstractID().toString();
input.output(new Utils.ChecksumHashCodeHelper(id)).name("ChecksumHashCode");
JobExecutionResult res = input.getExecutionEnvironment().execute();
return res. getAccumulatorResult(id);
}
// *************************************************************************
// UTIL METHODS
// *************************************************************************
public static int getBitSize(long value) {
if (value > Integer.MAX_VALUE) {
return 64 - Integer.numberOfLeadingZeros((int) (value >> 32));
} else {
return 32 - Integer.numberOfLeadingZeros((int) value);
}
}
/**
* Private constructor to prevent instantiation.
*/
private DataSetUtils() {
throw new RuntimeException();
}
}