All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.flink.runtime.state.ttl.TtlStateFactory Maven / Gradle / Ivy

There is a newer version: 1.13.6
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.flink.runtime.state.ttl;

import org.apache.flink.api.common.state.AggregatingStateDescriptor;
import org.apache.flink.api.common.state.FoldingStateDescriptor;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.state.MapStateDescriptor;
import org.apache.flink.api.common.state.ReducingStateDescriptor;
import org.apache.flink.api.common.state.State;
import org.apache.flink.api.common.state.StateDescriptor;
import org.apache.flink.api.common.state.StateTtlConfig;
import org.apache.flink.api.common.state.ValueStateDescriptor;
import org.apache.flink.api.common.typeutils.CompositeSerializer;
import org.apache.flink.api.common.typeutils.CompositeTypeSerializerSnapshot;
import org.apache.flink.api.common.typeutils.CompositeTypeSerializerUtil;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.common.typeutils.TypeSerializerConfigSnapshot;
import org.apache.flink.api.common.typeutils.TypeSerializerSchemaCompatibility;
import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot;
import org.apache.flink.api.common.typeutils.base.ListSerializer;
import org.apache.flink.api.common.typeutils.base.LongSerializer;
import org.apache.flink.api.common.typeutils.base.MapSerializer;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.runtime.state.KeyedStateBackend;
import org.apache.flink.runtime.state.StateSnapshotTransformer.StateSnapshotTransformFactory;
import org.apache.flink.runtime.state.internal.InternalKvState;
import org.apache.flink.util.FlinkRuntimeException;
import org.apache.flink.util.Preconditions;
import org.apache.flink.util.function.SupplierWithException;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;

import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.Stream;

/**
 * This state factory wraps state objects, produced by backends, with TTL logic.
 */
public class TtlStateFactory {
	public static  IS createStateAndWrapWithTtlIfEnabled(
		TypeSerializer namespaceSerializer,
		StateDescriptor stateDesc,
		KeyedStateBackend stateBackend,
		TtlTimeProvider timeProvider) throws Exception {
		Preconditions.checkNotNull(namespaceSerializer);
		Preconditions.checkNotNull(stateDesc);
		Preconditions.checkNotNull(stateBackend);
		Preconditions.checkNotNull(timeProvider);
		return  stateDesc.getTtlConfig().isEnabled() ?
			new TtlStateFactory(
				namespaceSerializer, stateDesc, stateBackend, timeProvider)
				.createState() :
			stateBackend.createInternalState(namespaceSerializer, stateDesc);
	}

	private final Map, SupplierWithException> stateFactories;

	@Nonnull
	private final TypeSerializer namespaceSerializer;
	@Nonnull
	private final StateDescriptor stateDesc;
	@Nonnull
	private final KeyedStateBackend stateBackend;
	@Nonnull
	private final StateTtlConfig ttlConfig;
	@Nonnull
	private final TtlTimeProvider timeProvider;
	private final long ttl;
	@Nullable
	private final TtlIncrementalCleanup incrementalCleanup;

	private TtlStateFactory(
		@Nonnull TypeSerializer namespaceSerializer,
		@Nonnull StateDescriptor stateDesc,
		@Nonnull KeyedStateBackend stateBackend,
		@Nonnull TtlTimeProvider timeProvider) {
		this.namespaceSerializer = namespaceSerializer;
		this.stateDesc = stateDesc;
		this.stateBackend = stateBackend;
		this.ttlConfig = stateDesc.getTtlConfig();
		this.timeProvider = timeProvider;
		this.ttl = ttlConfig.getTtl().toMilliseconds();
		this.stateFactories = createStateFactories();
		this.incrementalCleanup = getTtlIncrementalCleanup();
	}

	@SuppressWarnings("deprecation")
	private Map, SupplierWithException> createStateFactories() {
		return Stream.of(
			Tuple2.of(ValueStateDescriptor.class, (SupplierWithException) this::createValueState),
			Tuple2.of(ListStateDescriptor.class, (SupplierWithException) this::createListState),
			Tuple2.of(MapStateDescriptor.class, (SupplierWithException) this::createMapState),
			Tuple2.of(ReducingStateDescriptor.class, (SupplierWithException) this::createReducingState),
			Tuple2.of(AggregatingStateDescriptor.class, (SupplierWithException) this::createAggregatingState),
			Tuple2.of(FoldingStateDescriptor.class, (SupplierWithException) this::createFoldingState)
		).collect(Collectors.toMap(t -> t.f0, t -> t.f1));
	}

	@SuppressWarnings("unchecked")
	private IS createState() throws Exception {
		SupplierWithException stateFactory = stateFactories.get(stateDesc.getClass());
		if (stateFactory == null) {
			String message = String.format("State %s is not supported by %s",
				stateDesc.getClass(), TtlStateFactory.class);
			throw new FlinkRuntimeException(message);
		}
		IS state = stateFactory.get();
		if (incrementalCleanup != null) {
			incrementalCleanup.setTtlState((AbstractTtlState) state);
		}
		return state;
	}

	@SuppressWarnings("unchecked")
	private IS createValueState() throws Exception {
		ValueStateDescriptor> ttlDescriptor = new ValueStateDescriptor<>(
			stateDesc.getName(), new TtlSerializer<>(LongSerializer.INSTANCE, stateDesc.getSerializer()));
		return (IS) new TtlValueState<>(createTtlStateContext(ttlDescriptor));
	}

	@SuppressWarnings("unchecked")
	private  IS createListState() throws Exception {
		ListStateDescriptor listStateDesc = (ListStateDescriptor) stateDesc;
		ListStateDescriptor> ttlDescriptor = new ListStateDescriptor<>(
			stateDesc.getName(), new TtlSerializer<>(LongSerializer.INSTANCE, listStateDesc.getElementSerializer()));
		return (IS) new TtlListState<>(createTtlStateContext(ttlDescriptor));
	}

	@SuppressWarnings("unchecked")
	private  IS createMapState() throws Exception {
		MapStateDescriptor mapStateDesc = (MapStateDescriptor) stateDesc;
		MapStateDescriptor> ttlDescriptor = new MapStateDescriptor<>(
			stateDesc.getName(),
			mapStateDesc.getKeySerializer(),
			new TtlSerializer<>(LongSerializer.INSTANCE, mapStateDesc.getValueSerializer()));
		return (IS) new TtlMapState<>(createTtlStateContext(ttlDescriptor));
	}

	@SuppressWarnings("unchecked")
	private IS createReducingState() throws Exception {
		ReducingStateDescriptor reducingStateDesc = (ReducingStateDescriptor) stateDesc;
		ReducingStateDescriptor> ttlDescriptor = new ReducingStateDescriptor<>(
			stateDesc.getName(),
			new TtlReduceFunction<>(reducingStateDesc.getReduceFunction(), ttlConfig, timeProvider),
			new TtlSerializer<>(LongSerializer.INSTANCE, stateDesc.getSerializer()));
		return (IS) new TtlReducingState<>(createTtlStateContext(ttlDescriptor));
	}

	@SuppressWarnings("unchecked")
	private  IS createAggregatingState() throws Exception {
		AggregatingStateDescriptor aggregatingStateDescriptor =
			(AggregatingStateDescriptor) stateDesc;
		TtlAggregateFunction ttlAggregateFunction = new TtlAggregateFunction<>(
			aggregatingStateDescriptor.getAggregateFunction(), ttlConfig, timeProvider);
		AggregatingStateDescriptor, OUT> ttlDescriptor = new AggregatingStateDescriptor<>(
			stateDesc.getName(), ttlAggregateFunction, new TtlSerializer<>(LongSerializer.INSTANCE, stateDesc.getSerializer()));
		return (IS) new TtlAggregatingState<>(createTtlStateContext(ttlDescriptor), ttlAggregateFunction);
	}

	@SuppressWarnings({"deprecation", "unchecked"})
	private  IS createFoldingState() throws Exception {
		FoldingStateDescriptor foldingStateDescriptor = (FoldingStateDescriptor) stateDesc;
		SV initAcc = stateDesc.getDefaultValue();
		TtlValue ttlInitAcc = initAcc == null ? null : new TtlValue<>(initAcc, Long.MAX_VALUE);
		FoldingStateDescriptor> ttlDescriptor = new FoldingStateDescriptor<>(
			stateDesc.getName(),
			ttlInitAcc,
			new TtlFoldFunction<>(foldingStateDescriptor.getFoldFunction(), ttlConfig, timeProvider, initAcc),
			new TtlSerializer<>(LongSerializer.INSTANCE, stateDesc.getSerializer()));
		return (IS) new TtlFoldingState<>(createTtlStateContext(ttlDescriptor));
	}

	@SuppressWarnings("unchecked")
	private  TtlStateContext
		createTtlStateContext(StateDescriptor ttlDescriptor) throws Exception {

		ttlDescriptor.enableTimeToLive(stateDesc.getTtlConfig()); // also used by RocksDB backend for TTL compaction filter config
		OIS originalState = (OIS) stateBackend.createInternalState(
			namespaceSerializer, ttlDescriptor, getSnapshotTransformFactory());
		return new TtlStateContext<>(
			originalState, ttlConfig, timeProvider, (TypeSerializer) stateDesc.getSerializer(),
			registerTtlIncrementalCleanupCallback((InternalKvState) originalState));
	}

	private TtlIncrementalCleanup getTtlIncrementalCleanup() {
		StateTtlConfig.IncrementalCleanupStrategy config =
			ttlConfig.getCleanupStrategies().getIncrementalCleanupStrategy();
		return config != null ? new TtlIncrementalCleanup<>(config.getCleanupSize()) : null;
	}

	private Runnable registerTtlIncrementalCleanupCallback(InternalKvState originalState) {
		StateTtlConfig.IncrementalCleanupStrategy config =
			ttlConfig.getCleanupStrategies().getIncrementalCleanupStrategy();
		boolean cleanupConfigured = config != null && incrementalCleanup != null;
		boolean isCleanupActive = cleanupConfigured &&
			isStateIteratorSupported(originalState, incrementalCleanup.getCleanupSize());
		Runnable callback = isCleanupActive ? incrementalCleanup::stateAccessed : () -> { };
		if (isCleanupActive && config.runCleanupForEveryRecord()) {
			stateBackend.registerKeySelectionListener(stub -> callback.run());
		}
		return callback;
	}

	private boolean isStateIteratorSupported(InternalKvState originalState, int size) {
		boolean stateIteratorSupported = false;
		try {
			stateIteratorSupported = originalState.getStateIncrementalVisitor(size) != null;
		} catch (Throwable t) {
			// ignore
		}
		return stateIteratorSupported;
	}

	private StateSnapshotTransformFactory getSnapshotTransformFactory() {
		if (!ttlConfig.getCleanupStrategies().inFullSnapshot()) {
			return StateSnapshotTransformFactory.noTransform();
		} else {
			return new TtlStateSnapshotTransformer.Factory<>(timeProvider, ttl);
		}
	}

	/**
	 * Serializer for user state value with TTL. Visibility is public for usage with external tools.
	 */
	public static class TtlSerializer extends CompositeSerializer>
			implements TypeSerializerConfigSnapshot.SelfResolvingTypeSerializer> {
		private static final long serialVersionUID = 131020282727167064L;

		@SuppressWarnings("WeakerAccess")
		public TtlSerializer(TypeSerializer timestampSerializer, TypeSerializer userValueSerializer) {
			super(true, timestampSerializer, userValueSerializer);
		}

		@SuppressWarnings("WeakerAccess")
		public TtlSerializer(PrecomputedParameters precomputed, TypeSerializer ... fieldSerializers) {
			super(precomputed, fieldSerializers);
		}

		@SuppressWarnings("unchecked")
		@Override
		public TtlValue createInstance(@Nonnull Object ... values) {
			Preconditions.checkArgument(values.length == 2);
			return new TtlValue<>((T) values[1], (long) values[0]);
		}

		@Override
		protected void setField(@Nonnull TtlValue v, int index, Object fieldValue) {
			throw new UnsupportedOperationException("TtlValue is immutable");
		}

		@Override
		protected Object getField(@Nonnull TtlValue v, int index) {
			return index == 0 ? v.getLastAccessTimestamp() : v.getUserValue();
		}

		@SuppressWarnings("unchecked")
		@Override
		protected CompositeSerializer> createSerializerInstance(
			PrecomputedParameters precomputed,
			TypeSerializer ... originalSerializers) {
			Preconditions.checkNotNull(originalSerializers);
			Preconditions.checkArgument(originalSerializers.length == 2);
			return new TtlSerializer<>(precomputed, originalSerializers);
		}

		@SuppressWarnings("unchecked")
		TypeSerializer getTimestampSerializer() {
			return (TypeSerializer) (TypeSerializer) fieldSerializers[0];
		}

		@SuppressWarnings("unchecked")
		TypeSerializer getValueSerializer() {
			return (TypeSerializer) fieldSerializers[1];
		}

		@Override
		public TypeSerializerSnapshot> snapshotConfiguration() {
			return new TtlSerializerSnapshot<>(this);
		}

		@Override
		public TypeSerializerSchemaCompatibility> resolveSchemaCompatibilityViaRedirectingToNewSnapshotClass(
				TypeSerializerConfigSnapshot> deprecatedConfigSnapshot) {

			if (deprecatedConfigSnapshot instanceof ConfigSnapshot) {
				ConfigSnapshot castedLegacyConfigSnapshot = (ConfigSnapshot) deprecatedConfigSnapshot;
				TtlSerializerSnapshot newSnapshot = new TtlSerializerSnapshot<>();

				return CompositeTypeSerializerUtil.delegateCompatibilityCheckToNewSnapshot(
					this,
					newSnapshot,
					castedLegacyConfigSnapshot.getNestedSerializerSnapshots());
			}

			return TypeSerializerSchemaCompatibility.incompatible();
		}

		public static boolean isTtlStateSerializer(TypeSerializer typeSerializer) {
			boolean ttlSerializer = typeSerializer instanceof TtlStateFactory.TtlSerializer;
			boolean ttlListSerializer = typeSerializer instanceof ListSerializer &&
				((ListSerializer) typeSerializer).getElementSerializer() instanceof TtlStateFactory.TtlSerializer;
			boolean ttlMapSerializer = typeSerializer instanceof MapSerializer &&
				((MapSerializer) typeSerializer).getValueSerializer() instanceof TtlStateFactory.TtlSerializer;
			return ttlSerializer || ttlListSerializer || ttlMapSerializer;
		}
	}

	/**
	 * A {@link TypeSerializerSnapshot} for TtlSerializer.
	 */
	public static final class TtlSerializerSnapshot extends CompositeTypeSerializerSnapshot, TtlSerializer> {

		private static final int VERSION = 2;

		@SuppressWarnings({"WeakerAccess", "unused"})
		public TtlSerializerSnapshot() {
			super(TtlSerializer.class);
		}

		TtlSerializerSnapshot(TtlSerializer serializerInstance) {
			super(serializerInstance);
		}

		@Override
		protected int getCurrentOuterSnapshotVersion() {
			return VERSION;
		}

		@Override
		protected TypeSerializer[] getNestedSerializers(TtlSerializer outerSerializer) {
			return new TypeSerializer[]{ outerSerializer.getTimestampSerializer(), outerSerializer.getValueSerializer()};
		}

		@Override
		@SuppressWarnings("unchecked")
		protected TtlSerializer createOuterSerializerWithNestedSerializers(TypeSerializer[] nestedSerializers) {
			TypeSerializer timestampSerializer = (TypeSerializer) nestedSerializers[0];
			TypeSerializer valueSerializer = (TypeSerializer) nestedSerializers[1];

			return new TtlSerializer<>(timestampSerializer, valueSerializer);
		}
	}
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy