org.apache.beam.sdk.transforms.DoFnTester Maven / Gradle / Ivy
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.sdk.transforms;
import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.base.Preconditions.checkState;
import com.google.common.base.Function;
import com.google.common.base.MoreObjects;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.beam.sdk.annotations.Experimental;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.apache.beam.sdk.testing.ValueInSingleWindow;
import org.apache.beam.sdk.transforms.Combine.CombineFn;
import org.apache.beam.sdk.transforms.DoFn.OnTimerContext;
import org.apache.beam.sdk.transforms.reflect.DoFnInvoker;
import org.apache.beam.sdk.transforms.reflect.DoFnInvokers;
import org.apache.beam.sdk.transforms.reflect.DoFnSignature;
import org.apache.beam.sdk.transforms.reflect.DoFnSignatures;
import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
import org.apache.beam.sdk.transforms.windowing.PaneInfo;
import org.apache.beam.sdk.util.SerializableUtils;
import org.apache.beam.sdk.util.Timer;
import org.apache.beam.sdk.util.UserCodeException;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.util.state.InMemoryStateInternals;
import org.apache.beam.sdk.util.state.StateInternals;
import org.apache.beam.sdk.values.PCollectionView;
import org.apache.beam.sdk.values.TimestampedValue;
import org.apache.beam.sdk.values.TupleTag;
import org.joda.time.Instant;
/**
* A harness for unit-testing a {@link DoFn}.
*
* For example:
*
*
{@code
* DoFn fn = ...;
*
* DoFnTester fnTester = DoFnTester.of(fn);
*
* // Set arguments shared across all bundles:
* fnTester.setSideInputs(...); // If fn takes side inputs.
* fnTester.setSideOutputTags(...); // If fn writes to side outputs.
*
* // Process a bundle containing a single input element:
* Input testInput = ...;
* List testOutputs = fnTester.processBundle(testInput);
* Assert.assertThat(testOutputs, Matchers.hasItems(...));
*
* // Process a bigger bundle:
* Assert.assertThat(fnTester.processBundle(i1, i2, ...), Matchers.hasItems(...));
* }
*
* @param the type of the {@link DoFn}'s (main) input elements
* @param the type of the {@link DoFn}'s (main) output elements
*/
public class DoFnTester implements AutoCloseable {
/**
* Returns a {@code DoFnTester} supporting unit-testing of the given
* {@link DoFn}. By default, uses {@link CloningBehavior#CLONE_ONCE}.
*
* The only supported extra parameter of the {@link DoFn.ProcessElement} method is
* {@link BoundedWindow}.
*/
@SuppressWarnings("unchecked")
public static DoFnTester of(DoFn fn) {
checkNotNull(fn, "fn can't be null");
return new DoFnTester<>(fn);
}
/**
* Registers the tuple of values of the side input {@link PCollectionView}s to
* pass to the {@link DoFn} under test.
*
* Resets the state of this {@link DoFnTester}.
*
*
If this isn't called, {@code DoFnTester} assumes the
* {@link DoFn} takes no side inputs.
*/
public void setSideInputs(Map, Map> sideInputs) {
checkState(
state == State.UNINITIALIZED,
"Can't add side inputs: DoFnTester is already initialized, in state %s",
state);
this.sideInputs = sideInputs;
}
/**
* Registers the values of a side input {@link PCollectionView} to pass to the {@link DoFn}
* under test.
*
* The provided value is the final value of the side input in the specified window, not
* the value of the input PCollection in that window.
*
*
If this isn't called, {@code DoFnTester} will return the default value for any side input
* that is used.
*/
public void setSideInput(PCollectionView sideInput, BoundedWindow window, T value) {
checkState(
state == State.UNINITIALIZED,
"Can't add side inputs: DoFnTester is already initialized, in state %s",
state);
Map windowValues = (Map) sideInputs.get(sideInput);
if (windowValues == null) {
windowValues = new HashMap<>();
sideInputs.put(sideInput, windowValues);
}
windowValues.put(window, value);
}
@SuppressWarnings("unchecked")
public StateInternals getStateInternals() {
return (StateInternals) stateInternals;
}
/**
* When a {@link DoFnTester} should clone the {@link DoFn} under test and how it should manage
* the lifecycle of the {@link DoFn}.
*/
public enum CloningBehavior {
/**
* Clone the {@link DoFn} and call {@link DoFn.Setup} every time a bundle starts; call {@link
* DoFn.Teardown} every time a bundle finishes.
*/
CLONE_PER_BUNDLE,
/**
* Clone the {@link DoFn} and call {@link DoFn.Setup} on the first access; call {@link
* DoFn.Teardown} only explicitly.
*/
CLONE_ONCE,
/**
* Do not clone the {@link DoFn}; call {@link DoFn.Setup} on the first access; call {@link
* DoFn.Teardown} only explicitly.
*/
DO_NOT_CLONE
}
/**
* Instruct this {@link DoFnTester} whether or not to clone the {@link DoFn} under test.
*/
public void setCloningBehavior(CloningBehavior newValue) {
checkState(state == State.UNINITIALIZED, "Wrong state: %s", state);
this.cloningBehavior = newValue;
}
/**
* Indicates whether this {@link DoFnTester} will clone the {@link DoFn} under test.
*/
public CloningBehavior getCloningBehavior() {
return cloningBehavior;
}
/**
* A convenience operation that first calls {@link #startBundle},
* then calls {@link #processElement} on each of the input elements, then
* calls {@link #finishBundle}, then returns the result of
* {@link #takeOutputElements}.
*/
public List processBundle(Iterable inputElements) throws Exception {
startBundle();
for (InputT inputElement : inputElements) {
processElement(inputElement);
}
finishBundle();
return takeOutputElements();
}
/**
* A convenience method for testing {@link DoFn DoFns} with bundles of elements.
* Logic proceeds as follows:
*
*
* - Calls {@link #startBundle}.
* - Calls {@link #processElement} on each of the arguments.
* - Calls {@link #finishBundle}.
* - Returns the result of {@link #takeOutputElements}.
*
*/
@SafeVarargs
public final List processBundle(InputT... inputElements) throws Exception {
return processBundle(Arrays.asList(inputElements));
}
/**
* Calls the {@link DoFn.StartBundle} method on the {@link DoFn} under test.
*
* If needed, first creates a fresh instance of the {@link DoFn} under test and calls
* {@link DoFn.Setup}.
*/
public void startBundle() throws Exception {
checkState(
state == State.UNINITIALIZED || state == State.BUNDLE_FINISHED,
"Wrong state during startBundle: %s",
state);
if (state == State.UNINITIALIZED) {
initializeState();
}
TestContext context = new TestContext();
context.setupDelegateAggregators();
// State and timer internals are per-bundle.
stateInternals = InMemoryStateInternals.forKey(new Object());
try {
fnInvoker.invokeStartBundle(context);
} catch (UserCodeException e) {
unwrapUserCodeException(e);
}
state = State.BUNDLE_STARTED;
}
private static void unwrapUserCodeException(UserCodeException e) throws Exception {
if (e.getCause() instanceof Exception) {
throw (Exception) e.getCause();
} else if (e.getCause() instanceof Error) {
throw (Error) e.getCause();
} else {
throw e;
}
}
/**
* Calls the {@link DoFn.ProcessElement} method on the {@link DoFn} under test, in a
* context where {@link DoFn.ProcessContext#element} returns the
* given element and the element is in the global window.
*
*
Will call {@link #startBundle} automatically, if it hasn't
* already been called.
*
* @throws IllegalStateException if the {@code DoFn} under test has already
* been finished
*/
public void processElement(InputT element) throws Exception {
processTimestampedElement(TimestampedValue.atMinimumTimestamp(element));
}
/**
* Calls {@link DoFn.ProcessElement} on the {@code DoFn} under test, in a
* context where {@link DoFn.ProcessContext#element} returns the
* given element and timestamp and the element is in the global window.
*
*
Will call {@link #startBundle} automatically, if it hasn't
* already been called.
*/
public void processTimestampedElement(TimestampedValue element) throws Exception {
checkNotNull(element, "Timestamped element cannot be null");
processWindowedElement(
element.getValue(), element.getTimestamp(), GlobalWindow.INSTANCE);
}
/**
* Calls {@link DoFn.ProcessElement} on the {@code DoFn} under test, in a
* context where {@link DoFn.ProcessContext#element} returns the
* given element and timestamp and the element is in the given window.
*
* Will call {@link #startBundle} automatically, if it hasn't
* already been called.
*/
public void processWindowedElement(
InputT element, Instant timestamp, final BoundedWindow window) throws Exception {
if (state != State.BUNDLE_STARTED) {
startBundle();
}
try {
final TestProcessContext processContext =
new TestProcessContext(
ValueInSingleWindow.of(element, timestamp, window, PaneInfo.NO_FIRING));
fnInvoker.invokeProcessElement(
new DoFnInvoker.ArgumentProvider() {
@Override
public BoundedWindow window() {
return window;
}
@Override
public DoFn.Context context(DoFn doFn) {
throw new UnsupportedOperationException(
"Not expected to access DoFn.Context from @ProcessElement");
}
@Override
public DoFn.ProcessContext processContext(DoFn doFn) {
return processContext;
}
@Override
public OnTimerContext onTimerContext(DoFn doFn) {
throw new UnsupportedOperationException(
"DoFnTester doesn't support timers yet.");
}
@Override
public DoFn.InputProvider inputProvider() {
throw new UnsupportedOperationException(
"Not expected to access InputProvider from DoFnTester");
}
@Override
public DoFn.OutputReceiver outputReceiver() {
throw new UnsupportedOperationException(
"Not expected to access OutputReceiver from DoFnTester");
}
@Override
public RestrictionTracker restrictionTracker() {
throw new UnsupportedOperationException(
"Not expected to access RestrictionTracker from a regular DoFn in DoFnTester");
}
@Override
public org.apache.beam.sdk.util.state.State state(String stateId) {
throw new UnsupportedOperationException("DoFnTester doesn't support state yet");
}
@Override
public Timer timer(String timerId) {
throw new UnsupportedOperationException("DoFnTester doesn't support timers yet");
}
});
} catch (UserCodeException e) {
unwrapUserCodeException(e);
}
}
/**
* Calls the {@link DoFn.FinishBundle} method of the {@link DoFn} under test.
*
* If {@link #setCloningBehavior} was called with {@link CloningBehavior#CLONE_PER_BUNDLE},
* then also calls {@link DoFn.Teardown} on the {@link DoFn}, and it will be cloned and
* {@link DoFn.Setup} again when processing the next bundle.
*
* @throws IllegalStateException if {@link DoFn.FinishBundle} has already been called
* for this bundle.
*/
public void finishBundle() throws Exception {
checkState(
state == State.BUNDLE_STARTED,
"Must be inside bundle to call finishBundle, but was: %s",
state);
try {
fnInvoker.invokeFinishBundle(new TestContext());
} catch (UserCodeException e) {
unwrapUserCodeException(e);
}
if (cloningBehavior == CloningBehavior.CLONE_PER_BUNDLE) {
fnInvoker.invokeTeardown();
fn = null;
fnInvoker = null;
state = State.UNINITIALIZED;
} else {
state = State.BUNDLE_FINISHED;
}
}
/**
* Returns the elements output so far to the main output. Does not
* clear them, so subsequent calls will continue to include these
* elements.
*
* @see #takeOutputElements
* @see #clearOutputElements
*
*/
public List peekOutputElements() {
return Lists.transform(
peekOutputElementsWithTimestamp(),
new Function, OutputT>() {
@Override
@SuppressWarnings("unchecked")
public OutputT apply(TimestampedValue input) {
return input.getValue();
}
});
}
/**
* Returns the elements output so far to the main output with associated timestamps. Does not
* clear them, so subsequent calls will continue to include these.
* elements.
*
* @see #takeOutputElementsWithTimestamp
* @see #clearOutputElements
*/
@Experimental
public List> peekOutputElementsWithTimestamp() {
// TODO: Should we return an unmodifiable list?
return Lists.transform(getImmutableOutput(mainOutputTag),
new Function, TimestampedValue>() {
@Override
@SuppressWarnings("unchecked")
public TimestampedValue apply(ValueInSingleWindow input) {
return TimestampedValue.of(input.getValue(), input.getTimestamp());
}
});
}
/**
* Returns the elements output so far to the main output in the provided window with associated
* timestamps.
*/
public List> peekOutputElementsInWindow(BoundedWindow window) {
return peekOutputElementsInWindow(mainOutputTag, window);
}
/**
* Returns the elements output so far to the specified output in the provided window with
* associated timestamps.
*/
public List> peekOutputElementsInWindow(
TupleTag tag,
BoundedWindow window) {
ImmutableList.Builder> valuesBuilder = ImmutableList.builder();
for (ValueInSingleWindow value : getImmutableOutput(tag)) {
if (value.getWindow().equals(window)) {
valuesBuilder.add(TimestampedValue.of(value.getValue(), value.getTimestamp()));
}
}
return valuesBuilder.build();
}
/**
* Clears the record of the elements output so far to the main output.
*
* @see #peekOutputElements
*/
public void clearOutputElements() {
getMutableOutput(mainOutputTag).clear();
}
/**
* Returns the elements output so far to the main output.
* Clears the list so these elements don't appear in future calls.
*
* @see #peekOutputElements
*/
public List takeOutputElements() {
List resultElems = new ArrayList<>(peekOutputElements());
clearOutputElements();
return resultElems;
}
/**
* Returns the elements output so far to the main output with associated timestamps.
* Clears the list so these elements don't appear in future calls.
*
* @see #peekOutputElementsWithTimestamp
* @see #takeOutputElements
* @see #clearOutputElements
*/
@Experimental
public List> takeOutputElementsWithTimestamp() {
List> resultElems =
new ArrayList<>(peekOutputElementsWithTimestamp());
clearOutputElements();
return resultElems;
}
/**
* Returns the elements output so far to the side output with the
* given tag. Does not clear them, so subsequent calls will
* continue to include these elements.
*
* @see #takeSideOutputElements
* @see #clearSideOutputElements
*/
public List peekSideOutputElements(TupleTag tag) {
// TODO: Should we return an unmodifiable list?
return Lists.transform(getImmutableOutput(tag),
new Function, T>() {
@SuppressWarnings("unchecked")
@Override
public T apply(ValueInSingleWindow input) {
return input.getValue();
}});
}
/**
* Clears the record of the elements output so far to the side
* output with the given tag.
*
* @see #peekSideOutputElements
*/
public void clearSideOutputElements(TupleTag tag) {
getMutableOutput(tag).clear();
}
/**
* Returns the elements output so far to the side output with the given tag.
* Clears the list so these elements don't appear in future calls.
*
* @see #peekSideOutputElements
*/
public List takeSideOutputElements(TupleTag tag) {
List resultElems = new ArrayList<>(peekSideOutputElements(tag));
clearSideOutputElements(tag);
return resultElems;
}
/**
* Returns the value of the provided {@link Aggregator}.
*/
public AggregateT getAggregatorValue(Aggregator agg) {
return extractAggregatorValue(agg.getName(), agg.getCombineFn());
}
private AggregateT extractAggregatorValue(
String name, CombineFn combiner) {
@SuppressWarnings("unchecked")
AccumT accumulator = (AccumT) accumulators.get(name);
if (accumulator == null) {
accumulator = combiner.createAccumulator();
}
return combiner.extractOutput(accumulator);
}
private List> getImmutableOutput(TupleTag tag) {
@SuppressWarnings({"unchecked", "rawtypes"})
List> elems = (List) outputs.get(tag);
return ImmutableList.copyOf(
MoreObjects.firstNonNull(elems, Collections.>emptyList()));
}
@SuppressWarnings({"unchecked", "rawtypes"})
public List> getMutableOutput(TupleTag tag) {
List> outputList = (List) outputs.get(tag);
if (outputList == null) {
outputList = new ArrayList<>();
outputs.put(tag, (List) outputList);
}
return outputList;
}
public TupleTag getMainOutputTag() {
return mainOutputTag;
}
private class TestContext extends DoFn.Context {
TestContext() {
fn.super();
}
@Override
public PipelineOptions getPipelineOptions() {
return options;
}
@Override
public void output(OutputT output) {
throwUnsupportedOutputFromBundleMethods();
}
@Override
public void outputWithTimestamp(OutputT output, Instant timestamp) {
throwUnsupportedOutputFromBundleMethods();
}
@Override
public void sideOutputWithTimestamp(TupleTag tag, T output, Instant timestamp) {
throwUnsupportedOutputFromBundleMethods();
}
@Override
public void sideOutput(TupleTag tag, T output) {
throwUnsupportedOutputFromBundleMethods();
}
private void throwUnsupportedOutputFromBundleMethods() {
throw new UnsupportedOperationException(
"DoFnTester doesn't support output from bundle methods");
}
@Override
protected Aggregator createAggregator(
final String name, final CombineFn combiner) {
return aggregator(name, combiner);
}
private Aggregator aggregator(
final String name,
final CombineFn combiner) {
Aggregator aggregator = new Aggregator() {
@Override
public void addValue(AinT value) {
AccT accum = (AccT) accumulators.get(name);
AccT newAccum = combiner.addInput(accum, value);
accumulators.put(name, newAccum);
}
@Override
public String getName() {
return name;
}
@Override
public CombineFn getCombineFn() {
return combiner;
}
};
// Aggregator instantiation is idempotent
if (accumulators.containsKey(name)) {
Class currentAccumClass = accumulators.get(name).getClass();
Class createAccumClass = combiner.createAccumulator().getClass();
checkState(
currentAccumClass.isAssignableFrom(createAccumClass),
"Aggregator %s already initialized with accumulator type %s "
+ "but was re-initialized with accumulator type %s",
name,
currentAccumClass,
createAccumClass);
} else {
accumulators.put(name, combiner.createAccumulator());
}
return aggregator;
}
}
private class TestProcessContext extends DoFn.ProcessContext {
private final TestContext context;
private final ValueInSingleWindow element;
private TestProcessContext(ValueInSingleWindow element) {
fn.super();
this.context = new TestContext();
this.element = element;
}
@Override
public InputT element() {
return element.getValue();
}
@Override
public T sideInput(PCollectionView view) {
Map viewValues = sideInputs.get(view);
if (viewValues != null) {
BoundedWindow sideInputWindow =
view.getWindowingStrategyInternal()
.getWindowFn()
.getSideInputWindow(element.getWindow());
@SuppressWarnings("unchecked")
T windowValue = (T) viewValues.get(sideInputWindow);
if (windowValue != null) {
return windowValue;
}
}
return view.getViewFn().apply(Collections.>emptyList());
}
@Override
public Instant timestamp() {
return element.getTimestamp();
}
@Override
public PaneInfo pane() {
return element.getPane();
}
@Override
public PipelineOptions getPipelineOptions() {
return context.getPipelineOptions();
}
@Override
public void output(OutputT output) {
sideOutput(mainOutputTag, output);
}
@Override
public void outputWithTimestamp(OutputT output, Instant timestamp) {
sideOutputWithTimestamp(mainOutputTag, output, timestamp);
}
@Override
public void sideOutput(TupleTag tag, T output) {
sideOutputWithTimestamp(tag, output, element.getTimestamp());
}
@Override
public void sideOutputWithTimestamp(TupleTag tag, T output, Instant timestamp) {
getMutableOutput(tag)
.add(ValueInSingleWindow.of(output, timestamp, element.getWindow(), element.getPane()));
}
@Override
protected Aggregator createAggregator(
String name, CombineFn combiner) {
throw new IllegalStateException("Aggregators should not be created within ProcessContext. "
+ "Instead, create an aggregator at DoFn construction time with"
+ " createAggregator, and ensure they are set up by the time startBundle is"
+ " called with setupDelegateAggregators.");
}
}
@Override
public void close() throws Exception {
if (state == State.BUNDLE_STARTED) {
finishBundle();
}
if (state == State.BUNDLE_FINISHED) {
fnInvoker.invokeTeardown();
fn = null;
fnInvoker = null;
}
state = State.TORN_DOWN;
}
/////////////////////////////////////////////////////////////////////////////
/** The possible states of processing a {@link DoFn}. */
private enum State {
UNINITIALIZED,
BUNDLE_STARTED,
BUNDLE_FINISHED,
TORN_DOWN
}
private final PipelineOptions options = PipelineOptionsFactory.create();
/** The original {@link DoFn} under test. */
private final DoFn origFn;
/**
* Whether to clone the original {@link DoFn} or just use it as-is.
*
* Worker-side {@link DoFn DoFns} may not be serializable, and are not required to be.
*/
private CloningBehavior cloningBehavior = CloningBehavior.CLONE_ONCE;
/** The side input values to provide to the {@link DoFn} under test. */
private Map, Map> sideInputs =
new HashMap<>();
private Map accumulators;
/** The output tags used by the {@link DoFn} under test. */
private TupleTag mainOutputTag = new TupleTag<>();
/** The original DoFn under test, if started. */
private DoFn fn;
private DoFnInvoker fnInvoker;
/** The outputs from the {@link DoFn} under test. */
private Map, List>> outputs;
private InMemoryStateInternals stateInternals;
/** The state of processing of the {@link DoFn} under test. */
private State state = State.UNINITIALIZED;
private DoFnTester(DoFn origFn) {
this.origFn = origFn;
DoFnSignature signature = DoFnSignatures.signatureForDoFn(origFn);
for (DoFnSignature.Parameter param : signature.processElement().extraParameters()) {
param.match(
new DoFnSignature.Parameter.Cases.WithDefault() {
@Override
public Void dispatch(DoFnSignature.Parameter.ProcessContextParameter p) {
// ProcessContext parameter is obviously supported.
return null;
}
@Override
public Void dispatch(DoFnSignature.Parameter.WindowParameter p) {
// We also support the BoundedWindow parameter.
return null;
}
@Override
protected Void dispatchDefault(DoFnSignature.Parameter p) {
throw new UnsupportedOperationException(
"Parameter " + p + " not supported by DoFnTester");
}
});
}
}
@SuppressWarnings("unchecked")
private void initializeState() throws Exception {
checkState(state == State.UNINITIALIZED, "Already initialized");
checkState(fn == null, "Uninitialized but fn != null");
if (cloningBehavior.equals(CloningBehavior.DO_NOT_CLONE)) {
fn = origFn;
} else {
fn = (DoFn)
SerializableUtils.deserializeFromByteArray(
SerializableUtils.serializeToByteArray(origFn),
origFn.toString());
}
fnInvoker = DoFnInvokers.invokerFor(fn);
fnInvoker.invokeSetup();
outputs = new HashMap<>();
accumulators = new HashMap<>();
}
}