All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.flink.api.java.utils.DataSetUtils Maven / Gradle / Ivy

There is a newer version: 1.20.0
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.flink.api.java.utils;

import org.apache.flink.annotation.PublicEvolving;
import org.apache.flink.api.common.JobExecutionResult;
import org.apache.flink.api.common.distributions.DataDistribution;
import org.apache.flink.api.common.functions.BroadcastVariableInitializer;
import org.apache.flink.api.common.functions.MapPartitionFunction;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.common.functions.RichMapPartitionFunction;
import org.apache.flink.api.common.operators.Keys;
import org.apache.flink.api.common.operators.base.PartitionOperatorBase;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.Utils;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.api.java.functions.SampleInCoordinator;
import org.apache.flink.api.java.functions.SampleInPartition;
import org.apache.flink.api.java.functions.SampleWithFraction;
import org.apache.flink.api.java.operators.GroupReduceOperator;
import org.apache.flink.api.java.operators.MapPartitionOperator;
import org.apache.flink.api.java.operators.PartitionOperator;
import org.apache.flink.api.java.summarize.aggregation.SummaryAggregatorFactory;
import org.apache.flink.api.java.summarize.aggregation.TupleSummaryAggregator;
import org.apache.flink.api.java.tuple.Tuple;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.typeutils.TupleTypeInfoBase;
import org.apache.flink.api.java.typeutils.TypeExtractor;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.util.AbstractID;
import org.apache.flink.util.Collector;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;

/**
 * This class provides simple utility methods for zipping elements in a data set with an index or
 * with a unique identifier.
 */
@PublicEvolving
public final class DataSetUtils {

    /**
     * Method that goes over all the elements in each partition in order to retrieve the total
     * number of elements.
     *
     * @param input the DataSet received as input
     * @return a data set containing tuples of subtask index, number of elements mappings.
     */
    public static  DataSet> countElementsPerPartition(DataSet input) {
        return input.mapPartition(
                new RichMapPartitionFunction>() {
                    @Override
                    public void mapPartition(
                            Iterable values, Collector> out)
                            throws Exception {
                        long counter = 0;
                        for (T value : values) {
                            counter++;
                        }
                        out.collect(
                                new Tuple2<>(getRuntimeContext().getIndexOfThisSubtask(), counter));
                    }
                });
    }

    /**
     * Method that assigns a unique {@link Long} value to all elements in the input data set. The
     * generated values are consecutive.
     *
     * @param input the input data set
     * @return a data set of tuple 2 consisting of consecutive ids and initial values.
     */
    public static  DataSet> zipWithIndex(DataSet input) {

        DataSet> elementCount = countElementsPerPartition(input);

        return input.mapPartition(
                        new RichMapPartitionFunction>() {

                            long start = 0;

                            @Override
                            public void open(Configuration parameters) throws Exception {
                                super.open(parameters);

                                List> offsets =
                                        getRuntimeContext()
                                                .getBroadcastVariableWithInitializer(
                                                        "counts",
                                                        new BroadcastVariableInitializer<
                                                                Tuple2,
                                                                List>>() {
                                                            @Override
                                                            public List>
                                                                    initializeBroadcastVariable(
                                                                            Iterable<
                                                                                            Tuple2<
                                                                                                    Integer,
                                                                                                    Long>>
                                                                                    data) {
                                                                // sort the list by task id to
                                                                // calculate the correct offset
                                                                List>
                                                                        sortedData =
                                                                                new ArrayList<>();
                                                                for (Tuple2 datum :
                                                                        data) {
                                                                    sortedData.add(datum);
                                                                }
                                                                Collections.sort(
                                                                        sortedData,
                                                                        new Comparator<
                                                                                Tuple2<
                                                                                        Integer,
                                                                                        Long>>() {
                                                                            @Override
                                                                            public int compare(
                                                                                    Tuple2<
                                                                                                    Integer,
                                                                                                    Long>
                                                                                            o1,
                                                                                    Tuple2<
                                                                                                    Integer,
                                                                                                    Long>
                                                                                            o2) {
                                                                                return o1.f0
                                                                                        .compareTo(
                                                                                                o2.f0);
                                                                            }
                                                                        });
                                                                return sortedData;
                                                            }
                                                        });

                                // compute the offset for each partition
                                for (int i = 0;
                                        i < getRuntimeContext().getIndexOfThisSubtask();
                                        i++) {
                                    start += offsets.get(i).f1;
                                }
                            }

                            @Override
                            public void mapPartition(
                                    Iterable values, Collector> out)
                                    throws Exception {
                                for (T value : values) {
                                    out.collect(new Tuple2<>(start++, value));
                                }
                            }
                        })
                .withBroadcastSet(elementCount, "counts");
    }

    /**
     * Method that assigns a unique {@link Long} value to all elements in the input data set as
     * described below.
     *
     * 
    *
  • a map function is applied to the input data set *
  • each map task holds a counter c which is increased for each record *
  • c is shifted by n bits where n = log2(number of parallel tasks) *
  • to create a unique ID among all tasks, the task id is added to the counter *
  • for each record, the resulting counter is collected *
* * @param input the input data set * @return a data set of tuple 2 consisting of ids and initial values. */ public static DataSet> zipWithUniqueId(DataSet input) { return input.mapPartition( new RichMapPartitionFunction>() { long maxBitSize = getBitSize(Long.MAX_VALUE); long shifter = 0; long start = 0; long taskId = 0; long label = 0; @Override public void open(Configuration parameters) throws Exception { super.open(parameters); shifter = getBitSize(getRuntimeContext().getNumberOfParallelSubtasks() - 1); taskId = getRuntimeContext().getIndexOfThisSubtask(); } @Override public void mapPartition(Iterable values, Collector> out) throws Exception { for (T value : values) { label = (start << shifter) + taskId; if (getBitSize(start) + shifter < maxBitSize) { out.collect(new Tuple2<>(label, value)); start++; } else { throw new Exception( "Exceeded Long value range while generating labels"); } } } }); } // -------------------------------------------------------------------------------------------- // Sample // -------------------------------------------------------------------------------------------- /** * Generate a sample of DataSet by the probability fraction of each element. * * @param withReplacement Whether element can be selected more than once. * @param fraction Probability that each element is chosen, should be [0,1] without replacement, * and [0, ∞) with replacement. While fraction is larger than 1, the elements are expected * to be selected multi times into sample on average. * @return The sampled DataSet */ public static MapPartitionOperator sample( DataSet input, final boolean withReplacement, final double fraction) { return sample(input, withReplacement, fraction, Utils.RNG.nextLong()); } /** * Generate a sample of DataSet by the probability fraction of each element. * * @param withReplacement Whether element can be selected more than once. * @param fraction Probability that each element is chosen, should be [0,1] without replacement, * and [0, ∞) with replacement. While fraction is larger than 1, the elements are expected * to be selected multi times into sample on average. * @param seed random number generator seed. * @return The sampled DataSet */ public static MapPartitionOperator sample( DataSet input, final boolean withReplacement, final double fraction, final long seed) { return input.mapPartition(new SampleWithFraction(withReplacement, fraction, seed)); } /** * Generate a sample of DataSet which contains fixed size elements. * *

NOTE: Sample with fixed size is not as efficient as sample with fraction, * use sample with fraction unless you need exact precision. * * @param withReplacement Whether element can be selected more than once. * @param numSamples The expected sample size. * @return The sampled DataSet */ public static DataSet sampleWithSize( DataSet input, final boolean withReplacement, final int numSamples) { return sampleWithSize(input, withReplacement, numSamples, Utils.RNG.nextLong()); } /** * Generate a sample of DataSet which contains fixed size elements. * *

NOTE: Sample with fixed size is not as efficient as sample with fraction, * use sample with fraction unless you need exact precision. * * @param withReplacement Whether element can be selected more than once. * @param numSamples The expected sample size. * @param seed Random number generator seed. * @return The sampled DataSet */ public static DataSet sampleWithSize( DataSet input, final boolean withReplacement, final int numSamples, final long seed) { SampleInPartition sampleInPartition = new SampleInPartition<>(withReplacement, numSamples, seed); MapPartitionOperator mapPartitionOperator = input.mapPartition(sampleInPartition); // There is no previous group, so the parallelism of GroupReduceOperator is always 1. String callLocation = Utils.getCallLocationName(); SampleInCoordinator sampleInCoordinator = new SampleInCoordinator<>(withReplacement, numSamples, seed); return new GroupReduceOperator<>( mapPartitionOperator, input.getType(), sampleInCoordinator, callLocation); } // -------------------------------------------------------------------------------------------- // Partition // -------------------------------------------------------------------------------------------- /** Range-partitions a DataSet on the specified tuple field positions. */ public static PartitionOperator partitionByRange( DataSet input, DataDistribution distribution, int... fields) { return new PartitionOperator<>( input, PartitionOperatorBase.PartitionMethod.RANGE, new Keys.ExpressionKeys<>(fields, input.getType(), false), distribution, Utils.getCallLocationName()); } /** Range-partitions a DataSet on the specified fields. */ public static PartitionOperator partitionByRange( DataSet input, DataDistribution distribution, String... fields) { return new PartitionOperator<>( input, PartitionOperatorBase.PartitionMethod.RANGE, new Keys.ExpressionKeys<>(fields, input.getType()), distribution, Utils.getCallLocationName()); } /** Range-partitions a DataSet using the specified key selector function. */ public static > PartitionOperator partitionByRange( DataSet input, DataDistribution distribution, KeySelector keyExtractor) { final TypeInformation keyType = TypeExtractor.getKeySelectorTypes(keyExtractor, input.getType()); return new PartitionOperator<>( input, PartitionOperatorBase.PartitionMethod.RANGE, new Keys.SelectorFunctionKeys<>( input.clean(keyExtractor), input.getType(), keyType), distribution, Utils.getCallLocationName()); } // -------------------------------------------------------------------------------------------- // Summarize // -------------------------------------------------------------------------------------------- /** * Summarize a DataSet of Tuples by collecting single pass statistics for all columns. * *

Example usage: * *

{@code
     * Dataset> input = // [...]
     * Tuple3 summary = DataSetUtils.summarize(input)
     *
     * summary.f0.getStandardDeviation()
     * summary.f1.getMaxLength()
     * }
* * @return the summary as a Tuple the same width as input rows */ public static R summarize(DataSet input) throws Exception { if (!input.getType().isTupleType()) { throw new IllegalArgumentException( "summarize() is only implemented for DataSet's of Tuples"); } final TupleTypeInfoBase inType = (TupleTypeInfoBase) input.getType(); DataSet> result = input.mapPartition( new MapPartitionFunction>() { @Override public void mapPartition( Iterable values, Collector> out) throws Exception { TupleSummaryAggregator aggregator = SummaryAggregatorFactory.create(inType); for (Tuple value : values) { aggregator.aggregate(value); } out.collect(aggregator); } }) .reduce( new ReduceFunction>() { @Override public TupleSummaryAggregator reduce( TupleSummaryAggregator agg1, TupleSummaryAggregator agg2) throws Exception { agg1.combine(agg2); return agg1; } }); return result.collect().get(0).result(); } // -------------------------------------------------------------------------------------------- // Checksum // -------------------------------------------------------------------------------------------- /** * Convenience method to get the count (number of elements) of a DataSet as well as the checksum * (sum over element hashes). * * @return A ChecksumHashCode that represents the count and checksum of elements in the data * set. * @deprecated replaced with {@code org.apache.flink.graph.asm.dataset.ChecksumHashCode} in * Gelly */ @Deprecated public static Utils.ChecksumHashCode checksumHashCode(DataSet input) throws Exception { final String id = new AbstractID().toString(); input.output(new Utils.ChecksumHashCodeHelper(id)).name("ChecksumHashCode"); JobExecutionResult res = input.getExecutionEnvironment().execute(); return res.getAccumulatorResult(id); } // ************************************************************************* // UTIL METHODS // ************************************************************************* public static int getBitSize(long value) { if (value > Integer.MAX_VALUE) { return 64 - Integer.numberOfLeadingZeros((int) (value >> 32)); } else { return 32 - Integer.numberOfLeadingZeros((int) value); } } /** Private constructor to prevent instantiation. */ private DataSetUtils() { throw new RuntimeException(); } }




© 2015 - 2024 Weber Informatics LLC | Privacy Policy