All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.beam.fn.harness.CombineRunners Maven / Gradle / Ivy

There is a newer version: 2.60.0
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.beam.fn.harness;

import com.google.auto.service.AutoService;
import java.io.IOException;
import java.util.Map;
import java.util.function.Supplier;
import org.apache.beam.model.pipeline.v1.RunnerApi;
import org.apache.beam.model.pipeline.v1.RunnerApi.CombinePayload;
import org.apache.beam.model.pipeline.v1.RunnerApi.PTransform;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.fn.data.FnDataReceiver;
import org.apache.beam.sdk.function.ThrowingFunction;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.transforms.Combine.CombineFn;
import org.apache.beam.sdk.transforms.windowing.GlobalWindows;
import org.apache.beam.sdk.util.SerializableUtils;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.util.WindowedValue.WindowedValueCoder;
import org.apache.beam.sdk.util.construction.PTransformTranslation;
import org.apache.beam.sdk.util.construction.RehydratedComponents;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables;

/** Executes different components of Combine PTransforms. */
@SuppressWarnings({
  "rawtypes", // TODO(https://github.com/apache/beam/issues/20447)
  "nullness",
  "keyfor"
}) // TODO(https://github.com/apache/beam/issues/20497)
public class CombineRunners {

  /** A registrar which provides a factory to handle combine component PTransforms. */
  @AutoService(PTransformRunnerFactory.Registrar.class)
  public static class Registrar implements PTransformRunnerFactory.Registrar {

    @Override
    public Map getPTransformRunnerFactories() {
      return ImmutableMap.of(
          PTransformTranslation.COMBINE_PER_KEY_PRECOMBINE_TRANSFORM_URN,
          new PrecombineFactory(),
          PTransformTranslation.COMBINE_PER_KEY_MERGE_ACCUMULATORS_TRANSFORM_URN,
          MapFnRunners.forValueMapFnFactory(CombineRunners::createMergeAccumulatorsMapFunction),
          PTransformTranslation.COMBINE_PER_KEY_EXTRACT_OUTPUTS_TRANSFORM_URN,
          MapFnRunners.forValueMapFnFactory(CombineRunners::createExtractOutputsMapFunction),
          PTransformTranslation.COMBINE_PER_KEY_CONVERT_TO_ACCUMULATORS_TRANSFORM_URN,
          MapFnRunners.forValueMapFnFactory(CombineRunners::createConvertToAccumulatorsMapFunction),
          PTransformTranslation.COMBINE_GROUPED_VALUES_TRANSFORM_URN,
          MapFnRunners.forValueMapFnFactory(CombineRunners::createCombineGroupedValuesMapFunction));
    }
  }

  private static class PrecombineRunner {
    private final PipelineOptions options;
    private final String ptransformId;
    private final Supplier> bundleCache;
    private final CombineFn combineFn;
    private final FnDataReceiver>> output;
    private final Coder keyCoder;
    private PrecombineGroupingTable groupingTable;
    private boolean isGloballyWindowed;

    PrecombineRunner(
        PipelineOptions options,
        String ptransformId,
        Supplier> bundleCache,
        CombineFn combineFn,
        FnDataReceiver>> output,
        Coder keyCoder) {
      this(options, ptransformId, bundleCache, combineFn, output, keyCoder, false);
    }

    PrecombineRunner(
        PipelineOptions options,
        String ptransformId,
        Supplier> bundleCache,
        CombineFn combineFn,
        FnDataReceiver>> output,
        Coder keyCoder,
        boolean isGloballyWindowed) {
      this.options = options;
      this.ptransformId = ptransformId;
      this.bundleCache = bundleCache;
      this.combineFn = combineFn;
      this.output = output;
      this.keyCoder = keyCoder;
      this.isGloballyWindowed = isGloballyWindowed;
    }

    void startBundle() {
      groupingTable =
          PrecombineGroupingTable.combiningAndSampling(
              options,
              Caches.subCache(bundleCache.get(), ptransformId),
              combineFn,
              keyCoder,
              0.001 /*sizeEstimatorSampleRate*/,
              isGloballyWindowed);
    }

    void processElement(WindowedValue> elem) throws Exception {
      groupingTable.put(elem, output::accept);
    }

    void finishBundle() throws Exception {
      groupingTable.flush(output::accept);
      groupingTable = null;
    }
  }

  /** A factory for {@link PrecombineRunner}s. */
  @VisibleForTesting
  public static class PrecombineFactory
      implements PTransformRunnerFactory> {

    @Override
    public PrecombineRunner createRunnerForPTransform(Context context)
        throws IOException {
      // Get objects needed to create the runner.
      RehydratedComponents rehydratedComponents =
          RehydratedComponents.forComponents(
              RunnerApi.Components.newBuilder()
                  .putAllCoders(context.getCoders())
                  .putAllWindowingStrategies(context.getWindowingStrategies())
                  .build());
      String mainInputTag =
          Iterables.getOnlyElement(context.getPTransform().getInputsMap().keySet());
      RunnerApi.PCollection mainInput =
          context.getPCollections().get(context.getPTransform().getInputsOrThrow(mainInputTag));

      // Input coder may sometimes be WindowedValueCoder depending on runner, instead of the
      // expected KvCoder.
      Coder uncastInputCoder = rehydratedComponents.getCoder(mainInput.getCoderId());
      KvCoder inputCoder;
      boolean isGloballyWindowed =
          rehydratedComponents
              .getWindowingStrategy(mainInput.getWindowingStrategyId())
              .getWindowFn()
              .equals(new GlobalWindows());
      if (uncastInputCoder instanceof WindowedValueCoder) {
        inputCoder =
            (KvCoder)
                ((WindowedValueCoder>) uncastInputCoder).getValueCoder();
      } else {
        inputCoder = (KvCoder) rehydratedComponents.getCoder(mainInput.getCoderId());
      }
      Coder keyCoder = inputCoder.getKeyCoder();

      CombinePayload combinePayload =
          CombinePayload.parseFrom(context.getPTransform().getSpec().getPayload());
      CombineFn combineFn =
          (CombineFn)
              SerializableUtils.deserializeFromByteArray(
                  combinePayload.getCombineFn().getPayload().toByteArray(), "CombineFn");

      FnDataReceiver>> consumer =
          (FnDataReceiver)
              context.getPCollectionConsumer(
                  Iterables.getOnlyElement(context.getPTransform().getOutputsMap().values()));

      PrecombineRunner runner =
          new PrecombineRunner<>(
              context.getPipelineOptions(),
              context.getPTransformId(),
              context.getBundleCacheSupplier(),
              combineFn,
              consumer,
              keyCoder,
              isGloballyWindowed);

      // Register the appropriate handlers.
      context.addStartBundleFunction(runner::startBundle);
      context.addPCollectionConsumer(
          Iterables.getOnlyElement(context.getPTransform().getInputsMap().values()),
          (FnDataReceiver)
              (FnDataReceiver>>) runner::processElement);
      context.addFinishBundleFunction(runner::finishBundle);

      return runner;
    }
  }

  static 
      ThrowingFunction>, KV>
          createMergeAccumulatorsMapFunction(String pTransformId, PTransform pTransform)
              throws IOException {
    CombinePayload combinePayload = CombinePayload.parseFrom(pTransform.getSpec().getPayload());
    CombineFn combineFn =
        (CombineFn)
            SerializableUtils.deserializeFromByteArray(
                combinePayload.getCombineFn().getPayload().toByteArray(), "CombineFn");

    return (KV> input) ->
        KV.of(input.getKey(), combineFn.mergeAccumulators(input.getValue()));
  }

  static 
      ThrowingFunction, KV> createExtractOutputsMapFunction(
          String pTransformId, PTransform pTransform) throws IOException {
    CombinePayload combinePayload = CombinePayload.parseFrom(pTransform.getSpec().getPayload());
    CombineFn combineFn =
        (CombineFn)
            SerializableUtils.deserializeFromByteArray(
                combinePayload.getCombineFn().getPayload().toByteArray(), "CombineFn");

    return (KV input) ->
        KV.of(input.getKey(), combineFn.extractOutput(input.getValue()));
  }

  static 
      ThrowingFunction, KV> createConvertToAccumulatorsMapFunction(
          String pTransformId, PTransform pTransform) throws IOException {
    CombinePayload combinePayload = CombinePayload.parseFrom(pTransform.getSpec().getPayload());
    CombineFn combineFn =
        (CombineFn)
            SerializableUtils.deserializeFromByteArray(
                combinePayload.getCombineFn().getPayload().toByteArray(), "CombineFn");

    return (KV input) ->
        KV.of(input.getKey(), combineFn.addInput(combineFn.createAccumulator(), input.getValue()));
  }

  static 
      ThrowingFunction>, KV>
          createCombineGroupedValuesMapFunction(String pTransformId, PTransform pTransform)
              throws IOException {
    CombinePayload combinePayload = CombinePayload.parseFrom(pTransform.getSpec().getPayload());
    CombineFn combineFn =
        (CombineFn)
            SerializableUtils.deserializeFromByteArray(
                combinePayload.getCombineFn().getPayload().toByteArray(), "CombineFn");

    return (KV> input) -> {
      return KV.of(input.getKey(), combineFn.apply(input.getValue()));
    };
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy