org.apache.flink.runtime.checkpoint.TaskStateAssignment Maven / Gradle / Ivy
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.runtime.checkpoint;
import org.apache.flink.runtime.OperatorIDPair;
import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
import org.apache.flink.runtime.jobgraph.OperatorID;
import org.apache.flink.runtime.jobgraph.OperatorInstanceID;
import org.apache.flink.runtime.state.InputChannelStateHandle;
import org.apache.flink.runtime.state.KeyedStateHandle;
import org.apache.flink.runtime.state.OperatorStateHandle;
import org.apache.flink.runtime.state.ResultSubpartitionStateHandle;
import org.apache.flink.runtime.state.StateObject;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import static java.util.Collections.emptyMap;
import static org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptor.NO_MAPPINGS;
import static org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptor.NO_SUBTASKS;
import static org.apache.flink.util.Preconditions.checkState;
/**
* Used by {@link StateAssignmentOperation} to store temporal information while creating {@link
* OperatorSubtaskState}.
*/
class TaskStateAssignment {
final ExecutionJobVertex executionJobVertex;
final Map oldState;
final boolean hasState;
final int newParallelism;
final OperatorID inputOperatorID;
final OperatorID outputOperatorID;
final Map> subManagedOperatorState;
final Map> subRawOperatorState;
final Map> subManagedKeyedState;
final Map> subRawKeyedState;
final Map> inputChannelStates;
final Map> resultSubpartitionStates;
/** The subtask mapping when the output operator was rescaled. */
Map> outputSubtaskMappings = emptyMap();
/** The subtask mapping when the input operator was rescaled. */
Map> inputSubtaskMappings = emptyMap();
/**
* The subpartitions mappings of the upstream task per input set when its output operator was
* rescaled.
*/
final Map upstreamAssignments;
/**
* The input channel mappings of the downstream task per partition set when its input operator
* was rescaled.
*/
final Map downstreamAssignments;
public TaskStateAssignment(
ExecutionJobVertex executionJobVertex, Map oldState) {
this.executionJobVertex = executionJobVertex;
this.oldState = oldState;
this.hasState =
oldState.values().stream()
.anyMatch(operatorState -> operatorState.getNumberCollectedStates() > 0);
newParallelism = executionJobVertex.getParallelism();
final int expectedNumberOfSubtasks = newParallelism * oldState.size();
subManagedOperatorState = new HashMap<>(expectedNumberOfSubtasks);
subRawOperatorState = new HashMap<>(expectedNumberOfSubtasks);
inputChannelStates = new HashMap<>(expectedNumberOfSubtasks);
resultSubpartitionStates = new HashMap<>(expectedNumberOfSubtasks);
subManagedKeyedState = new HashMap<>(expectedNumberOfSubtasks);
subRawKeyedState = new HashMap<>(expectedNumberOfSubtasks);
final List operatorIDs = executionJobVertex.getOperatorIDs();
outputOperatorID = operatorIDs.get(0).getGeneratedOperatorID();
inputOperatorID = operatorIDs.get(operatorIDs.size() - 1).getGeneratedOperatorID();
upstreamAssignments = new HashMap<>(executionJobVertex.getInputs().size());
downstreamAssignments = new HashMap<>(executionJobVertex.getProducedDataSets().length);
}
public OperatorSubtaskState getSubtaskState(OperatorInstanceID instanceID) {
checkState(
subManagedKeyedState.containsKey(instanceID)
|| !subRawKeyedState.containsKey(instanceID),
"If an operator has no managed key state, it should also not have a raw keyed state.");
return OperatorSubtaskState.builder()
.setManagedOperatorState(getState(instanceID, subManagedOperatorState))
.setRawOperatorState(getState(instanceID, subRawOperatorState))
.setManagedKeyedState(getState(instanceID, subManagedKeyedState))
.setRawKeyedState(getState(instanceID, subRawKeyedState))
.setInputChannelState(getState(instanceID, inputChannelStates))
.setResultSubpartitionState(getState(instanceID, resultSubpartitionStates))
.setInputRescalingDescriptor(
inputOperatorID.equals(instanceID.getOperatorId())
? createRescalingDescriptor(
instanceID,
upstreamAssignments,
assignment -> assignment.outputSubtaskMappings,
inputSubtaskMappings)
: InflightDataRescalingDescriptor.NO_RESCALE)
.setOutputRescalingDescriptor(
outputOperatorID.equals(instanceID.getOperatorId())
? createRescalingDescriptor(
instanceID,
downstreamAssignments,
assignment -> assignment.inputSubtaskMappings,
outputSubtaskMappings)
: InflightDataRescalingDescriptor.NO_RESCALE)
.build();
}
private InflightDataRescalingDescriptor createRescalingDescriptor(
OperatorInstanceID instanceID,
Map assignments,
Function>> mappingRetriever,
Map> subtaskMappings) {
if (assignments.isEmpty() && subtaskMappings.isEmpty()) {
return InflightDataRescalingDescriptor.NO_RESCALE;
}
final Set oldTaskInstances =
subtaskMappings.isEmpty()
? NO_SUBTASKS
: subtaskMappings.get(instanceID.getSubtaskId());
final Map rescaledChannelsMappings =
assignments.isEmpty()
? NO_MAPPINGS
: assignments.entrySet().stream()
.collect(
Collectors.toMap(
Map.Entry::getKey,
assignment ->
new RescaledChannelsMapping(
mappingRetriever.apply(
assignment.getValue()))));
return new InflightDataRescalingDescriptor(oldTaskInstances, rescaledChannelsMappings);
}
private StateObjectCollection getState(
OperatorInstanceID instanceID,
Map> subManagedOperatorState) {
List value = subManagedOperatorState.get(instanceID);
return value != null ? new StateObjectCollection<>(value) : StateObjectCollection.empty();
}
}