All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.flink.runtime.state.ttl.TtlStateFactory Maven / Gradle / Ivy

There is a newer version: 1.13.6
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.flink.runtime.state.ttl;

import org.apache.flink.api.common.state.AggregatingStateDescriptor;
import org.apache.flink.api.common.state.FoldingStateDescriptor;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.state.MapStateDescriptor;
import org.apache.flink.api.common.state.ReducingStateDescriptor;
import org.apache.flink.api.common.state.State;
import org.apache.flink.api.common.state.StateDescriptor;
import org.apache.flink.api.common.state.StateTtlConfig;
import org.apache.flink.api.common.state.ValueStateDescriptor;
import org.apache.flink.api.common.typeutils.CompositeSerializer;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.common.typeutils.base.LongSerializer;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.runtime.state.KeyedStateFactory;
import org.apache.flink.runtime.state.StateSnapshotTransformer.StateSnapshotTransformFactory;
import org.apache.flink.util.FlinkRuntimeException;
import org.apache.flink.util.Preconditions;
import org.apache.flink.util.function.SupplierWithException;

import javax.annotation.Nonnull;

import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.Stream;

/**
 * This state factory wraps state objects, produced by backends, with TTL logic.
 */
public class TtlStateFactory {
	public static  IS createStateAndWrapWithTtlIfEnabled(
		TypeSerializer namespaceSerializer,
		StateDescriptor stateDesc,
		KeyedStateFactory originalStateFactory,
		TtlTimeProvider timeProvider) throws Exception {
		Preconditions.checkNotNull(namespaceSerializer);
		Preconditions.checkNotNull(stateDesc);
		Preconditions.checkNotNull(originalStateFactory);
		Preconditions.checkNotNull(timeProvider);
		return  stateDesc.getTtlConfig().isEnabled() ?
			new TtlStateFactory(
				namespaceSerializer, stateDesc, originalStateFactory, timeProvider)
				.createState() :
			originalStateFactory.createInternalState(namespaceSerializer, stateDesc);
	}

	private final Map, SupplierWithException> stateFactories;

	private final TypeSerializer namespaceSerializer;
	private final StateDescriptor stateDesc;
	private final KeyedStateFactory originalStateFactory;
	private final StateTtlConfig ttlConfig;
	private final TtlTimeProvider timeProvider;
	private final long ttl;

	private TtlStateFactory(
		TypeSerializer namespaceSerializer,
		StateDescriptor stateDesc,
		KeyedStateFactory originalStateFactory,
		TtlTimeProvider timeProvider) {
		this.namespaceSerializer = namespaceSerializer;
		this.stateDesc = stateDesc;
		this.originalStateFactory = originalStateFactory;
		this.ttlConfig = stateDesc.getTtlConfig();
		this.timeProvider = timeProvider;
		this.ttl = ttlConfig.getTtl().toMilliseconds();
		this.stateFactories = createStateFactories();
	}

	@SuppressWarnings("deprecation")
	private Map, SupplierWithException> createStateFactories() {
		return Stream.of(
			Tuple2.of(ValueStateDescriptor.class, (SupplierWithException) this::createValueState),
			Tuple2.of(ListStateDescriptor.class, (SupplierWithException) this::createListState),
			Tuple2.of(MapStateDescriptor.class, (SupplierWithException) this::createMapState),
			Tuple2.of(ReducingStateDescriptor.class, (SupplierWithException) this::createReducingState),
			Tuple2.of(AggregatingStateDescriptor.class, (SupplierWithException) this::createAggregatingState),
			Tuple2.of(FoldingStateDescriptor.class, (SupplierWithException) this::createFoldingState)
		).collect(Collectors.toMap(t -> t.f0, t -> t.f1));
	}

	private IS createState() throws Exception {
		SupplierWithException stateFactory = stateFactories.get(stateDesc.getClass());
		if (stateFactory == null) {
			String message = String.format("State %s is not supported by %s",
				stateDesc.getClass(), TtlStateFactory.class);
			throw new FlinkRuntimeException(message);
		}
		return stateFactory.get();
	}

	@SuppressWarnings("unchecked")
	private IS createValueState() throws Exception {
		ValueStateDescriptor> ttlDescriptor = new ValueStateDescriptor<>(
			stateDesc.getName(), new TtlSerializer<>(stateDesc.getSerializer()));
		return (IS) new TtlValueState<>(
			originalStateFactory.createInternalState(namespaceSerializer, ttlDescriptor, getSnapshotTransformFactory()),
			ttlConfig, timeProvider, stateDesc.getSerializer());
	}

	@SuppressWarnings("unchecked")
	private  IS createListState() throws Exception {
		ListStateDescriptor listStateDesc = (ListStateDescriptor) stateDesc;
		ListStateDescriptor> ttlDescriptor = new ListStateDescriptor<>(
			stateDesc.getName(), new TtlSerializer<>(listStateDesc.getElementSerializer()));
		return (IS) new TtlListState<>(
			originalStateFactory.createInternalState(
				namespaceSerializer, ttlDescriptor, getSnapshotTransformFactory()),
			ttlConfig, timeProvider, listStateDesc.getSerializer());
	}

	@SuppressWarnings("unchecked")
	private  IS createMapState() throws Exception {
		MapStateDescriptor mapStateDesc = (MapStateDescriptor) stateDesc;
		MapStateDescriptor> ttlDescriptor = new MapStateDescriptor<>(
			stateDesc.getName(),
			mapStateDesc.getKeySerializer(),
			new TtlSerializer<>(mapStateDesc.getValueSerializer()));
		return (IS) new TtlMapState<>(
			originalStateFactory.createInternalState(namespaceSerializer, ttlDescriptor, getSnapshotTransformFactory()),
			ttlConfig, timeProvider, mapStateDesc.getSerializer());
	}

	@SuppressWarnings("unchecked")
	private IS createReducingState() throws Exception {
		ReducingStateDescriptor reducingStateDesc = (ReducingStateDescriptor) stateDesc;
		ReducingStateDescriptor> ttlDescriptor = new ReducingStateDescriptor<>(
			stateDesc.getName(),
			new TtlReduceFunction<>(reducingStateDesc.getReduceFunction(), ttlConfig, timeProvider),
			new TtlSerializer<>(stateDesc.getSerializer()));
		return (IS) new TtlReducingState<>(
			originalStateFactory.createInternalState(namespaceSerializer, ttlDescriptor, getSnapshotTransformFactory()),
			ttlConfig, timeProvider, stateDesc.getSerializer());
	}

	@SuppressWarnings("unchecked")
	private  IS createAggregatingState() throws Exception {
		AggregatingStateDescriptor aggregatingStateDescriptor =
			(AggregatingStateDescriptor) stateDesc;
		TtlAggregateFunction ttlAggregateFunction = new TtlAggregateFunction<>(
			aggregatingStateDescriptor.getAggregateFunction(), ttlConfig, timeProvider);
		AggregatingStateDescriptor, OUT> ttlDescriptor = new AggregatingStateDescriptor<>(
			stateDesc.getName(), ttlAggregateFunction, new TtlSerializer<>(stateDesc.getSerializer()));
		return (IS) new TtlAggregatingState<>(
			originalStateFactory.createInternalState(namespaceSerializer, ttlDescriptor, getSnapshotTransformFactory()),
			ttlConfig, timeProvider, stateDesc.getSerializer(), ttlAggregateFunction);
	}

	@SuppressWarnings({"deprecation", "unchecked"})
	private  IS createFoldingState() throws Exception {
		FoldingStateDescriptor foldingStateDescriptor = (FoldingStateDescriptor) stateDesc;
		SV initAcc = stateDesc.getDefaultValue();
		TtlValue ttlInitAcc = initAcc == null ? null : new TtlValue<>(initAcc, Long.MAX_VALUE);
		FoldingStateDescriptor> ttlDescriptor = new FoldingStateDescriptor<>(
			stateDesc.getName(),
			ttlInitAcc,
			new TtlFoldFunction<>(foldingStateDescriptor.getFoldFunction(), ttlConfig, timeProvider, initAcc),
			new TtlSerializer<>(stateDesc.getSerializer()));
		return (IS) new TtlFoldingState<>(
			originalStateFactory.createInternalState(namespaceSerializer, ttlDescriptor, getSnapshotTransformFactory()),
			ttlConfig, timeProvider, stateDesc.getSerializer());
	}

	private StateSnapshotTransformFactory getSnapshotTransformFactory() {
		if (!ttlConfig.getCleanupStrategies().inFullSnapshot()) {
			return StateSnapshotTransformFactory.noTransform();
		} else {
			return new TtlStateSnapshotTransformer.Factory<>(timeProvider, ttl);
		}
	}

	/**
	 * Serializer for user state value with TTL. Visibility is public for usage with external tools.
	 */
	public static class TtlSerializer extends CompositeSerializer> {
		private static final long serialVersionUID = 131020282727167064L;

		public TtlSerializer(TypeSerializer userValueSerializer) {
			super(true, LongSerializer.INSTANCE, userValueSerializer);
		}

		public TtlSerializer(PrecomputedParameters precomputed, TypeSerializer ... fieldSerializers) {
			super(precomputed, fieldSerializers);
		}

		@SuppressWarnings("unchecked")
		@Override
		public TtlValue createInstance(@Nonnull Object ... values) {
			Preconditions.checkArgument(values.length == 2);
			return new TtlValue<>((T) values[1], (long) values[0]);
		}

		@Override
		protected void setField(@Nonnull TtlValue v, int index, Object fieldValue) {
			throw new UnsupportedOperationException("TtlValue is immutable");
		}

		@Override
		protected Object getField(@Nonnull TtlValue v, int index) {
			return index == 0 ? v.getLastAccessTimestamp() : v.getUserValue();
		}

		@SuppressWarnings("unchecked")
		@Override
		protected CompositeSerializer> createSerializerInstance(
			PrecomputedParameters precomputed,
			TypeSerializer ... originalSerializers) {
			Preconditions.checkNotNull(originalSerializers);
			Preconditions.checkArgument(originalSerializers.length == 2);
			return new TtlSerializer<>(precomputed, originalSerializers);
		}
	}
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy