org.apache.beam.sdk.io.elasticsearch.ElasticsearchIO Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of scio-elasticsearch_2.11 Show documentation
Show all versions of scio-elasticsearch_2.11 Show documentation
Scio add-on for writing to Elasticsearch
/*
* Copyright 2017 Spotify AB.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.beam.sdk.io.elasticsearch;
import static com.google.common.base.Preconditions.checkNotNull;
import com.google.common.collect.Iterables;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
import java.util.stream.StreamSupport;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.GroupByKey;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.transforms.windowing.AfterProcessingTime;
import org.apache.beam.sdk.transforms.windowing.GlobalWindows;
import org.apache.beam.sdk.transforms.windowing.Repeatedly;
import org.apache.beam.sdk.transforms.windowing.Window;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PDone;
import com.twitter.jsr166e.ThreadLocalRandom;
import java.io.IOException;
import java.io.Serializable;
import java.net.InetSocketAddress;
import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.bulk.BulkItemResponse;
import org.elasticsearch.action.bulk.BulkRequest;
import org.joda.time.Duration;
import java.util.Arrays;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Supplier;
import org.elasticsearch.action.bulk.BulkResponse;
import org.elasticsearch.client.Client;
import org.elasticsearch.client.transport.TransportClient;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.transport.InetSocketTransportAddress;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class ElasticsearchIO {
public static class Write {
/**
* Returns a tranform for writing to Elasticsearch cluster for a given name.
*
* @param clusterName name of the Elasticsearch cluster
*/
public static Bound withClusterName(String clusterName) {
return new Bound().withClusterName(clusterName);
}
/**
* Returns a transform for writing to the Elasticsearch cluster for a given servers.
*
* @param servers endpoints for the Elasticsearch cluster
*/
public static Bound withServers(InetSocketAddress[] servers) {
return new Bound().withServers(servers);
}
/**
* Returns a transform for writing to Elasticsearch cluster by providing
* slight delay specified by flushInterval.
*
* @param flushInterval delay applied to buffer elements. Defaulted to 1 seconds.
*/
public static Bound withFlushInterval(Duration flushInterval) {
return new Bound().withFlushInterval(flushInterval);
}
/**
* Returns a transform for writing to Elasticsearch cluster.
*
* @param function creates IndexRequest required by Elasticsearch client
*/
public static Bound withFunction(SerializableFunction>> function) {
return new Bound().withFunction(function);
}
/**
* Returns a transform for writing to Elasticsearch cluster.
* Note: Recommended to set this number as number of workers in your pipeline.
*
* @param numOfShard to construct a batch to bulk write to Elasticsearch.
*/
public static Bound withNumOfShard(Long numOfShard) {
return new Bound<>().withNumOfShard(numOfShard);
}
/**
* Returns a transform for writing to Elasticsearch cluster.
*
* @param error applies given function if specified in case of
* Elasticsearch error with bulk writes. Default behavior throws IOException.
*/
public static Bound withError(ThrowingConsumer error) {
return new Bound<>().withError(error);
}
public static class Bound extends PTransform, PDone> {
private final String clusterName;
private final InetSocketAddress[] servers;
private final Duration flushInterval;
private final SerializableFunction>> toActionRequests;
private final Long numOfShard;
private final ThrowingConsumer error;
private Bound(final String clusterName,
final InetSocketAddress[] servers,
final Duration flushInterval,
final SerializableFunction>> toActionRequests,
final Long numOfShard,
final ThrowingConsumer error) {
this.clusterName = clusterName;
this.servers = servers;
this.flushInterval = flushInterval;
this.toActionRequests = toActionRequests;
this.numOfShard = numOfShard;
this.error = error;
}
Bound() {
this(null, null, null, null, null, defaultErrorHandler());
}
public Bound withClusterName(String clusterName) {
return new Bound<>(clusterName, servers, flushInterval, toActionRequests, numOfShard, error);
}
public Bound withServers(InetSocketAddress[] servers) {
return new Bound<>(clusterName, servers, flushInterval, toActionRequests, numOfShard, error);
}
public Bound withFlushInterval(Duration flushInterval) {
return new Bound<>(clusterName, servers, flushInterval, toActionRequests, numOfShard, error);
}
public Bound withFunction(SerializableFunction>> toIndexRequest) {
return new Bound<>(clusterName, servers, flushInterval, toIndexRequest, numOfShard, error);
}
public Bound withNumOfShard(Long numOfShard) {
return new Bound<>(clusterName, servers, flushInterval, toActionRequests, numOfShard, error);
}
public Bound withError(ThrowingConsumer error) {
return new Bound<>(clusterName, servers, flushInterval, toActionRequests, numOfShard, error);
}
@Override
public PDone expand(final PCollection input) {
checkNotNull(clusterName);
checkNotNull(servers);
checkNotNull(toActionRequests);
checkNotNull(numOfShard);
checkNotNull(flushInterval);
input
.apply("Assign To Shard", ParDo.of(new AssignToShard<>(numOfShard)))
.apply("Re-Window to Global Window", Window.>into(new GlobalWindows())
.triggering(Repeatedly.forever(
AfterProcessingTime
.pastFirstElementInPane()
.plusDelayOf(flushInterval)))
.discardingFiredPanes())
.apply(GroupByKey.create())
.apply("Write to Elasticesarch",
ParDo.of(new ElasticsearchWriter<>
(clusterName, servers, toActionRequests, error)));
return PDone.in(input.getPipeline());
}
}
private static class AssignToShard extends DoFn> {
private final Long numOfShard;
public AssignToShard(Long numOfShard) {
this.numOfShard = numOfShard;
}
@ProcessElement
public void processElement(ProcessContext c) throws Exception {
// assign this element to a random shard
final long shard = ThreadLocalRandom.current().nextLong(numOfShard);
c.output(KV.of(shard, c.element()));
}
}
private static class ElasticsearchWriter extends DoFn>, Void> {
private final Logger LOG = LoggerFactory.getLogger(ElasticsearchWriter.class);
private final ClientSupplier clientSupplier;
private final SerializableFunction>> toActionRequests;
private final ThrowingConsumer error;
public ElasticsearchWriter(String clusterName,
InetSocketAddress[] servers,
SerializableFunction>> toActionRequests,
ThrowingConsumer error) {
this.clientSupplier = new ClientSupplier(clusterName, servers);
this.toActionRequests = toActionRequests;
this.error = error;
}
@ProcessElement
public void processElement(ProcessContext c) throws Exception {
final List>> actionRequests =
StreamSupport.stream(c.element().getValue().spliterator(), false)
.map(toActionRequests::apply)
.collect(Collectors.toList());
// Elasticsearch throws ActionRequestValidationException if bulk request is empty,
// so do nothing if number of actions is zero.
if (actionRequests.isEmpty()) {
LOG.info("ElasticsearchWriter: no requests to send");
return;
}
final BulkRequest bulkRequest = new BulkRequest().add(Iterables.concat(actionRequests));
final BulkResponse bulkItemResponse = clientSupplier.get().bulk(bulkRequest).get();
if (bulkItemResponse.hasFailures()) {
error.accept(new BulkExecutionException(bulkItemResponse));
}
}
}
private static class ClientSupplier implements Supplier, Serializable {
private final AtomicReference CLIENT = new AtomicReference<>();
private final String clusterName;
private final InetSocketAddress[] addresses;
public ClientSupplier(final String clusterName, final InetSocketAddress[] addresses) {
this.clusterName = clusterName;
this.addresses = addresses;
}
@Override
public Client get() {
if (CLIENT.get() == null) {
synchronized (CLIENT) {
if (CLIENT.get() == null) {
CLIENT.set(create(clusterName, addresses));
}
}
}
return CLIENT.get();
}
private TransportClient create(String clusterName, InetSocketAddress[] addresses) {
final Settings settings = Settings.settingsBuilder()
.put("cluster.name", clusterName)
.build();
InetSocketTransportAddress[] transportAddresses = Arrays.stream(addresses)
.map(InetSocketTransportAddress::new)
.toArray(InetSocketTransportAddress[]::new);
return TransportClient.builder()
.settings(settings)
.build()
.addTransportAddresses(transportAddresses);
}
}
private static ThrowingConsumer defaultErrorHandler() {
return throwable -> {
throw throwable;
};
}
/**
* An exception that puts information about the failures in the bulk execution.
*/
public static class BulkExecutionException extends IOException {
private final Iterable failures;
BulkExecutionException(BulkResponse bulkResponse) {
super(bulkResponse.buildFailureMessage());
this.failures = Arrays.stream(bulkResponse.getItems())
.map(BulkItemResponse::getFailure)
.filter(Objects::nonNull)
.map(BulkItemResponse.Failure::getCause)
.collect(Collectors.toList());
}
public Iterable getFailures() {
return failures;
}
}
}
}