Please wait. This can take some minutes ...
Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance.
Project price only 1 $
You can buy this project and download/modify it how often you want.
org.apache.flink.runtime.checkpoint.RoundRobinOperatorStateRepartitioner Maven / Gradle / Ivy
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.runtime.checkpoint;
import org.apache.flink.annotation.Internal;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.runtime.state.OperatorStateHandle;
import org.apache.flink.runtime.state.OperatorStreamStateHandle;
import org.apache.flink.runtime.state.StreamStateHandle;
import org.apache.flink.util.CollectionUtil;
import org.apache.flink.util.Preconditions;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.BitSet;
import java.util.EnumMap;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
/**
* Current default implementation of {@link OperatorStateRepartitioner} that redistributes state in
* round robin fashion.
*/
@Internal
public class RoundRobinOperatorStateRepartitioner
implements OperatorStateRepartitioner {
public static final OperatorStateRepartitioner INSTANCE =
new RoundRobinOperatorStateRepartitioner();
private static final boolean OPTIMIZE_MEMORY_USE = false;
@Override
public List> repartitionState(
List> previousParallelSubtaskStates,
int oldParallelism,
int newParallelism) {
Preconditions.checkNotNull(previousParallelSubtaskStates);
Preconditions.checkArgument(newParallelism > 0);
Preconditions.checkArgument(
previousParallelSubtaskStates.size() == oldParallelism,
"This method still depends on the order of the new and old operators");
// Assemble result from all merge maps
List> result = new ArrayList<>(newParallelism);
List> mergeMapList;
// We only round-robin repartition UNION state if new parallelism equals to the old one.
if (newParallelism == oldParallelism) {
Map>>
unionStates = collectUnionStates(previousParallelSubtaskStates);
Map>>
partlyFinishedBroadcastStates =
collectPartlyFinishedBroadcastStates(previousParallelSubtaskStates);
if (unionStates.isEmpty() && partlyFinishedBroadcastStates.isEmpty()) {
return previousParallelSubtaskStates;
}
// Initialize
mergeMapList = initMergeMapList(previousParallelSubtaskStates);
repartitionUnionState(unionStates, mergeMapList);
// TODO: Currently if some tasks is finished, we would rescale the
// remaining state. A better solution would be not touch the non-empty
// subtask state and only fix the empty ones.
repartitionBroadcastState(partlyFinishedBroadcastStates, mergeMapList);
} else {
// Reorganize: group by (State Name -> StreamStateHandle + Offsets)
GroupByStateNameResults nameToStateByMode =
groupByStateMode(previousParallelSubtaskStates);
if (OPTIMIZE_MEMORY_USE) {
previousParallelSubtaskStates
.clear(); // free for GC at to cost that old handles are no longer available
}
// Do the actual repartitioning for all named states
mergeMapList = repartition(nameToStateByMode, newParallelism);
}
for (int i = 0; i < mergeMapList.size(); ++i) {
result.add(i, new ArrayList<>(mergeMapList.get(i).values()));
}
return result;
}
/**
* Init the list of StreamStateHandle -> OperatorStateHandle map with given
* parallelSubtaskStates when parallelism not changed.
*/
private List> initMergeMapList(
List> parallelSubtaskStates) {
int parallelism = parallelSubtaskStates.size();
final List> mergeMapList =
new ArrayList<>(parallelism);
for (List previousParallelSubtaskState : parallelSubtaskStates) {
mergeMapList.add(
previousParallelSubtaskState.stream()
.collect(
Collectors.toMap(
OperatorStateHandle::getDelegateStateHandle,
Function.identity())));
}
return mergeMapList;
}
private Map>>
collectUnionStates(List> parallelSubtaskStates) {
return collectStates(parallelSubtaskStates, OperatorStateHandle.Mode.UNION).entrySet()
.stream()
.collect(Collectors.toMap(Map.Entry::getKey, e -> e.getValue().entries));
}
private Map>>
collectPartlyFinishedBroadcastStates(
List> parallelSubtaskStates) {
return collectStates(parallelSubtaskStates, OperatorStateHandle.Mode.BROADCAST).entrySet()
.stream()
.filter(e -> e.getValue().isPartiallyReported())
.collect(Collectors.toMap(Map.Entry::getKey, e -> e.getValue().entries));
}
/** Collect the states from given parallelSubtaskStates with the specific {@code mode}. */
private Map collectStates(
List> parallelSubtaskStates, OperatorStateHandle.Mode mode) {
Map states =
CollectionUtil.newHashMapWithExpectedSize(parallelSubtaskStates.size());
for (int i = 0; i < parallelSubtaskStates.size(); ++i) {
final int subtaskIndex = i;
List subTaskState = parallelSubtaskStates.get(i);
for (OperatorStateHandle operatorStateHandle : subTaskState) {
if (operatorStateHandle == null) {
continue;
}
final Set>
partitionOffsetEntries =
operatorStateHandle.getStateNameToPartitionOffsets().entrySet();
partitionOffsetEntries.stream()
.filter(entry -> entry.getValue().getDistributionMode().equals(mode))
.forEach(
entry -> {
StateEntry stateEntry =
states.computeIfAbsent(
entry.getKey(),
k ->
new StateEntry(
parallelSubtaskStates.size()
* partitionOffsetEntries
.size(),
parallelSubtaskStates.size()));
stateEntry.addEntry(
subtaskIndex,
Tuple2.of(
operatorStateHandle.getDelegateStateHandle(),
entry.getValue()));
});
}
}
return states;
}
/** Group by the different named states. */
@SuppressWarnings("unchecked, rawtype")
private GroupByStateNameResults groupByStateMode(
List> previousParallelSubtaskStates) {
// Reorganize: group by (State Name -> StreamStateHandle + StateMetaInfo)
EnumMap<
OperatorStateHandle.Mode,
Map<
String,
List>>>
nameToStateByMode = new EnumMap<>(OperatorStateHandle.Mode.class);
for (OperatorStateHandle.Mode mode : OperatorStateHandle.Mode.values()) {
nameToStateByMode.put(mode, new HashMap<>());
}
for (List previousParallelSubtaskState :
previousParallelSubtaskStates) {
for (OperatorStateHandle operatorStateHandle : previousParallelSubtaskState) {
if (operatorStateHandle == null) {
continue;
}
final Set>
partitionOffsetEntries =
operatorStateHandle.getStateNameToPartitionOffsets().entrySet();
for (Map.Entry e :
partitionOffsetEntries) {
OperatorStateHandle.StateMetaInfo metaInfo = e.getValue();
Map>>
nameToState = nameToStateByMode.get(metaInfo.getDistributionMode());
List>
stateLocations =
nameToState.computeIfAbsent(
e.getKey(),
k ->
new ArrayList<>(
previousParallelSubtaskStates.size()
* partitionOffsetEntries
.size()));
stateLocations.add(
Tuple2.of(operatorStateHandle.getDelegateStateHandle(), e.getValue()));
}
}
}
return new GroupByStateNameResults(nameToStateByMode);
}
/** Repartition all named states. */
private List> repartition(
GroupByStateNameResults nameToStateByMode, int newParallelism) {
// We will use this to merge w.r.t. StreamStateHandles for each parallel subtask inside the
// maps
List> mergeMapList =
new ArrayList<>(newParallelism);
// Initialize
for (int i = 0; i < newParallelism; ++i) {
mergeMapList.add(new HashMap<>());
}
// Start with the state handles we distribute round robin by splitting by offsets
Map>>
nameToDistributeState =
nameToStateByMode.getByMode(OperatorStateHandle.Mode.SPLIT_DISTRIBUTE);
repartitionSplitState(nameToDistributeState, newParallelism, mergeMapList);
// Now we also add the state handles marked for union to all parallel instances
Map>>
nameToUnionState = nameToStateByMode.getByMode(OperatorStateHandle.Mode.UNION);
repartitionUnionState(nameToUnionState, mergeMapList);
// Now we also add the state handles marked for uniform broadcast to all parallel instances
Map>>
nameToBroadcastState =
nameToStateByMode.getByMode(OperatorStateHandle.Mode.BROADCAST);
repartitionBroadcastState(nameToBroadcastState, mergeMapList);
return mergeMapList;
}
/** Repartition SPLIT_DISTRIBUTE state. */
private void repartitionSplitState(
Map>>
nameToDistributeState,
int newParallelism,
List> mergeMapList) {
int startParallelOp = 0;
// Iterate all named states and repartition one named state at a time per iteration
for (Map.Entry>>
e : nameToDistributeState.entrySet()) {
List> current =
e.getValue();
// Determine actual number of partitions for this named state
int totalPartitions = 0;
for (Tuple2 offsets : current) {
totalPartitions += offsets.f1.getOffsets().length;
}
// Repartition the state across the parallel operator instances
int lstIdx = 0;
int offsetIdx = 0;
int baseFraction = totalPartitions / newParallelism;
int remainder = totalPartitions % newParallelism;
int newStartParallelOp = startParallelOp;
for (int i = 0; i < newParallelism; ++i) {
// Preparation: calculate the actual index considering wrap around
int parallelOpIdx = (i + startParallelOp) % newParallelism;
// Now calculate the number of partitions we will assign to the parallel instance in
// this round ...
int numberOfPartitionsToAssign = baseFraction;
// ... and distribute odd partitions while we still have some, one at a time
if (remainder > 0) {
++numberOfPartitionsToAssign;
--remainder;
} else if (remainder == 0) {
// We are out of odd partitions now and begin our next redistribution round with
// the current
// parallel operator to ensure fair load balance
newStartParallelOp = parallelOpIdx;
--remainder;
}
// Now start collection the partitions for the parallel instance into this list
while (numberOfPartitionsToAssign > 0) {
Tuple2 handleWithOffsets =
current.get(lstIdx);
long[] offsets = handleWithOffsets.f1.getOffsets();
int remaining = offsets.length - offsetIdx;
// Repartition offsets
long[] offs;
if (remaining > numberOfPartitionsToAssign) {
offs =
Arrays.copyOfRange(
offsets, offsetIdx, offsetIdx + numberOfPartitionsToAssign);
offsetIdx += numberOfPartitionsToAssign;
} else {
if (OPTIMIZE_MEMORY_USE) {
handleWithOffsets.f1 = null; // GC
}
offs = Arrays.copyOfRange(offsets, offsetIdx, offsets.length);
offsetIdx = 0;
++lstIdx;
}
numberOfPartitionsToAssign -= remaining;
// As a last step we merge partitions that use the same StreamStateHandle in a
// single
// OperatorStateHandle
Map mergeMap =
mergeMapList.get(parallelOpIdx);
OperatorStateHandle operatorStateHandle = mergeMap.get(handleWithOffsets.f0);
if (operatorStateHandle == null) {
operatorStateHandle =
new OperatorStreamStateHandle(
CollectionUtil.newHashMapWithExpectedSize(
nameToDistributeState.size()),
handleWithOffsets.f0);
mergeMap.put(handleWithOffsets.f0, operatorStateHandle);
}
operatorStateHandle
.getStateNameToPartitionOffsets()
.put(
e.getKey(),
new OperatorStateHandle.StateMetaInfo(
offs, OperatorStateHandle.Mode.SPLIT_DISTRIBUTE));
}
}
startParallelOp = newStartParallelOp;
e.setValue(null);
}
}
/** Repartition UNION state. */
private void repartitionUnionState(
Map>>
unionState,
List> mergeMapList) {
for (Map mergeMap : mergeMapList) {
for (Map.Entry<
String,
List>>
e : unionState.entrySet()) {
for (Tuple2
handleWithMetaInfo : e.getValue()) {
OperatorStateHandle operatorStateHandle = mergeMap.get(handleWithMetaInfo.f0);
if (operatorStateHandle == null) {
operatorStateHandle =
new OperatorStreamStateHandle(
CollectionUtil.newHashMapWithExpectedSize(
unionState.size()),
handleWithMetaInfo.f0);
mergeMap.put(handleWithMetaInfo.f0, operatorStateHandle);
}
operatorStateHandle
.getStateNameToPartitionOffsets()
.put(e.getKey(), handleWithMetaInfo.f1);
}
}
}
}
/** Repartition BROADCAST state. */
private void repartitionBroadcastState(
Map>>
broadcastState,
List> mergeMapList) {
int newParallelism = mergeMapList.size();
for (int i = 0; i < newParallelism; ++i) {
final Map mergeMap = mergeMapList.get(i);
// for each name, pick the i-th entry
for (Map.Entry<
String,
List>>
e : broadcastState.entrySet()) {
int previousParallelism = e.getValue().size();
Tuple2 handleWithMetaInfo =
e.getValue().get(i % previousParallelism);
OperatorStateHandle operatorStateHandle = mergeMap.get(handleWithMetaInfo.f0);
if (operatorStateHandle == null) {
operatorStateHandle =
new OperatorStreamStateHandle(
CollectionUtil.newHashMapWithExpectedSize(
broadcastState.size()),
handleWithMetaInfo.f0);
mergeMap.put(handleWithMetaInfo.f0, operatorStateHandle);
}
operatorStateHandle
.getStateNameToPartitionOffsets()
.put(e.getKey(), handleWithMetaInfo.f1);
}
}
}
private static final class GroupByStateNameResults {
private final EnumMap<
OperatorStateHandle.Mode,
Map<
String,
List>>>
byMode;
GroupByStateNameResults(
EnumMap<
OperatorStateHandle.Mode,
Map<
String,
List<
Tuple2<
StreamStateHandle,
OperatorStateHandle.StateMetaInfo>>>>
byMode) {
this.byMode = Preconditions.checkNotNull(byMode);
}
public Map>>
getByMode(OperatorStateHandle.Mode mode) {
return byMode.get(mode);
}
}
private static final class StateEntry {
final List> entries;
final BitSet reportedSubtaskIndices;
public StateEntry(int estimatedEntrySize, int parallelism) {
this.entries = new ArrayList<>(estimatedEntrySize);
this.reportedSubtaskIndices = new BitSet(parallelism);
}
void addEntry(
int subtaskIndex,
Tuple2 entry) {
this.entries.add(entry);
reportedSubtaskIndices.set(subtaskIndex);
}
boolean isPartiallyReported() {
return reportedSubtaskIndices.cardinality() > 0
&& reportedSubtaskIndices.cardinality() < reportedSubtaskIndices.size();
}
}
}