All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.beam.sdk.io.elasticsearch.ElasticsearchIO Maven / Gradle / Ivy

/*
 * Copyright 2018 Spotify AB.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package org.apache.beam.sdk.io.elasticsearch;

import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkArgument;
import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkNotNull;

import java.io.IOException;
import java.io.Serializable;
import java.net.InetSocketAddress;
import java.util.Arrays;
import java.util.Objects;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.GroupByKey;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.ProcessFunction;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.transforms.windowing.AfterProcessingTime;
import org.apache.beam.sdk.transforms.windowing.GlobalWindows;
import org.apache.beam.sdk.transforms.windowing.Repeatedly;
import org.apache.beam.sdk.transforms.windowing.TimestampCombiner;
import org.apache.beam.sdk.transforms.windowing.Window;
import org.apache.beam.sdk.util.BackOff;
import org.apache.beam.sdk.util.BackOffUtils;
import org.apache.beam.sdk.util.FluentBackoff;
import org.apache.beam.sdk.util.Sleeper;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PDone;
import org.elasticsearch.action.DocWriteRequest;
import org.elasticsearch.action.bulk.BulkItemResponse;
import org.elasticsearch.action.bulk.BulkRequest;
import org.elasticsearch.action.bulk.BulkResponse;
import org.elasticsearch.action.delete.DeleteRequest;
import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.action.update.UpdateRequest;
import org.elasticsearch.client.Client;
import org.elasticsearch.client.transport.TransportClient;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.transport.TransportAddress;
import org.elasticsearch.transport.client.PreBuiltTransportClient;
import org.joda.time.Duration;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ElasticsearchIO {

  public static class Write {

    private static final Logger LOG = LoggerFactory.getLogger(Write.class);
    private static final String RETRY_ATTEMPT_LOG =
        "Error writing to Elasticsearch. Retry attempt[%d]";
    private static final String RETRY_FAILED_LOG =
        "Error writing to ES after %d attempt(s). No more attempts allowed";

    /**
     * Returns a transform for writing to Elasticsearch cluster for a given name.
     *
     * @param clusterName name of the Elasticsearch cluster
     */
    public static  Bound withClusterName(String clusterName) {
      return new Bound().withClusterName(clusterName);
    }

    /**
     * Returns a transform for writing to the Elasticsearch cluster for a given servers.
     *
     * @param servers endpoints for the Elasticsearch cluster
     */
    public static  Bound withServers(InetSocketAddress[] servers) {
      return new Bound().withServers(servers);
    }

    /**
     * Returns a transform for writing to Elasticsearch cluster by providing slight delay specified
     * by flushInterval.
     *
     * @param flushInterval delay applied to buffer elements. Defaulted to 1 seconds.
     */
    public static  Bound withFlushInterval(Duration flushInterval) {
      return new Bound().withFlushInterval(flushInterval);
    }

    /**
     * Returns a transform for writing to Elasticsearch cluster.
     *
     * @param function creates IndexRequest required by Elasticsearch client
     */
    public static  Bound withFunction(
        SerializableFunction>> function) {
      return new Bound().withFunction(function);
    }

    /**
     * Returns a transform for writing to Elasticsearch cluster. Note: Recommended to set this
     * number as number of workers in your pipeline.
     *
     * @param numOfShard to construct a batch to bulk write to Elasticsearch.
     */
    public static  Bound withNumOfShard(long numOfShard) {
      return new Bound<>().withNumOfShard(numOfShard);
    }

    /**
     * Returns a transform for writing to Elasticsearch cluster.
     *
     * @param error applies given function if specified in case of Elasticsearch error with bulk
     *     writes. Default behavior throws IOException.
     */
    public static  Bound withError(ThrowingConsumer error) {
      return new Bound<>().withError(error);
    }

    public static  Bound withMaxBulkRequestSize(int maxBulkRequestSize) {
      return new Bound<>().withMaxBulkRequestSize(maxBulkRequestSize);
    }

    public static  Bound withMaxBulkRequestBytes(long maxBulkRequestBytes) {
      return new Bound<>().withMaxBulkRequestBytes(maxBulkRequestBytes);
    }

    /**
     * Returns a transform for writing to Elasticsearch cluster.
     *
     * @param maxRetries Maximum number of retries to attempt for saving any single chunk of bulk
     *     requests to the Elasticsearch cluster.
     */
    public static  Bound withMaxRetries(int maxRetries) {
      return new Bound<>().withMaxRetries(maxRetries);
    }

    /**
     * Returns a transform for writing to Elasticsearch cluster.
     *
     * @param retryPause Duration to wait between successive retry attempts.
     */
    public static  Bound withRetryPause(Duration retryPause) {
      return new Bound<>().withRetryPause(retryPause);
    }

    public static class Bound extends PTransform, PDone> {

      private static final int CHUNK_SIZE = 3000;

      // 5 megabytes - recommended as a sensible default payload size (see
      // https://www.elastic.co/guide/en/elasticsearch/reference/6.8/getting-started-index.html#getting-started-batch-processing)
      private static final long CHUNK_BYTES = 5L * 1024L * 1024L;

      private static final int DEFAULT_RETRIES = 3;
      private static final Duration DEFAULT_RETRY_PAUSE = Duration.millis(35000);

      private final String clusterName;
      private final InetSocketAddress[] servers;
      private final Duration flushInterval;
      private final SerializableFunction>> toDocWriteRequests;
      private final long numOfShard;
      private final int maxBulkRequestSize;
      private final long maxBulkRequestBytes;
      private final int maxRetries;
      private final Duration retryPause;
      private final ThrowingConsumer error;

      private Bound(
          final String clusterName,
          final InetSocketAddress[] servers,
          final Duration flushInterval,
          final SerializableFunction>> toDocWriteRequests,
          final long numOfShard,
          final int maxBulkRequestSize,
          final long maxBulkRequestBytes,
          int maxRetries,
          Duration retryPause,
          final ThrowingConsumer error) {
        this.clusterName = clusterName;
        this.servers = servers;
        this.flushInterval = flushInterval;
        this.toDocWriteRequests = toDocWriteRequests;
        this.numOfShard = numOfShard;
        this.maxBulkRequestSize = maxBulkRequestSize;
        this.maxBulkRequestBytes = maxBulkRequestBytes;
        this.maxRetries = maxRetries;
        this.retryPause = retryPause;
        this.error = error;
      }

      Bound() {
        this(
            null,
            null,
            null,
            null,
            0,
            CHUNK_SIZE,
            CHUNK_BYTES,
            DEFAULT_RETRIES,
            DEFAULT_RETRY_PAUSE,
            defaultErrorHandler());
      }

      public Bound withClusterName(String clusterName) {
        return new Bound<>(
            clusterName,
            servers,
            flushInterval,
            toDocWriteRequests,
            numOfShard,
            maxBulkRequestSize,
            maxBulkRequestBytes,
            maxRetries,
            retryPause,
            error);
      }

      public Bound withServers(InetSocketAddress[] servers) {
        return new Bound<>(
            clusterName,
            servers,
            flushInterval,
            toDocWriteRequests,
            numOfShard,
            maxBulkRequestSize,
            maxBulkRequestBytes,
            maxRetries,
            retryPause,
            error);
      }

      public Bound withFlushInterval(Duration flushInterval) {
        return new Bound<>(
            clusterName,
            servers,
            flushInterval,
            toDocWriteRequests,
            numOfShard,
            maxBulkRequestSize,
            maxBulkRequestBytes,
            maxRetries,
            retryPause,
            error);
      }

      public Bound withFunction(
          SerializableFunction>> toIndexRequest) {
        return new Bound<>(
            clusterName,
            servers,
            flushInterval,
            toIndexRequest,
            numOfShard,
            maxBulkRequestSize,
            maxBulkRequestBytes,
            maxRetries,
            retryPause,
            error);
      }

      public Bound withNumOfShard(long numOfShard) {
        return new Bound<>(
            clusterName,
            servers,
            flushInterval,
            toDocWriteRequests,
            numOfShard,
            maxBulkRequestSize,
            maxBulkRequestBytes,
            maxRetries,
            retryPause,
            error);
      }

      public Bound withError(ThrowingConsumer error) {
        return new Bound<>(
            clusterName,
            servers,
            flushInterval,
            toDocWriteRequests,
            numOfShard,
            maxBulkRequestSize,
            maxBulkRequestBytes,
            maxRetries,
            retryPause,
            error);
      }

      public Bound withMaxBulkRequestSize(int maxBulkRequestSize) {
        return new Bound<>(
            clusterName,
            servers,
            flushInterval,
            toDocWriteRequests,
            numOfShard,
            maxBulkRequestSize,
            maxBulkRequestBytes,
            maxRetries,
            retryPause,
            error);
      }

      public Bound withMaxBulkRequestBytes(long maxBulkRequestBytes) {
        return new Bound<>(
            clusterName,
            servers,
            flushInterval,
            toDocWriteRequests,
            numOfShard,
            maxBulkRequestSize,
            maxBulkRequestBytes,
            maxRetries,
            retryPause,
            error);
      }

      public Bound withMaxRetries(int maxRetries) {
        return new Bound<>(
            clusterName,
            servers,
            flushInterval,
            toDocWriteRequests,
            numOfShard,
            maxBulkRequestSize,
            maxBulkRequestBytes,
            maxRetries,
            retryPause,
            error);
      }

      public Bound withRetryPause(Duration retryPause) {
        return new Bound<>(
            clusterName,
            servers,
            flushInterval,
            toDocWriteRequests,
            numOfShard,
            maxBulkRequestSize,
            maxBulkRequestBytes,
            maxRetries,
            retryPause,
            error);
      }

      @Override
      public PDone expand(final PCollection input) {
        checkNotNull(clusterName);
        checkNotNull(servers);
        checkNotNull(toDocWriteRequests);
        checkNotNull(flushInterval);
        checkArgument(numOfShard >= 0);
        checkArgument(maxBulkRequestSize > 0);
        checkArgument(maxBulkRequestBytes > 0);
        checkArgument(maxRetries >= 0);
        checkArgument(retryPause.getMillis() >= 0);
        if (numOfShard == 0) {
          input.apply(
              ParDo.of(
                  new ElasticsearchWriter<>(
                      clusterName,
                      servers,
                      maxBulkRequestSize,
                      maxBulkRequestBytes,
                      toDocWriteRequests,
                      error,
                      maxRetries,
                      retryPause)));
        } else {
          input
              .apply("Assign To Shard", ParDo.of(new AssignToShard<>(numOfShard)))
              .apply(
                  "Re-Window to Global Window",
                  Window.>into(new GlobalWindows())
                      .triggering(
                          Repeatedly.forever(
                              AfterProcessingTime.pastFirstElementInPane()
                                  .plusDelayOf(flushInterval)))
                      .discardingFiredPanes()
                      .withTimestampCombiner(TimestampCombiner.END_OF_WINDOW))
              .apply(GroupByKey.create())
              .apply(
                  "Write to Elasticsearch",
                  ParDo.of(
                      new ElasticsearchShardWriter<>(
                          clusterName,
                          servers,
                          maxBulkRequestSize,
                          maxBulkRequestBytes,
                          toDocWriteRequests,
                          error,
                          maxRetries,
                          retryPause)));
        }
        return PDone.in(input.getPipeline());
      }
    }

    private static class AssignToShard extends DoFn> {

      private final long numOfShard;

      public AssignToShard(long numOfShard) {
        this.numOfShard = numOfShard;
      }

      @ProcessElement
      public void processElement(@Element T element, OutputReceiver> out)
          throws Exception {
        // assign this element to a random shard
        final long shard = ThreadLocalRandom.current().nextLong(numOfShard);
        out.output(KV.of(shard, element));
      }
    }

    private static class ElasticsearchWriter extends DoFn {

      private BulkRequest chunk;
      private long currentSize;
      private long currentBytes;

      private final ClientSupplier clientSupplier;
      private final SerializableFunction>> toDocWriteRequests;
      private final ThrowingConsumer error;
      private final int maxBulkRequestSize;
      private final long maxBulkRequestBytes;
      private final int maxRetries;
      private final Duration retryPause;

      private ProcessFunction requestFn;
      private ProcessFunction retryFn;

      public ElasticsearchWriter(
          String clusterName,
          InetSocketAddress[] servers,
          int maxBulkRequestSize,
          long maxBulkRequestBytes,
          SerializableFunction>> toDocWriteRequests,
          ThrowingConsumer error,
          int maxRetries,
          Duration retryPause) {
        this.maxBulkRequestSize = maxBulkRequestSize;
        this.maxBulkRequestBytes = maxBulkRequestBytes;
        this.clientSupplier = new ClientSupplier(clusterName, servers);
        this.toDocWriteRequests = toDocWriteRequests;
        this.error = error;
        this.maxRetries = maxRetries;
        this.retryPause = retryPause;
      }

      @Setup
      public void setup() throws Exception {
        final FluentBackoff backoffConfig =
            FluentBackoff.DEFAULT
                .withMaxRetries(this.maxRetries)
                .withInitialBackoff(this.retryPause);
        this.requestFn = request(clientSupplier, error);
        this.retryFn = retry(requestFn, backoffConfig);
      }

      @StartBundle
      public void startBundle(StartBundleContext c) {
        chunk = new BulkRequest();
        currentSize = 0;
        currentBytes = 0;
      }

      @FinishBundle
      public void finishBundle() throws Exception {
        flush();
      }

      @ProcessElement
      public void processElement(@Element T element) throws Exception {
        final Iterable> requests = toDocWriteRequests.apply(element);
        for (DocWriteRequest request : requests) {
          long requestBytes = documentSize(request);
          if (currentSize < maxBulkRequestSize
              && (currentBytes + requestBytes) < maxBulkRequestBytes) {
            chunk.add(request);
            currentSize += 1;
            currentBytes += requestBytes;
          } else {
            flush();
            chunk = new BulkRequest().add(request);
            currentSize = 1;
            currentBytes = requestBytes;
          }
        }
      }

      private void flush() throws Exception {
        if (chunk.numberOfActions() < 1) {
          return;
        }
        try {
          requestFn.apply(chunk);
        } catch (Exception e) {
          retryFn.apply(chunk);
        }
      }
    }

    private static class ElasticsearchShardWriter extends DoFn>, Void> {

      private final ClientSupplier clientSupplier;
      private final SerializableFunction>> toDocWriteRequests;
      private final ThrowingConsumer error;
      private final int maxBulkRequestSize;
      private final long maxBulkRequestBytes;
      private final int maxRetries;
      private final Duration retryPause;

      private ProcessFunction requestFn;
      private ProcessFunction retryFn;

      public ElasticsearchShardWriter(
          String clusterName,
          InetSocketAddress[] servers,
          int maxBulkRequestSize,
          long maxBulkRequestBytes,
          SerializableFunction>> toDocWriteRequests,
          ThrowingConsumer error,
          int maxRetries,
          Duration retryPause) {
        this.maxBulkRequestSize = maxBulkRequestSize;
        this.maxBulkRequestBytes = maxBulkRequestBytes;
        this.clientSupplier = new ClientSupplier(clusterName, servers);
        this.toDocWriteRequests = toDocWriteRequests;
        this.error = error;
        this.maxRetries = maxRetries;
        this.retryPause = retryPause;
      }

      @Setup
      public void setup() throws Exception {
        final FluentBackoff backoffConfig =
            FluentBackoff.DEFAULT
                .withMaxRetries(this.maxRetries)
                .withInitialBackoff(this.retryPause);
        this.requestFn = request(clientSupplier, error);
        this.retryFn = retry(requestFn, backoffConfig);
      }

      @SuppressWarnings("Duplicates")
      @ProcessElement
      public void processElement(@Element KV> element) throws Exception {
        final Iterable values = element.getValue();

        // Elasticsearch throws ActionRequestValidationException if bulk request is empty,
        // so do nothing if number of actions is zero.
        if (!values.iterator().hasNext()) {
          LOG.info("ElasticsearchWriter: no requests to send");
          return;
        }

        final Stream docWriteRequests =
            StreamSupport.stream(values.spliterator(), false)
                .map(toDocWriteRequests::apply)
                .flatMap(ar -> StreamSupport.stream(ar.spliterator(), false));

        int currentSize = 0;
        long currentBytes = 0L;
        BulkRequest chunk = new BulkRequest();

        for (DocWriteRequest request : (Iterable) docWriteRequests::iterator) {
          long requestBytes = documentSize(request);
          if (currentSize < maxBulkRequestSize
              && (currentBytes + requestBytes) < maxBulkRequestBytes) {
            chunk.add(request);
            currentSize += 1;
            currentBytes += requestBytes;
          } else {
            flush(chunk);
            chunk = new BulkRequest().add(request);
            currentSize = 1;
            currentBytes = requestBytes;
          }
        }

        flush(chunk);
      }

      private void flush(BulkRequest chunk) throws Exception {
        if (chunk.numberOfActions() < 1) {
          return;
        }

        try {
          requestFn.apply(chunk);
        } catch (Exception e) {
          retryFn.apply(chunk);
        }
      }
    }

    private static ProcessFunction request(
        final ClientSupplier clientSupplier,
        final ThrowingConsumer bulkErrorHandler) {
      return chunk -> {
        final BulkResponse bulkItemResponse = clientSupplier.get().bulk(chunk).get();

        if (bulkItemResponse.hasFailures()) {
          bulkErrorHandler.accept(new BulkExecutionException(bulkItemResponse));
        }

        return bulkItemResponse;
      };
    }

    private static ProcessFunction retry(
        final ProcessFunction requestFn,
        final FluentBackoff backoffConfig) {
      return chunk -> {
        final BackOff backoff = backoffConfig.backoff();
        int attempt = 0;
        BulkResponse response = null;
        Exception exception = null;

        while (response == null && BackOffUtils.next(Sleeper.DEFAULT, backoff)) {
          LOG.warn(String.format(RETRY_ATTEMPT_LOG, ++attempt));
          try {
            response = requestFn.apply(chunk);
            exception = null;
          } catch (Exception e) {
            exception = e;
          }
        }

        if (exception != null) {
          throw new Exception(String.format(RETRY_FAILED_LOG, attempt), exception);
        }

        return response;
      };
    }

    private static class ClientSupplier implements Supplier, Serializable {

      private final AtomicReference CLIENT = new AtomicReference<>();
      private final String clusterName;
      private final InetSocketAddress[] addresses;

      public ClientSupplier(final String clusterName, final InetSocketAddress[] addresses) {
        this.clusterName = clusterName;
        this.addresses = addresses;
      }

      @Override
      public Client get() {
        if (CLIENT.get() == null) {
          synchronized (CLIENT) {
            if (CLIENT.get() == null) {
              CLIENT.set(create(clusterName, addresses));
            }
          }
        }
        return CLIENT.get();
      }

      private TransportClient create(String clusterName, InetSocketAddress[] addresses) {
        final Settings settings = Settings.builder().put("cluster.name", clusterName).build();

        TransportAddress[] transportAddresses =
            Arrays.stream(addresses).map(TransportAddress::new).toArray(TransportAddress[]::new);

        return new PreBuiltTransportClient(settings).addTransportAddresses(transportAddresses);
      }
    }

    private static ThrowingConsumer defaultErrorHandler() {
      return throwable -> {
        throw throwable;
      };
    }

    /** An exception that puts information about the failures in the bulk execution. */
    public static class BulkExecutionException extends IOException {

      private final Iterable failures;

      BulkExecutionException(BulkResponse bulkResponse) {
        super(bulkResponse.buildFailureMessage());
        this.failures =
            Arrays.stream(bulkResponse.getItems())
                .map(BulkItemResponse::getFailure)
                .filter(Objects::nonNull)
                .map(BulkItemResponse.Failure::getCause)
                .collect(Collectors.toList());
      }

      public Iterable getFailures() {
        return failures;
      }
    }
  }

  private static long documentSize(DocWriteRequest request) {
    if (request instanceof IndexRequest) {
      return ((IndexRequest) request).source().length();
    } else if (request instanceof UpdateRequest) {
      return ((UpdateRequest) request).doc().source().length();
    } else if (request instanceof DeleteRequest) {
      return 0;
    }
    throw new IllegalArgumentException("Encountered unknown subclass of DocWriteRequest");
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy