All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.paimon.flink.shuffle.RangeShuffle Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.paimon.flink.shuffle;

import org.apache.paimon.annotation.VisibleForTesting;
import org.apache.paimon.data.DataGetters;
import org.apache.paimon.flink.FlinkRowWrapper;
import org.apache.paimon.types.InternalRowToSizeVisitor;
import org.apache.paimon.types.RowType;
import org.apache.paimon.utils.Pair;
import org.apache.paimon.utils.SerializableSupplier;

import org.apache.flink.annotation.Internal;
import org.apache.flink.api.common.functions.OpenContext;
import org.apache.flink.api.common.functions.Partitioner;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.dag.Transformation;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.api.java.typeutils.ListTypeInfo;
import org.apache.flink.api.java.typeutils.TupleTypeInfo;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.operators.BoundedOneInput;
import org.apache.flink.streaming.api.operators.InputSelectable;
import org.apache.flink.streaming.api.operators.InputSelection;
import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
import org.apache.flink.streaming.api.operators.StreamMap;
import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
import org.apache.flink.streaming.api.transformations.OneInputTransformation;
import org.apache.flink.streaming.api.transformations.PartitionTransformation;
import org.apache.flink.streaming.api.transformations.StreamExchangeMode;
import org.apache.flink.streaming.api.transformations.TwoInputTransformation;
import org.apache.flink.streaming.runtime.partitioner.BroadcastPartitioner;
import org.apache.flink.streaming.runtime.partitioner.CustomPartitionerWrapper;
import org.apache.flink.streaming.runtime.partitioner.ForwardPartitioner;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.table.data.RowData;
import org.apache.flink.table.runtime.operators.TableStreamOperator;
import org.apache.flink.table.runtime.util.StreamRecordCollector;
import org.apache.flink.util.Collector;
import org.apache.flink.util.Preconditions;
import org.apache.flink.util.XORShiftRandom;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.PriorityQueue;
import java.util.Random;
import java.util.function.BiFunction;
import java.util.stream.Collectors;

/**
 * RangeShuffle Util to shuffle the input stream by the sampling range. See `rangeShuffleBykey`
 * method how to build the topo.
 */
public class RangeShuffle {

    /**
     * The RelNode with range-partition distribution will create the following transformations.
     *
     * 

Explanation of the following figure: "[LSample, n]" means operator is LSample and * parallelism is n, "LSample" means LocalSampleOperator, "GSample" means GlobalSampleOperator, * "ARange" means AssignRangeId, "RRange" means RemoveRangeId. * *

{@code
     * [IN,n]->[LSample,n]->[GSample,1]-BROADCAST
     *    \                                    \
     *     -----------------------------BATCH-[ARange,n]-PARTITION->[RRange,m]->
     * }
* *

The streams except the sample and histogram process stream will be blocked, so the sample * and histogram process stream does not care about requiredExchangeMode. */ public static DataStream> rangeShuffleByKey( DataStream> inputDataStream, SerializableSupplier> keyComparator, TypeInformation keyTypeInformation, int localSampleSize, int globalSampleSize, int rangeNum, int outParallelism, RowType valueRowType, boolean isSortBySize) { Transformation> input = inputDataStream.getTransformation(); OneInputTransformation, Tuple2> keyInput = new OneInputTransformation<>( input, "ABSTRACT KEY AND SIZE", new StreamMap<>(new KeyAndSizeExtractor<>(valueRowType, isSortBySize)), new TupleTypeInfo<>(keyTypeInformation, BasicTypeInfo.INT_TYPE_INFO), input.getParallelism()); // 1. Fixed size sample in each partitions. OneInputTransformation, Tuple3> localSample = new OneInputTransformation<>( keyInput, "LOCAL SAMPLE", new LocalSampleOperator<>(localSampleSize), new TupleTypeInfo<>( BasicTypeInfo.DOUBLE_TYPE_INFO, keyTypeInformation, BasicTypeInfo.INT_TYPE_INFO), keyInput.getParallelism()); // 2. Collect all the samples and gather them into a sorted key range. OneInputTransformation, List> sampleAndHistogram = new OneInputTransformation<>( localSample, "GLOBAL SAMPLE", new GlobalSampleOperator<>(globalSampleSize, keyComparator, rangeNum), new ListTypeInfo<>(keyTypeInformation), 1); // 3. Take range boundaries as broadcast input and take the tuple of partition id and // record as output. // The shuffle mode of input edge must be BATCH to avoid dead lock. See // DeadlockBreakupProcessor. TwoInputTransformation, Tuple2, Tuple2>> preparePartition = new TwoInputTransformation<>( new PartitionTransformation<>( sampleAndHistogram, new BroadcastPartitioner<>(), StreamExchangeMode.BATCH), new PartitionTransformation<>( input, new ForwardPartitioner<>(), StreamExchangeMode.BATCH), "ASSIGN RANGE INDEX", new AssignRangeIndexOperator<>(keyComparator), new TupleTypeInfo<>( BasicTypeInfo.INT_TYPE_INFO, input.getOutputType()), input.getParallelism()); // 4. Remove the partition id. (shuffle according range partition) return new DataStream<>( inputDataStream.getExecutionEnvironment(), new OneInputTransformation<>( new PartitionTransformation<>( preparePartition, new CustomPartitionerWrapper<>( new AssignRangeIndexOperator.RangePartitioner(rangeNum), new AssignRangeIndexOperator.Tuple2KeySelector<>()), StreamExchangeMode.BATCH), "REMOVE RANGE INDEX", new RemoveRangeIndexOperator<>(), input.getOutputType(), outParallelism)); } /** KeyAndSizeExtractor is responsible for extracting the sort key and row size. */ public static class KeyAndSizeExtractor extends RichMapFunction, Tuple2> { private final RowType rowType; private final boolean isSortBySize; private transient List> fieldSizeCalculator; public KeyAndSizeExtractor(RowType rowType, boolean isSortBySize) { this.rowType = rowType; this.isSortBySize = isSortBySize; } /** * Do not annotate with @override here to maintain compatibility with Flink * 1.18-. */ public void open(OpenContext openContext) throws Exception { open(new Configuration()); } /** * Do not annotate with @override here to maintain compatibility with Flink * 2.0+. */ public void open(Configuration parameters) throws Exception { InternalRowToSizeVisitor internalRowToSizeVisitor = new InternalRowToSizeVisitor(); fieldSizeCalculator = rowType.getFieldTypes().stream() .map(dataType -> dataType.accept(internalRowToSizeVisitor)) .collect(Collectors.toList()); } @Override public Tuple2 map(Tuple2 keyAndRowData) throws Exception { if (isSortBySize) { int size = 0; for (int i = 0; i < fieldSizeCalculator.size(); i++) { size += fieldSizeCalculator .get(i) .apply(new FlinkRowWrapper(keyAndRowData.f1), i); } return new Tuple2<>(keyAndRowData.f0, size); } else { // when basing on quantity, we don't need the size of the data, so setting it to a // constant of 1 would be sufficient. return new Tuple2<>(keyAndRowData.f0, 1); } } } /** * LocalSampleOperator wraps the sample logic on the partition side (the first phase of * distributed sample algorithm). Outputs sampled weight with record. * *

See {@link Sampler}. */ @Internal public static class LocalSampleOperator extends TableStreamOperator> implements OneInputStreamOperator, Tuple3>, BoundedOneInput { private static final long serialVersionUID = 1L; private final int numSample; private transient Collector> collector; private transient Sampler> sampler; public LocalSampleOperator(int numSample) { this.numSample = numSample; } @Override public void open() throws Exception { super.open(); this.collector = new StreamRecordCollector<>(output); sampler = new Sampler<>(numSample, System.nanoTime()); } @Override public void processElement(StreamRecord> streamRecord) throws Exception { sampler.collect(streamRecord.getValue()); } @Override public void endInput() { Iterator>> sampled = sampler.sample(); while (sampled.hasNext()) { Tuple2> next = sampled.next(); collector.collect(new Tuple3<>(next.f0, next.f1.f0, next.f1.f1)); } } } /** * Global sample for range partition. Inputs weight with record. Outputs list of sampled record. * *

See {@link Sampler}. */ private static class GlobalSampleOperator extends TableStreamOperator> implements OneInputStreamOperator, List>, BoundedOneInput { private static final long serialVersionUID = 1L; private final int numSample; private final int rangesNum; private final SerializableSupplier> comparatorSupplier; private transient Comparator keyComparator; private transient Collector> collector; private transient Sampler> sampler; public GlobalSampleOperator( int numSample, SerializableSupplier> comparatorSupplier, int rangesNum) { this.numSample = numSample; this.comparatorSupplier = comparatorSupplier; this.rangesNum = rangesNum; } @Override public void open() throws Exception { super.open(); this.keyComparator = comparatorSupplier.get(); this.sampler = new Sampler<>(numSample, 0L); this.collector = new StreamRecordCollector<>(output); } @Override public void processElement(StreamRecord> record) throws Exception { Tuple3 tuple = record.getValue(); sampler.collect(tuple.f0, new Tuple2<>(tuple.f1, tuple.f2)); } @Override public void endInput() { Iterator>> sampled = sampler.sample(); List> sampledData = new ArrayList<>(); while (sampled.hasNext()) { sampledData.add(sampled.next().f1); } sampledData.sort((o1, o2) -> keyComparator.compare(o1.f0, o2.f0)); List range; if (sampledData.isEmpty()) { range = new ArrayList<>(); } else { range = Arrays.asList(allocateRangeBaseSize(sampledData, rangesNum)); } collector.collect(range); } } /** * This two-input-operator require an input with RangeBoundaries as broadcast input, and * generate Tuple2 which includes range index and record from the other input itself as output. */ private static class AssignRangeIndexOperator extends TableStreamOperator>> implements TwoInputStreamOperator< List, Tuple2, Tuple2>>, InputSelectable { private static final long serialVersionUID = 1L; private final SerializableSupplier> keyComparatorSupplier; private transient List> keyIndex; private transient Collector>> collector; private transient Comparator keyComparator; public AssignRangeIndexOperator(SerializableSupplier> keyComparatorSupplier) { this.keyComparatorSupplier = keyComparatorSupplier; } @Override public void open() throws Exception { super.open(); this.keyComparator = keyComparatorSupplier.get(); this.collector = new StreamRecordCollector<>(output); } @Override public void processElement1(StreamRecord> streamRecord) { keyIndex = new ArrayList<>(); T last = null; int index = 0; for (T t : streamRecord.getValue()) { if (last != null && keyComparator.compare(last, t) == 0) { keyIndex.get(keyIndex.size() - 1).getRight().add(index++); } else { Pair pair = Pair.of(t, new RandomList()); pair.getRight().add(index++); keyIndex.add(pair); last = t; } } } @Override public void processElement2(StreamRecord> streamRecord) { if (keyIndex == null) { throw new RuntimeException("There should be one data from the first input."); } // If the range number is 1, the range index will be 0 for all records. if (keyIndex.isEmpty()) { collector.collect(new Tuple2<>(0, streamRecord.getValue())); } else { Tuple2 row = streamRecord.getValue(); collector.collect(new Tuple2<>(binarySearch(row.f0), row)); } } @Override public InputSelection nextSelection() { return keyIndex == null ? InputSelection.FIRST : InputSelection.ALL; } private int binarySearch(T key) { int lastIndex = this.keyIndex.size() - 1; int low = 0; int high = lastIndex; while (low <= high) { final int mid = (low + high) >>> 1; final Pair indexPair = keyIndex.get(mid); final int result = keyComparator.compare(key, indexPair.getLeft()); if (result > 0) { low = mid + 1; } else if (result < 0) { high = mid - 1; } else { return indexPair.getRight().get(); } } // key not found, but the low index is the target // bucket, since the boundaries are the upper bound return low > lastIndex ? (keyIndex.get(lastIndex).getRight().get() + 1) : keyIndex.get(low).getRight().get(); } /** A {@link KeySelector} to select by f0 of tuple2. */ public static class Tuple2KeySelector implements KeySelector>, Integer> { private static final long serialVersionUID = 1L; @Override public Integer getKey(Tuple2> tuple2) throws Exception { return tuple2.f0; } } /** A {@link Partitioner} to partition by id with range. */ public static class RangePartitioner implements Partitioner { private static final long serialVersionUID = 1L; private final int totalRangeNum; public RangePartitioner(int totalRangeNum) { this.totalRangeNum = totalRangeNum; } @Override public int partition(Integer key, int numPartitions) { Preconditions.checkArgument( numPartitions <= totalRangeNum, "Num of subPartitions should <= totalRangeNum: " + totalRangeNum); int partition = key / (totalRangeNum / numPartitions); return Math.min(numPartitions - 1, partition); } } } /** Remove the range index and return the actual record. */ private static class RemoveRangeIndexOperator extends TableStreamOperator> implements OneInputStreamOperator< Tuple2>, Tuple2> { private static final long serialVersionUID = 1L; private transient Collector> collector; @Override public void open() throws Exception { super.open(); this.collector = new StreamRecordCollector<>(output); } @Override public void processElement(StreamRecord>> streamRecord) throws Exception { collector.collect(streamRecord.getValue().f1); } } /** * A simple in memory implementation Sampling, and with only one pass through the input * iteration whose size is unpredictable. The basic idea behind this sampler implementation is * to generate a random number for each input element as its weight, select the top K elements * with max weight. As the weights are generated randomly, so are the selected top K elements. * In the first phase, we generate random numbers as the weights for each element and select top * K elements as the output of each partitions. In the second phase, we select top K elements * from all the outputs of the first phase. * *

This implementation refers to the algorithm described in "Optimal Random Sampling from * Distributed Streams Revisited". */ private static class Sampler { private final int numSamples; private final Random random; private final PriorityQueue> queue; private int index = 0; private Tuple2 smallest = null; /** * Create a new sampler with reservoir size and a supplied random number generator. * * @param numSamples Maximum number of samples to retain in reservoir, must be non-negative. */ Sampler(int numSamples, long seed) { Preconditions.checkArgument(numSamples >= 0, "numSamples should be non-negative."); this.numSamples = numSamples; this.random = new XORShiftRandom(seed); this.queue = new PriorityQueue<>(numSamples, Comparator.comparingDouble(o -> o.f0)); } void collect(T rowData) { collect(random.nextDouble(), rowData); } void collect(double weight, T key) { if (index < numSamples) { // Fill the queue with first K elements from input. addQueue(weight, key); } else { // Remove the element with the smallest weight, // and append current element into the queue. if (weight > smallest.f0) { queue.remove(); addQueue(weight, key); } } index++; } private void addQueue(double weight, T row) { queue.add(new Tuple2<>(weight, row)); smallest = queue.peek(); } Iterator> sample() { return queue.iterator(); } } /** Contains integers and randomly get one. */ private static class RandomList { private static final Random RANDOM = new Random(); private final List list = new ArrayList<>(); public void add(int i) { list.add(i); } public int get() { return list.get(RANDOM.nextInt(list.size())); } } @VisibleForTesting static T[] allocateRangeBaseSize(List> sampledData, int rangesNum) { int sampeNum = sampledData.size(); int boundarySize = rangesNum - 1; @SuppressWarnings("unchecked") T[] boundaries = (T[]) new Object[boundarySize]; if (!sampledData.isEmpty()) { long restSize = sampledData.stream().mapToLong(t -> (long) t.f1).sum(); double stepRange = restSize / (double) rangesNum; int currentWeight = 0; int index = 0; for (int i = 0; i < boundarySize; i++) { while (currentWeight < stepRange && index < sampeNum) { boundaries[i] = sampledData.get(Math.min(index, sampeNum - 1)).f0; int sampleWeight = sampledData.get(index++).f1; currentWeight += sampleWeight; restSize -= sampleWeight; } currentWeight = 0; stepRange = restSize / (double) (rangesNum - i - 1); } } for (int i = 0; i < boundarySize; i++) { if (boundaries[i] == null) { boundaries[i] = sampledData.get(sampeNum - 1).f0; } } return boundaries; } }





© 2015 - 2025 Weber Informatics LLC | Privacy Policy