All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.flink.runtime.checkpoint.StateAssignmentOperation Maven / Gradle / Ivy

There is a newer version: 1.13.6
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.flink.runtime.checkpoint;

import org.apache.flink.annotation.Internal;
import org.apache.flink.annotation.VisibleForTesting;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.runtime.OperatorIDPair;
import org.apache.flink.runtime.executiongraph.Execution;
import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
import org.apache.flink.runtime.jobgraph.OperatorID;
import org.apache.flink.runtime.jobgraph.OperatorInstanceID;
import org.apache.flink.runtime.state.AbstractChannelStateHandle;
import org.apache.flink.runtime.state.InputChannelStateHandle;
import org.apache.flink.runtime.state.KeyGroupRange;
import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
import org.apache.flink.runtime.state.KeyGroupsStateHandle;
import org.apache.flink.runtime.state.KeyedStateHandle;
import org.apache.flink.runtime.state.OperatorStateHandle;
import org.apache.flink.runtime.state.ResultSubpartitionStateHandle;
import org.apache.flink.runtime.state.StateObject;
import org.apache.flink.util.Preconditions;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;

import static java.util.Collections.emptyList;
import static org.apache.flink.util.Preconditions.checkNotNull;
import static org.apache.flink.util.Preconditions.checkState;

/**
 * This class encapsulates the operation of assigning restored state when restoring from a checkpoint.
 */
@Internal
public class StateAssignmentOperation {

	private static final Logger LOG = LoggerFactory.getLogger(StateAssignmentOperation.class);

	private final Set tasks;
	private final Map operatorStates;

	private final long restoreCheckpointId;
	private final boolean allowNonRestoredState;

	public StateAssignmentOperation(
		long restoreCheckpointId,
		Set tasks,
		Map operatorStates,
		boolean allowNonRestoredState) {

		this.restoreCheckpointId = restoreCheckpointId;
		this.tasks = Preconditions.checkNotNull(tasks);
		this.operatorStates = Preconditions.checkNotNull(operatorStates);
		this.allowNonRestoredState = allowNonRestoredState;
	}

	public void assignStates() {
		Map localOperators = new HashMap<>(operatorStates);

		checkStateMappingCompleteness(allowNonRestoredState, operatorStates, tasks);

		for (ExecutionJobVertex executionJobVertex : this.tasks) {

			// find the states of all operators belonging to this task
			List operatorIDPairs = executionJobVertex.getOperatorIDs();
			List operatorStates = new ArrayList<>(operatorIDPairs.size());
			boolean statelessSubTasks = true;
			for (OperatorIDPair operatorIDPair : operatorIDPairs) {
				OperatorID operatorID = operatorIDPair.getUserDefinedOperatorID().orElse(operatorIDPair.getGeneratedOperatorID());

				OperatorState operatorState = localOperators.remove(operatorID);
				if (operatorState == null) {
					operatorState = new OperatorState(
						operatorID,
						executionJobVertex.getParallelism(),
						executionJobVertex.getMaxParallelism());
				} else if (operatorState.getNumberCollectedStates() > 0) {
					statelessSubTasks = false;
				}
				operatorStates.add(operatorState);
			}
			if (!statelessSubTasks) { // skip tasks where no operator has any state
				assignAttemptState(executionJobVertex, operatorStates);
			}
		}

	}

	private void assignAttemptState(ExecutionJobVertex executionJobVertex, List operatorStates) {

		List operatorIDs = executionJobVertex.getOperatorIDs();

		//1. first compute the new parallelism
		checkParallelismPreconditions(operatorStates, executionJobVertex);

		int newParallelism = executionJobVertex.getParallelism();

		List keyGroupPartitions = createKeyGroupPartitions(
			executionJobVertex.getMaxParallelism(),
			newParallelism);

		final int expectedNumberOfSubTasks = newParallelism * operatorIDs.size();

		/*
		 * Redistribute ManagedOperatorStates and RawOperatorStates from old parallelism to new parallelism.
		 *
		 * The old ManagedOperatorStates with old parallelism 3:
		 *
		 * 		parallelism0 parallelism1 parallelism2
		 * op0   states0,0    state0,1	   state0,2
		 * op1
		 * op2   states2,0    state2,1	   state1,2
		 * op3   states3,0    state3,1     state3,2
		 *
		 * The new ManagedOperatorStates with new parallelism 4:
		 *
		 * 		parallelism0 parallelism1 parallelism2 parallelism3
		 * op0   state0,0	  state0,1 	   state0,2		state0,3
		 * op1
		 * op2   state2,0	  state2,1 	   state2,2		state2,3
		 * op3   state3,0	  state3,1 	   state3,2		state3,3
		 */
		Map> newManagedOperatorStates = reDistributePartitionableStates(
			operatorStates,
			newParallelism,
			operatorIDs,
			OperatorSubtaskState::getManagedOperatorState,
			RoundRobinOperatorStateRepartitioner.INSTANCE);
		Map> newRawOperatorStates = reDistributePartitionableStates(
			operatorStates,
			newParallelism,
			operatorIDs,
			OperatorSubtaskState::getRawOperatorState,
			RoundRobinOperatorStateRepartitioner.INSTANCE);
		final Map> newInputChannelState = reDistributePartitionableStates(
			operatorStates,
			newParallelism,
			operatorIDs,
			OperatorSubtaskState::getInputChannelState,
			channelStateNonRescalingRepartitioner("input channel"));
		final Map> newResultSubpartitionState = reDistributePartitionableStates(
			operatorStates,
			newParallelism,
			operatorIDs,
			OperatorSubtaskState::getResultSubpartitionState,
			channelStateNonRescalingRepartitioner("result subpartition"));

		Map> newManagedKeyedState = new HashMap<>(expectedNumberOfSubTasks);
		Map> newRawKeyedState = new HashMap<>(expectedNumberOfSubTasks);

		reDistributeKeyedStates(
			operatorStates,
			newParallelism,
			operatorIDs,
			keyGroupPartitions,
			newManagedKeyedState,
			newRawKeyedState);

		/*
		 *  An executionJobVertex's all state handles needed to restore are something like a matrix
		 *
		 * 		parallelism0 parallelism1 parallelism2 parallelism3
		 * op0   sh(0,0)     sh(0,1)       sh(0,2)	    sh(0,3)
		 * op1   sh(1,0)	 sh(1,1)	   sh(1,2)	    sh(1,3)
		 * op2   sh(2,0)	 sh(2,1)	   sh(2,2)		sh(2,3)
		 * op3   sh(3,0)	 sh(3,1)	   sh(3,2)		sh(3,3)
		 *
		 */
		assignTaskStateToExecutionJobVertices(
			executionJobVertex,
			newManagedOperatorStates,
			newRawOperatorStates,
			newInputChannelState,
			newResultSubpartitionState,
			newManagedKeyedState,
			newRawKeyedState,
			newParallelism);
	}

	private void assignTaskStateToExecutionJobVertices(
			ExecutionJobVertex executionJobVertex,
			Map> subManagedOperatorState,
			Map> subRawOperatorState,
			Map> inputChannelStates,
			Map> resultSubpartitionStates,
			Map> subManagedKeyedState,
			Map> subRawKeyedState,
			int newParallelism) {

		List operatorIDs = executionJobVertex.getOperatorIDs();

		for (int subTaskIndex = 0; subTaskIndex < newParallelism; subTaskIndex++) {

			Execution currentExecutionAttempt = executionJobVertex.getTaskVertices()[subTaskIndex]
				.getCurrentExecutionAttempt();

			TaskStateSnapshot taskState = new TaskStateSnapshot(operatorIDs.size());
			boolean statelessTask = true;

			for (OperatorIDPair operatorID : operatorIDs) {
				OperatorInstanceID instanceID = OperatorInstanceID.of(subTaskIndex, operatorID.getGeneratedOperatorID());

				OperatorSubtaskState operatorSubtaskState = operatorSubtaskStateFrom(
					instanceID,
					subManagedOperatorState,
					subRawOperatorState,
					inputChannelStates,
					resultSubpartitionStates,
					subManagedKeyedState,
					subRawKeyedState);

				if (operatorSubtaskState.hasState()) {
					statelessTask = false;
				}
				taskState.putSubtaskStateByOperatorID(operatorID.getGeneratedOperatorID(), operatorSubtaskState);
			}

			if (!statelessTask) {
				JobManagerTaskRestore taskRestore = new JobManagerTaskRestore(restoreCheckpointId, taskState);
				currentExecutionAttempt.setInitialState(taskRestore);
			}
		}
	}

	public static OperatorSubtaskState operatorSubtaskStateFrom(
			OperatorInstanceID instanceID,
			Map> subManagedOperatorState,
			Map> subRawOperatorState,
			Map> inputChannelStates,
			Map> resultSubpartitionStates,
			Map> subManagedKeyedState,
			Map> subRawKeyedState) {

		if (!subManagedOperatorState.containsKey(instanceID) &&
			!subRawOperatorState.containsKey(instanceID) &&
			!inputChannelStates.containsKey(instanceID) &&
			!resultSubpartitionStates.containsKey(instanceID) &&
			!subManagedKeyedState.containsKey(instanceID) &&
			!subRawKeyedState.containsKey(instanceID)) {

			return new OperatorSubtaskState();
		}
		if (!subManagedKeyedState.containsKey(instanceID)) {
			checkState(!subRawKeyedState.containsKey(instanceID));
		}
		return new OperatorSubtaskState(
			new StateObjectCollection<>(subManagedOperatorState.getOrDefault(instanceID, emptyList())),
			new StateObjectCollection<>(subRawOperatorState.getOrDefault(instanceID, emptyList())),
			new StateObjectCollection<>(subManagedKeyedState.getOrDefault(instanceID, emptyList())),
			new StateObjectCollection<>(subRawKeyedState.getOrDefault(instanceID, emptyList())),
			new StateObjectCollection<>(inputChannelStates.getOrDefault(instanceID, emptyList())),
			new StateObjectCollection<>(resultSubpartitionStates.getOrDefault(instanceID, emptyList())));
	}

	public void checkParallelismPreconditions(List operatorStates, ExecutionJobVertex executionJobVertex) {
		for (OperatorState operatorState : operatorStates) {
			checkParallelismPreconditions(operatorState, executionJobVertex);
		}
	}

	private void reDistributeKeyedStates(
			List oldOperatorStates,
			int newParallelism,
			List newOperatorIDs,
			List newKeyGroupPartitions,
			Map> newManagedKeyedState,
			Map> newRawKeyedState) {
		//TODO: rewrite this method to only use OperatorID
		checkState(newOperatorIDs.size() == oldOperatorStates.size(),
			"This method still depends on the order of the new and old operators");

		for (int operatorIndex = 0; operatorIndex < newOperatorIDs.size(); operatorIndex++) {
			OperatorState operatorState = oldOperatorStates.get(operatorIndex);
			int oldParallelism = operatorState.getParallelism();
			for (int subTaskIndex = 0; subTaskIndex < newParallelism; subTaskIndex++) {
				OperatorInstanceID instanceID = OperatorInstanceID.of(subTaskIndex, newOperatorIDs.get(operatorIndex).getGeneratedOperatorID());
				Tuple2, List> subKeyedStates = reAssignSubKeyedStates(
					operatorState,
					newKeyGroupPartitions,
					subTaskIndex,
					newParallelism,
					oldParallelism);
				newManagedKeyedState.put(instanceID, subKeyedStates.f0);
				newRawKeyedState.put(instanceID, subKeyedStates.f1);
			}
		}
	}

	// TODO rewrite based on operator id
	private Tuple2, List> reAssignSubKeyedStates(
			OperatorState operatorState,
			List keyGroupPartitions,
			int subTaskIndex,
			int newParallelism,
			int oldParallelism) {

		List subManagedKeyedState;
		List subRawKeyedState;

		if (newParallelism == oldParallelism) {
			if (operatorState.getState(subTaskIndex) != null) {
				subManagedKeyedState = operatorState.getState(subTaskIndex).getManagedKeyedState().asList();
				subRawKeyedState = operatorState.getState(subTaskIndex).getRawKeyedState().asList();
			} else {
				subManagedKeyedState = emptyList();
				subRawKeyedState = emptyList();
			}
		} else {
			subManagedKeyedState = getManagedKeyedStateHandles(operatorState, keyGroupPartitions.get(subTaskIndex));
			subRawKeyedState = getRawKeyedStateHandles(operatorState, keyGroupPartitions.get(subTaskIndex));
		}

		if (subManagedKeyedState.isEmpty() && subRawKeyedState.isEmpty()) {
			return new Tuple2<>(emptyList(), emptyList());
		} else {
			return new Tuple2<>(subManagedKeyedState, subRawKeyedState);
		}
	}

	public static  Map>  reDistributePartitionableStates(
			List oldOperatorStates,
			int newParallelism,
			List newOperatorIDs,
			Function> extractHandle,
			OperatorStateRepartitioner stateRepartitioner) {

		//TODO: rewrite this method to only use OperatorID
		checkState(newOperatorIDs.size() == oldOperatorStates.size(),
			"This method still depends on the order of the new and old operators");

		// The nested list wraps as the level of operator -> subtask -> state object collection
		List>> oldStates = splitManagedAndRawOperatorStates(oldOperatorStates, extractHandle);

		Map> result = new HashMap<>();
		for (int operatorIndex = 0; operatorIndex < newOperatorIDs.size(); operatorIndex++) {
			result.putAll(applyRepartitioner(
				newOperatorIDs.get(operatorIndex).getGeneratedOperatorID(),
				stateRepartitioner,
				oldStates.get(operatorIndex),
				oldOperatorStates.get(operatorIndex).getParallelism(),
				newParallelism));
		}

		return result;
	}

	private static  List>> splitManagedAndRawOperatorStates(
			List operatorStates,
			Function> extracthandle) {
		List>> result = new ArrayList<>();

		for (OperatorState operatorState : operatorStates) {
			List> statePerSubtask = new ArrayList<>(operatorState.getParallelism());

			for (int subTaskIndex = 0; subTaskIndex < operatorState.getParallelism(); subTaskIndex++) {
				OperatorSubtaskState subtaskState = operatorState.getState(subTaskIndex);
				statePerSubtask.add(subtaskState == null ? emptyList() : extracthandle.apply(subtaskState).asList());
			}
			result.add(statePerSubtask);
		}
		return result;
	}

	/**
	 * Collect {@link KeyGroupsStateHandle  managedKeyedStateHandles} which have intersection with given
	 * {@link KeyGroupRange} from {@link TaskState operatorState}.
	 *
	 * @param operatorState        all state handles of a operator
	 * @param subtaskKeyGroupRange the KeyGroupRange of a subtask
	 * @return all managedKeyedStateHandles which have intersection with given KeyGroupRange
	 */
	public static List getManagedKeyedStateHandles(
			OperatorState operatorState,
			KeyGroupRange subtaskKeyGroupRange) {

		final int parallelism = operatorState.getParallelism();

		List subtaskKeyedStateHandles = null;

		for (int i = 0; i < parallelism; i++) {
			if (operatorState.getState(i) != null) {

				Collection keyedStateHandles = operatorState.getState(i).getManagedKeyedState();

				if (subtaskKeyedStateHandles == null) {
					subtaskKeyedStateHandles = new ArrayList<>(parallelism * keyedStateHandles.size());
				}

				extractIntersectingState(
					keyedStateHandles,
					subtaskKeyGroupRange,
					subtaskKeyedStateHandles);
			}
		}

		return subtaskKeyedStateHandles;
	}

	/**
	 * Collect {@link KeyGroupsStateHandle  rawKeyedStateHandles} which have intersection with given
	 * {@link KeyGroupRange} from {@link TaskState operatorState}.
	 *
	 * @param operatorState        all state handles of a operator
	 * @param subtaskKeyGroupRange the KeyGroupRange of a subtask
	 * @return all rawKeyedStateHandles which have intersection with given KeyGroupRange
	 */
	public static List getRawKeyedStateHandles(
			OperatorState operatorState,
			KeyGroupRange subtaskKeyGroupRange) {

		final int parallelism = operatorState.getParallelism();

		List extractedKeyedStateHandles = null;

		for (int i = 0; i < parallelism; i++) {
			if (operatorState.getState(i) != null) {

				Collection rawKeyedState = operatorState.getState(i).getRawKeyedState();

				if (extractedKeyedStateHandles == null) {
					extractedKeyedStateHandles = new ArrayList<>(parallelism * rawKeyedState.size());
				}

				extractIntersectingState(
					rawKeyedState,
					subtaskKeyGroupRange,
					extractedKeyedStateHandles);
			}
		}

		return extractedKeyedStateHandles;
	}

	/**
	 * Extracts certain key group ranges from the given state handles and adds them to the collector.
	 */
	@VisibleForTesting
	public static void extractIntersectingState(
			Collection originalSubtaskStateHandles,
			KeyGroupRange rangeToExtract,
			List extractedStateCollector) {

		for (KeyedStateHandle keyedStateHandle : originalSubtaskStateHandles) {

			if (keyedStateHandle != null) {

				KeyedStateHandle intersectedKeyedStateHandle = keyedStateHandle.getIntersection(rangeToExtract);

				if (intersectedKeyedStateHandle != null) {
					extractedStateCollector.add(intersectedKeyedStateHandle);
				}
			}
		}
	}

	/**
	 * Groups the available set of key groups into key group partitions. A key group partition is
	 * the set of key groups which is assigned to the same task. Each set of the returned list
	 * constitutes a key group partition.
	 *
	 * 

IMPORTANT: The assignment of key groups to partitions has to be in sync with the * KeyGroupStreamPartitioner. * * @param numberKeyGroups Number of available key groups (indexed from 0 to numberKeyGroups - 1) * @param parallelism Parallelism to generate the key group partitioning for * @return List of key group partitions */ public static List createKeyGroupPartitions(int numberKeyGroups, int parallelism) { Preconditions.checkArgument(numberKeyGroups >= parallelism); List result = new ArrayList<>(parallelism); for (int i = 0; i < parallelism; ++i) { result.add(KeyGroupRangeAssignment.computeKeyGroupRangeForOperatorIndex(numberKeyGroups, parallelism, i)); } return result; } /** * Verifies conditions in regards to parallelism and maxParallelism that must be met when restoring state. * * @param operatorState state to restore * @param executionJobVertex task for which the state should be restored */ private static void checkParallelismPreconditions(OperatorState operatorState, ExecutionJobVertex executionJobVertex) { //----------------------------------------max parallelism preconditions------------------------------------- if (operatorState.getMaxParallelism() < executionJobVertex.getParallelism()) { throw new IllegalStateException("The state for task " + executionJobVertex.getJobVertexId() + " can not be restored. The maximum parallelism (" + operatorState.getMaxParallelism() + ") of the restored state is lower than the configured parallelism (" + executionJobVertex.getParallelism() + "). Please reduce the parallelism of the task to be lower or equal to the maximum parallelism." ); } // check that the number of key groups have not changed or if we need to override it to satisfy the restored state if (operatorState.getMaxParallelism() != executionJobVertex.getMaxParallelism()) { if (!executionJobVertex.isMaxParallelismConfigured()) { // if the max parallelism was not explicitly specified by the user, we derive it from the state LOG.debug("Overriding maximum parallelism for JobVertex {} from {} to {}", executionJobVertex.getJobVertexId(), executionJobVertex.getMaxParallelism(), operatorState.getMaxParallelism()); executionJobVertex.setMaxParallelism(operatorState.getMaxParallelism()); } else { // if the max parallelism was explicitly specified, we complain on mismatch throw new IllegalStateException("The maximum parallelism (" + operatorState.getMaxParallelism() + ") with which the latest " + "checkpoint of the execution job vertex " + executionJobVertex + " has been taken and the current maximum parallelism (" + executionJobVertex.getMaxParallelism() + ") changed. This " + "is currently not supported."); } } } /** * Verifies that all operator states can be mapped to an execution job vertex. * * @param allowNonRestoredState if false an exception will be thrown if a state could not be mapped * @param operatorStates operator states to map * @param tasks task to map to */ private static void checkStateMappingCompleteness( boolean allowNonRestoredState, Map operatorStates, Set tasks) { Set allOperatorIDs = new HashSet<>(); for (ExecutionJobVertex executionJobVertex : tasks) { for (OperatorIDPair operatorIDPair : executionJobVertex.getOperatorIDs()) { allOperatorIDs.add(operatorIDPair.getGeneratedOperatorID()); operatorIDPair.getUserDefinedOperatorID().ifPresent(allOperatorIDs::add); } } for (Map.Entry operatorGroupStateEntry : operatorStates.entrySet()) { OperatorState operatorState = operatorGroupStateEntry.getValue(); //----------------------------------------find operator for state--------------------------------------------- if (!allOperatorIDs.contains(operatorGroupStateEntry.getKey())) { if (allowNonRestoredState) { LOG.info("Skipped checkpoint state for operator {}.", operatorState.getOperatorID()); } else { throw new IllegalStateException("There is no operator for the state " + operatorState.getOperatorID()); } } } } public static Map> applyRepartitioner( OperatorID operatorID, OperatorStateRepartitioner opStateRepartitioner, List> chainOpParallelStates, int oldParallelism, int newParallelism) { List> states = applyRepartitioner( opStateRepartitioner, chainOpParallelStates, oldParallelism, newParallelism); Map> result = new HashMap<>(states.size()); for (int subtaskIndex = 0; subtaskIndex < states.size(); subtaskIndex++) { checkNotNull(states.get(subtaskIndex) != null, "states.get(subtaskIndex) is null"); result.put(OperatorInstanceID.of(subtaskIndex, operatorID), states.get(subtaskIndex)); } return result; } /** * Repartitions the given operator state using the given {@link OperatorStateRepartitioner} with respect to the new * parallelism. * * @param opStateRepartitioner partitioner to use * @param chainOpParallelStates state to repartition * @param oldParallelism parallelism with which the state is currently partitioned * @param newParallelism parallelism with which the state should be partitioned * @return repartitioned state */ // TODO rewrite based on operator id public static List> applyRepartitioner( OperatorStateRepartitioner opStateRepartitioner, List> chainOpParallelStates, int oldParallelism, int newParallelism) { if (chainOpParallelStates == null) { return emptyList(); } return opStateRepartitioner.repartitionState( chainOpParallelStates, oldParallelism, newParallelism); } static > OperatorStateRepartitioner channelStateNonRescalingRepartitioner(String logStateName) { return (previousParallelSubtaskStates, oldParallelism, newParallelism) -> { Preconditions.checkArgument( oldParallelism == newParallelism || previousParallelSubtaskStates.stream() .flatMap(s -> s.stream().map(l -> l.getOffsets())) .allMatch(List::isEmpty), String.format("rescaling not supported for %s state (old: %d, new: %d)", logStateName, oldParallelism, newParallelism)); return previousParallelSubtaskStates; }; } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy