org.apache.beam.runners.fnexecution.state.StateRequestHandlers Maven / Gradle / Ivy
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.runners.fnexecution.state;
import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Collections;
import java.util.EnumMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.ConcurrentHashMap;
import javax.annotation.concurrent.ThreadSafe;
import org.apache.beam.model.fnexecution.v1.BeamFnApi;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateAppendResponse;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateClearResponse;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateGetResponse;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateKey;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateKey.TypeCase;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateRequest;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateRequest.RequestCase;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateResponse;
import org.apache.beam.runners.fnexecution.control.ProcessBundleDescriptors.BagUserStateSpec;
import org.apache.beam.runners.fnexecution.control.ProcessBundleDescriptors.ExecutableProcessBundleDescriptor;
import org.apache.beam.runners.fnexecution.control.ProcessBundleDescriptors.SideInputSpec;
import org.apache.beam.runners.fnexecution.wire.ByteStringCoder;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.fn.stream.DataStreams;
import org.apache.beam.sdk.fn.stream.DataStreams.ElementDelimitedOutputStream;
import org.apache.beam.sdk.transforms.Materializations;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.sdk.util.common.Reiterable;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
/**
* A set of utility methods which construct {@link StateRequestHandler}s.
*
* TODO: Add a variant which works on {@link ByteString}s to remove encoding/decoding overhead.
*/
@SuppressWarnings({
"rawtypes", // TODO(https://github.com/apache/beam/issues/20447)
"nullness" // TODO(https://github.com/apache/beam/issues/20497)
})
public class StateRequestHandlers {
/**
* Marker interface that denotes some type of side input handler. The access pattern defines the
* underlying type.
*
*
* - access pattern: Java type
*
- {@code beam:side_input:iterable:v1}: {@link IterableSideInputHandler}
*
- {@code beam:side_input:multimap:v1}: {@link MultimapSideInputHandler}
*
*/
@ThreadSafe
public interface SideInputHandler {}
/**
* A handler for iterable side inputs.
*
* Note that this handler is expected to be thread safe as it will be invoked concurrently.
*/
@ThreadSafe
public interface IterableSideInputHandler extends SideInputHandler {
/**
* Returns an {@link Iterable} of values representing the side input for the given window.
*
* TODO: Add support for side input chunking and caching if a {@link Reiterable} is returned.
*/
Iterable get(W window);
/** Returns the {@link Coder} to use for the elements of the resulting values iterable. */
Coder elementCoder();
}
/**
* A handler for multimap side inputs.
*
* Note that this handler is expected to be thread safe as it will be invoked concurrently.
*/
@ThreadSafe
public interface MultimapSideInputHandler
extends SideInputHandler {
/**
* Returns an {@link Iterable} of keys representing the side input for the given window.
*
* TODO: Add support for side input chunking and caching if a {@link Reiterable} is returned.
*/
Iterable get(W window);
/**
* Returns an {@link Iterable} of values representing the side input for the given key and
* window.
*
* TODO: Add support for side input chunking and caching if a {@link Reiterable} is returned.
*/
Iterable get(K key, W window);
/** Returns the {@link Coder} to use for the elements of the resulting keys iterable. */
Coder keyCoder();
/** Returns the {@link Coder} to use for the elements of the resulting values iterable. */
Coder valueCoder();
}
/**
* A factory which constructs {@link MultimapSideInputHandler}s.
*
* Note that this factory should be thread safe because it will be invoked concurrently.
*/
@ThreadSafe
public interface SideInputHandlerFactory {
/**
* Returns an {@link IterableSideInputHandler} for the given {@code pTransformId}, {@code
* sideInputId}. The supplied {@code elementCoder} and {@code windowCoder} should be used to
* encode/decode their respective values.
*/
IterableSideInputHandler forIterableSideInput(
String pTransformId, String sideInputId, Coder elementCoder, Coder windowCoder);
/**
* Returns a {@link MultimapSideInputHandler} for the given {@code pTransformId}, {@code
* sideInputId}. The supplied {@code elementCoder} and {@code windowCoder} should be used to
* encode/decode their respective values.
*/
MultimapSideInputHandler forMultimapSideInput(
String pTransformId, String sideInputId, KvCoder elementCoder, Coder windowCoder);
/** Throws a {@link UnsupportedOperationException} on the first access. */
static SideInputHandlerFactory unsupported() {
return new SideInputHandlerFactory() {
@Override
public IterableSideInputHandler forIterableSideInput(
String pTransformId, String sideInputId, Coder elementCoder, Coder windowCoder) {
throw new UnsupportedOperationException(
String.format(
"The %s does not support handling sides inputs for PTransform %s with side "
+ "input id %s.",
SideInputHandlerFactory.class.getSimpleName(), pTransformId, sideInputId));
}
@Override
public
MultimapSideInputHandler forMultimapSideInput(
String pTransformId,
String sideInputId,
KvCoder elementCoder,
Coder windowCoder) {
throw new UnsupportedOperationException(
String.format(
"The %s does not support handling sides inputs for PTransform %s with side "
+ "input id %s.",
SideInputHandlerFactory.class.getSimpleName(), pTransformId, sideInputId));
}
};
}
}
/**
* A handler for bag user state.
*
* Note that this handler is expected to be thread safe as it will be invoked concurrently.
*/
@ThreadSafe
public interface BagUserStateHandler {
/**
* Returns an {@link Iterable} of values representing the bag user state for the given key and
* window.
*
* TODO: Add support for bag user state chunking and caching if a {@link Reiterable} is
* returned.
*/
Iterable get(K key, W window);
/** Appends the values to the bag user state for the given key and window. */
void append(K key, W window, Iterator values);
/** Clears the bag user state for the given key and window. */
void clear(K key, W window);
}
/**
* A factory which constructs {@link BagUserStateHandler}s.
*
* Note that this factory should be thread safe.
*/
@ThreadSafe
public interface BagUserStateHandlerFactory {
BagUserStateHandler forUserState(
String pTransformId,
String userStateId,
Coder keyCoder,
Coder valueCoder,
Coder windowCoder);
/** Throws a {@link UnsupportedOperationException} on the first access. */
static BagUserStateHandlerFactory unsupported() {
return (pTransformId, userStateId, keyCoder, valueCoder, windowCoder) -> {
throw new UnsupportedOperationException(
String.format(
"The %s does not support handling sides inputs for PTransform %s with user state "
+ "id %s.",
BagUserStateHandler.class.getSimpleName(), pTransformId, userStateId));
};
}
}
/**
* Returns a {@link StateRequestHandler} which delegates to the supplied handler depending on the
* {@link StateRequest}s {@link StateKey.TypeCase type}.
*
* An exception is thrown if a corresponding handler is not found.
*/
public static StateRequestHandler delegateBasedUponType(
EnumMap handlers) {
return new StateKeyTypeDelegatingStateRequestHandler(handlers);
}
/**
* A {@link StateRequestHandler} which delegates to the supplied handler depending on the {@link
* StateRequest}s {@link StateKey.TypeCase type}.
*
* An exception is thrown if a corresponding handler is not found.
*/
static class StateKeyTypeDelegatingStateRequestHandler implements StateRequestHandler {
private final EnumMap handlers;
StateKeyTypeDelegatingStateRequestHandler(
EnumMap handlers) {
this.handlers = handlers;
}
@Override
public CompletionStage handle(StateRequest request) throws Exception {
return handlers
.getOrDefault(request.getStateKey().getTypeCase(), this::handlerNotFound)
.handle(request);
}
@Override
public Iterable getCacheTokens() {
// Use loops here due to the horrible performance of Java Streams:
// https://medium.com/@milan.mimica/slow-like-a-stream-fast-like-a-loop-524f70391182
Set cacheTokens = new HashSet<>();
for (StateRequestHandler handler : handlers.values()) {
for (BeamFnApi.ProcessBundleRequest.CacheToken cacheToken : handler.getCacheTokens()) {
cacheTokens.add(cacheToken);
}
}
return cacheTokens;
}
private CompletionStage handlerNotFound(StateRequest request) {
CompletableFuture rval = new CompletableFuture<>();
rval.completeExceptionally(new IllegalStateException());
return rval;
}
}
/**
* Returns an adapter which converts a {@link SideInputHandlerFactory} to a {@link
* StateRequestHandler}.
*
* The {@link SideInputHandlerFactory} is required to handle all side inputs contained within
* the {@link ExecutableProcessBundleDescriptor}. See {@link
* ExecutableProcessBundleDescriptor#getSideInputSpecs} for the set of side inputs that are
* contained.
*
*
Instances of {@link MultimapSideInputHandler}s returned by the {@link
* SideInputHandlerFactory} are cached.
*/
public static StateRequestHandler forSideInputHandlerFactory(
Map> sideInputSpecs,
SideInputHandlerFactory sideInputHandlerFactory) {
return new StateRequestHandlerToSideInputHandlerFactoryAdapter(
sideInputSpecs, sideInputHandlerFactory);
}
/** An adapter which converts {@link SideInputHandlerFactory} to {@link StateRequestHandler}. */
static class StateRequestHandlerToSideInputHandlerFactoryAdapter implements StateRequestHandler {
private final Map> sideInputSpecs;
private final SideInputHandlerFactory sideInputHandlerFactory;
private final ConcurrentHashMap handlerCache;
StateRequestHandlerToSideInputHandlerFactoryAdapter(
Map> sideInputSpecs,
SideInputHandlerFactory sideInputHandlerFactory) {
this.sideInputSpecs = sideInputSpecs;
this.sideInputHandlerFactory = sideInputHandlerFactory;
this.handlerCache = new ConcurrentHashMap<>();
}
@Override
public CompletionStage handle(StateRequest request) throws Exception {
checkState(
RequestCase.GET.equals(request.getRequestCase()),
String.format("Unsupported request type %s for side input.", request.getRequestCase()));
try {
switch (request.getStateKey().getTypeCase()) {
case MULTIMAP_SIDE_INPUT:
{
StateKey.MultimapSideInput stateKey = request.getStateKey().getMultimapSideInput();
SideInputSpec, ?> referenceSpec =
sideInputSpecs.get(stateKey.getTransformId()).get(stateKey.getSideInputId());
MultimapSideInputHandler handler =
(MultimapSideInputHandler)
handlerCache.computeIfAbsent(referenceSpec, this::createHandler);
return handleGetMultimapValuesRequest(request, handler);
}
case MULTIMAP_KEYS_SIDE_INPUT:
{
StateKey.MultimapKeysSideInput stateKey =
request.getStateKey().getMultimapKeysSideInput();
SideInputSpec, ?> referenceSpec =
sideInputSpecs.get(stateKey.getTransformId()).get(stateKey.getSideInputId());
MultimapSideInputHandler handler =
(MultimapSideInputHandler)
handlerCache.computeIfAbsent(referenceSpec, this::createHandler);
return handleGetMultimapKeysRequest(request, handler);
}
case ITERABLE_SIDE_INPUT:
{
StateKey.IterableSideInput stateKey = request.getStateKey().getIterableSideInput();
SideInputSpec, ?> referenceSpec =
sideInputSpecs.get(stateKey.getTransformId()).get(stateKey.getSideInputId());
IterableSideInputHandler handler =
(IterableSideInputHandler)
handlerCache.computeIfAbsent(referenceSpec, this::createHandler);
return handleGetIterableValuesRequest(request, handler);
}
default:
throw new IllegalStateException(
String.format(
"Unsupported %s type %s, expected %s or %s",
StateRequest.class.getSimpleName(),
request.getStateKey().getTypeCase(),
TypeCase.MULTIMAP_SIDE_INPUT,
TypeCase.MULTIMAP_KEYS_SIDE_INPUT));
}
} catch (Exception e) {
CompletableFuture f = new CompletableFuture();
f.completeExceptionally(e);
return f;
}
}
private
CompletionStage handleGetMultimapKeysRequest(
StateRequest request, MultimapSideInputHandler handler) throws Exception {
// TODO: Add support for continuation tokens when handling state if the handler
// returned a {@link Reiterable}.
checkState(
request.getGet().getContinuationToken().isEmpty(),
"Continuation tokens are unsupported.");
StateKey.MultimapKeysSideInput stateKey = request.getStateKey().getMultimapKeysSideInput();
SideInputSpec, W> sideInputReferenceSpec =
sideInputSpecs.get(stateKey.getTransformId()).get(stateKey.getSideInputId());
W window = sideInputReferenceSpec.windowCoder().decode(stateKey.getWindow().newInput());
Iterable keys = handler.get(window);
List encodedValues = new ArrayList<>();
ElementDelimitedOutputStream outputStream = DataStreams.outbound(encodedValues::add);
for (K key : keys) {
handler.keyCoder().encode(key, outputStream);
outputStream.delimitElement();
}
outputStream.close();
StateResponse.Builder response = StateResponse.newBuilder();
response.setId(request.getId());
response.setGet(
StateGetResponse.newBuilder().setData(ByteString.copyFrom(encodedValues)).build());
return CompletableFuture.completedFuture(response);
}
private
CompletionStage handleGetMultimapValuesRequest(
StateRequest request, MultimapSideInputHandler handler) throws Exception {
// TODO: Add support for continuation tokens when handling state if the handler
// returned a {@link Reiterable}.
checkState(
request.getGet().getContinuationToken().isEmpty(),
"Continuation tokens are unsupported.");
StateKey.MultimapSideInput stateKey = request.getStateKey().getMultimapSideInput();
SideInputSpec, W> sideInputReferenceSpec =
sideInputSpecs.get(stateKey.getTransformId()).get(stateKey.getSideInputId());
W window = sideInputReferenceSpec.windowCoder().decode(stateKey.getWindow().newInput());
Iterable values =
handler.get(handler.keyCoder().decode(stateKey.getKey().newInput()), window);
List encodedValues = new ArrayList<>();
ElementDelimitedOutputStream outputStream = DataStreams.outbound(encodedValues::add);
for (V value : values) {
handler.valueCoder().encode(value, outputStream);
outputStream.delimitElement();
}
outputStream.close();
StateResponse.Builder response = StateResponse.newBuilder();
response.setId(request.getId());
response.setGet(
StateGetResponse.newBuilder().setData(ByteString.copyFrom(encodedValues)).build());
return CompletableFuture.completedFuture(response);
}
private
CompletionStage handleGetIterableValuesRequest(
StateRequest request, IterableSideInputHandler handler) throws Exception {
// TODO: Add support for continuation tokens when handling state if the handler
// returned a {@link Reiterable}.
checkState(
request.getGet().getContinuationToken().isEmpty(),
"Continuation tokens are unsupported.");
StateKey.IterableSideInput stateKey = request.getStateKey().getIterableSideInput();
SideInputSpec sideInputReferenceSpec =
sideInputSpecs.get(stateKey.getTransformId()).get(stateKey.getSideInputId());
W window = sideInputReferenceSpec.windowCoder().decode(stateKey.getWindow().newInput());
Iterable values = handler.get(window);
List encodedValues = new ArrayList<>();
ElementDelimitedOutputStream outputStream = DataStreams.outbound(encodedValues::add);
for (V value : values) {
handler.elementCoder().encode(value, outputStream);
outputStream.delimitElement();
}
outputStream.close();
StateResponse.Builder response = StateResponse.newBuilder();
response.setId(request.getId());
response.setGet(
StateGetResponse.newBuilder().setData(ByteString.copyFrom(encodedValues)).build());
return CompletableFuture.completedFuture(response);
}
private SideInputHandler createHandler(SideInputSpec, ?> cacheKey) {
switch (cacheKey.accessPattern().getUrn()) {
case Materializations.ITERABLE_MATERIALIZATION_URN:
return sideInputHandlerFactory.forIterableSideInput(
cacheKey.transformId(),
cacheKey.sideInputId(),
cacheKey.elementCoder(),
cacheKey.windowCoder());
case Materializations.MULTIMAP_MATERIALIZATION_URN:
return sideInputHandlerFactory.forMultimapSideInput(
cacheKey.transformId(),
cacheKey.sideInputId(),
(KvCoder) cacheKey.elementCoder(),
cacheKey.windowCoder());
default:
throw new IllegalStateException(
String.format("Unsupported access pattern for side input %s", cacheKey));
}
}
}
/**
* Returns an adapter which converts a {@link BagUserStateHandlerFactory} to a {@link
* StateRequestHandler}.
*
* The {@link SideInputHandlerFactory} is required to handle all multimap side inputs contained
* within the {@link ExecutableProcessBundleDescriptor}. See {@link
* ExecutableProcessBundleDescriptor#getSideInputSpecs} for the set of multimap side inputs that
* are contained.
*
*
Instances of {@link MultimapSideInputHandler}s returned by the {@link
* SideInputHandlerFactory} are cached.
*
*
In case of any failures, this handler must be discarded. Otherwise, the contained state
* cache token would be reused which would corrupt the state cache.
*/
public static StateRequestHandler forBagUserStateHandlerFactory(
ExecutableProcessBundleDescriptor processBundleDescriptor,
BagUserStateHandlerFactory bagUserStateHandlerFactory) {
return new ByteStringStateRequestHandlerToBagUserStateHandlerFactoryAdapter(
processBundleDescriptor, bagUserStateHandlerFactory);
}
/**
* An adapter which converts {@link BagUserStateHandlerFactory} to {@link StateRequestHandler}.
*/
static class ByteStringStateRequestHandlerToBagUserStateHandlerFactoryAdapter
implements StateRequestHandler {
private final ExecutableProcessBundleDescriptor processBundleDescriptor;
private final BagUserStateHandlerFactory handlerFactory;
private final ConcurrentHashMap handlerCache;
private final BeamFnApi.ProcessBundleRequest.CacheToken cacheToken;
ByteStringStateRequestHandlerToBagUserStateHandlerFactoryAdapter(
ExecutableProcessBundleDescriptor processBundleDescriptor,
BagUserStateHandlerFactory handlerFactory) {
this.processBundleDescriptor = processBundleDescriptor;
this.handlerFactory = handlerFactory;
this.handlerCache = new ConcurrentHashMap<>();
this.cacheToken = createCacheToken();
}
@Override
public CompletionStage handle(StateRequest request) throws Exception {
try {
checkState(
TypeCase.BAG_USER_STATE.equals(request.getStateKey().getTypeCase()),
"Unsupported %s type %s, expected %s",
StateRequest.class.getSimpleName(),
request.getStateKey().getTypeCase(),
TypeCase.BAG_USER_STATE);
StateKey.BagUserState stateKey = request.getStateKey().getBagUserState();
BagUserStateSpec