io.weaviate.client.v1.batch.api.ReferencesBatcher Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of client Show documentation
Show all versions of client Show documentation
A java client for Weaviate Vector Search Engine
// Generated by delombok at Mon Jan 08 10:19:26 UTC 2024
package io.weaviate.client.v1.batch.api;
import io.weaviate.client.v1.batch.model.BatchReference;
import io.weaviate.client.v1.batch.model.BatchReferenceResponse;
import io.weaviate.client.v1.batch.util.ReferencesPath;
import org.apache.commons.lang3.ObjectUtils;
import io.weaviate.client.Config;
import io.weaviate.client.base.BaseClient;
import io.weaviate.client.base.ClientResult;
import io.weaviate.client.base.Response;
import io.weaviate.client.base.Result;
import io.weaviate.client.base.WeaviateErrorMessage;
import io.weaviate.client.base.WeaviateErrorResponse;
import io.weaviate.client.base.http.HttpClient;
import io.weaviate.client.base.util.Assert;
import java.io.Closeable;
import java.net.ConnectException;
import java.net.SocketTimeoutException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;
import java.util.function.Supplier;
import java.util.stream.Collectors;
public class ReferencesBatcher extends BaseClient implements ClientResult, Closeable {
private final ReferencesPath referencesPath;
private final BatchRetriesConfig batchRetriesConfig;
private final AutoBatchConfig autoBatchConfig;
private final boolean autoRunEnabled;
private final ScheduledExecutorService executorService;
private final DelayedExecutor> delayedExecutor;
private final List references;
private String consistencyLevel;
private final List>> undoneFutures;
private ReferencesBatcher(HttpClient httpClient, Config config, ReferencesPath referencesPath, BatchRetriesConfig batchRetriesConfig, AutoBatchConfig autoBatchConfig) {
super(httpClient, config);
this.referencesPath = referencesPath;
this.references = new ArrayList<>();
this.batchRetriesConfig = batchRetriesConfig;
if (autoBatchConfig != null) {
this.autoRunEnabled = true;
this.autoBatchConfig = autoBatchConfig;
this.executorService = Executors.newScheduledThreadPool(autoBatchConfig.poolSize);
this.delayedExecutor = new ExecutorServiceDelayedExecutor(executorService);
this.undoneFutures = Collections.synchronizedList(new ArrayList<>());
} else {
this.autoRunEnabled = false;
this.autoBatchConfig = null;
this.executorService = null;
this.delayedExecutor = new SleepDelayedExecutor();
this.undoneFutures = null;
}
}
public static ReferencesBatcher create(HttpClient httpClient, Config config, ReferencesPath referencesPath, BatchRetriesConfig batchRetriesConfig) {
Assert.requiredNotNull(batchRetriesConfig, "batchRetriesConfig");
return new ReferencesBatcher(httpClient, config, referencesPath, batchRetriesConfig, null);
}
public static ReferencesBatcher createAuto(HttpClient httpClient, Config config, ReferencesPath referencesPath, BatchRetriesConfig batchRetriesConfig, AutoBatchConfig autoBatchConfig) {
Assert.requiredNotNull(batchRetriesConfig, "batchRetriesConfig");
Assert.requiredNotNull(autoBatchConfig, "autoBatchConfig");
return new ReferencesBatcher(httpClient, config, referencesPath, batchRetriesConfig, autoBatchConfig);
}
public ReferencesBatcher withReference(BatchReference reference) {
return withReferences(reference);
}
public ReferencesBatcher withReferences(BatchReference... references) {
this.references.addAll(Arrays.asList(references));
autoRun();
return this;
}
public ReferencesBatcher withConsistencyLevel(String consistencyLevel) {
this.consistencyLevel = consistencyLevel;
return this;
}
@Override
public Result run() {
if (autoRunEnabled) {
flush(); // fallback to flush in auto run enabled
return null;
}
if (references.isEmpty()) {
return new Result<>(0, new BatchReferenceResponse[0], null);
}
List batch = extractBatch(references.size());
return runRecursively(batch, 0, 0, (DelayedExecutor>) delayedExecutor);
}
public void flush() {
if (!autoRunEnabled) {
run(); // fallback to run if auto run disabled
return;
}
if (!references.isEmpty()) {
List batch = extractBatch(references.size());
runInThread(batch);
}
CompletableFuture>[] futures = undoneFutures.toArray(new CompletableFuture[0]);
if (futures.length == 0) {
return;
}
CompletableFuture.allOf(futures).join();
}
@Override
public void close() {
if (!autoRunEnabled) {
return;
}
executorService.shutdown();
try {
if (!executorService.awaitTermination(autoBatchConfig.awaitTerminationMs, TimeUnit.MILLISECONDS)) {
executorService.shutdownNow();
}
} catch (InterruptedException e) {
executorService.shutdownNow();
}
}
private List extractBatch(int batchSize) {
List batch = new ArrayList<>(batchSize);
List sublist = references.subList(0, batchSize);
batch.addAll(sublist);
sublist.clear();
return batch;
}
private void autoRun() {
if (!autoRunEnabled) {
return;
}
while (references.size() >= autoBatchConfig.batchSize) {
List batch = extractBatch(autoBatchConfig.batchSize);
runInThread(batch);
}
}
private void runInThread(List batch) {
CompletableFuture> future = CompletableFuture.supplyAsync(() -> createRunFuture(batch), executorService).thenCompose(f -> f);
if (autoBatchConfig.callback != null) {
future = future.whenComplete((result, e) -> autoBatchConfig.callback.accept(result));
}
CompletableFuture> undoneFuture = future;
undoneFutures.add(undoneFuture);
undoneFuture.whenComplete((result, ex) -> undoneFutures.remove(undoneFuture));
}
private CompletableFuture> createRunFuture(List batch) {
return runRecursively(batch, 0, 0, (DelayedExecutor>>) delayedExecutor);
}
private T runRecursively(List batch, int connectionErrorCount, int timeoutErrorCount, DelayedExecutor delayedExecutor) {
Result result = internalRun(batch);
if (result.hasErrors()) {
List messages = result.getError().getMessages();
if (!messages.isEmpty()) {
Throwable throwable = messages.get(0).getThrowable();
boolean executeAgain = false;
int delay = 0;
if (throwable instanceof ConnectException) {
if (connectionErrorCount++ < batchRetriesConfig.maxConnectionRetries) {
executeAgain = true;
delay = connectionErrorCount * batchRetriesConfig.retriesIntervalMs;
}
} else if (throwable instanceof SocketTimeoutException) {
if (timeoutErrorCount++ < batchRetriesConfig.maxTimeoutRetries) {
executeAgain = true;
delay = timeoutErrorCount * batchRetriesConfig.retriesIntervalMs;
}
}
if (executeAgain) {
int lambdaConnectionErrorCount = connectionErrorCount;
int lambdaTimeoutErrorCount = timeoutErrorCount;
List lambdaBatch = batch;
return delayedExecutor.delayed(delay, () -> runRecursively(lambdaBatch, lambdaConnectionErrorCount, lambdaTimeoutErrorCount, delayedExecutor));
}
}
} else {
batch = null;
}
Result finalResult = createFinalResultFromLastResult(result, batch);
return delayedExecutor.now(finalResult);
}
private Result internalRun(List batch) {
BatchReference[] payload = batch.toArray(new BatchReference[0]);
String path = referencesPath.buildCreate(ReferencesPath.Params.builder().consistencyLevel(consistencyLevel).build());
Response resp = sendPostRequest(path, payload, BatchReferenceResponse[].class);
return new Result<>(resp);
}
private Result createFinalResultFromLastResult(Result lastResult, List failedBatch) {
if (ObjectUtils.isEmpty(failedBatch)) {
return lastResult;
}
String failedRefs = failedBatch.stream().map(ref -> ref.getFrom() + " => " + ref.getTo()).collect(Collectors.joining(", "));
WeaviateErrorMessage failedRefsMessage = WeaviateErrorMessage.builder().message("Failed refs: " + failedRefs).build();
List messages;
int statusCode = 0;
if (lastResult.hasErrors()) {
statusCode = lastResult.getError().getStatusCode();
List prevMessages = lastResult.getError().getMessages();
messages = new ArrayList<>(prevMessages.size() + 1);
messages.addAll(prevMessages);
messages.add(failedRefsMessage);
} else {
messages = Collections.singletonList(failedRefsMessage);
}
return new Result<>(statusCode, null, WeaviateErrorResponse.builder().error(messages).code(statusCode).build());
}
private interface DelayedExecutor {
T delayed(int delay, Supplier supplier);
T now(Result result);
}
private static class ExecutorServiceDelayedExecutor implements DelayedExecutor>> {
private final ScheduledExecutorService executorService;
@Override
public CompletableFuture> delayed(int delay, Supplier>> supplier) {
Executor executor = runnable -> executorService.schedule(runnable, delay, TimeUnit.MILLISECONDS);
return CompletableFuture.supplyAsync(supplier, executor).thenCompose(f -> f);
}
@Override
public CompletableFuture> now(Result result) {
return CompletableFuture.completedFuture(result);
}
@java.lang.SuppressWarnings("all")
public ExecutorServiceDelayedExecutor(final ScheduledExecutorService executorService) {
this.executorService = executorService;
}
}
private static class SleepDelayedExecutor implements DelayedExecutor> {
@Override
public Result delayed(int delay, Supplier> supplier) {
try {
Thread.sleep(delay);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
return supplier.get();
}
@Override
public Result now(Result result) {
return result;
}
}
public static class BatchRetriesConfig {
public static final int MAX_TIMEOUT_RETRIES = 3;
public static final int MAX_CONNECTION_RETRIES = 3;
public static final int RETRIES_INTERVAL = 2000;
private final int maxTimeoutRetries;
private final int maxConnectionRetries;
private final int retriesIntervalMs;
private BatchRetriesConfig(int maxTimeoutRetries, int maxConnectionRetries, int retriesIntervalMs) {
Assert.requireGreaterEqual(maxTimeoutRetries, 0, "maxTimeoutRetries");
Assert.requireGreaterEqual(maxConnectionRetries, 0, "maxConnectionRetries");
Assert.requireGreater(retriesIntervalMs, 0, "retriesIntervalMs");
this.maxTimeoutRetries = maxTimeoutRetries;
this.maxConnectionRetries = maxConnectionRetries;
this.retriesIntervalMs = retriesIntervalMs;
}
public static BatchRetriesConfigBuilder defaultConfig() {
return BatchRetriesConfig.builder().maxTimeoutRetries(MAX_TIMEOUT_RETRIES).maxConnectionRetries(MAX_CONNECTION_RETRIES).retriesIntervalMs(RETRIES_INTERVAL);
}
@java.lang.SuppressWarnings("all")
public static class BatchRetriesConfigBuilder {
@java.lang.SuppressWarnings("all")
private int maxTimeoutRetries;
@java.lang.SuppressWarnings("all")
private int maxConnectionRetries;
@java.lang.SuppressWarnings("all")
private int retriesIntervalMs;
@java.lang.SuppressWarnings("all")
BatchRetriesConfigBuilder() {
}
/**
* @return {@code this}.
*/
@java.lang.SuppressWarnings("all")
public ReferencesBatcher.BatchRetriesConfig.BatchRetriesConfigBuilder maxTimeoutRetries(final int maxTimeoutRetries) {
this.maxTimeoutRetries = maxTimeoutRetries;
return this;
}
/**
* @return {@code this}.
*/
@java.lang.SuppressWarnings("all")
public ReferencesBatcher.BatchRetriesConfig.BatchRetriesConfigBuilder maxConnectionRetries(final int maxConnectionRetries) {
this.maxConnectionRetries = maxConnectionRetries;
return this;
}
/**
* @return {@code this}.
*/
@java.lang.SuppressWarnings("all")
public ReferencesBatcher.BatchRetriesConfig.BatchRetriesConfigBuilder retriesIntervalMs(final int retriesIntervalMs) {
this.retriesIntervalMs = retriesIntervalMs;
return this;
}
@java.lang.SuppressWarnings("all")
public ReferencesBatcher.BatchRetriesConfig build() {
return new ReferencesBatcher.BatchRetriesConfig(this.maxTimeoutRetries, this.maxConnectionRetries, this.retriesIntervalMs);
}
@java.lang.Override
@java.lang.SuppressWarnings("all")
public java.lang.String toString() {
return "ReferencesBatcher.BatchRetriesConfig.BatchRetriesConfigBuilder(maxTimeoutRetries=" + this.maxTimeoutRetries + ", maxConnectionRetries=" + this.maxConnectionRetries + ", retriesIntervalMs=" + this.retriesIntervalMs + ")";
}
}
@java.lang.SuppressWarnings("all")
public static ReferencesBatcher.BatchRetriesConfig.BatchRetriesConfigBuilder builder() {
return new ReferencesBatcher.BatchRetriesConfig.BatchRetriesConfigBuilder();
}
@java.lang.SuppressWarnings("all")
public int getMaxTimeoutRetries() {
return this.maxTimeoutRetries;
}
@java.lang.SuppressWarnings("all")
public int getMaxConnectionRetries() {
return this.maxConnectionRetries;
}
@java.lang.SuppressWarnings("all")
public int getRetriesIntervalMs() {
return this.retriesIntervalMs;
}
@java.lang.Override
@java.lang.SuppressWarnings("all")
public java.lang.String toString() {
return "ReferencesBatcher.BatchRetriesConfig(maxTimeoutRetries=" + this.getMaxTimeoutRetries() + ", maxConnectionRetries=" + this.getMaxConnectionRetries() + ", retriesIntervalMs=" + this.getRetriesIntervalMs() + ")";
}
@java.lang.Override
@java.lang.SuppressWarnings("all")
public boolean equals(final java.lang.Object o) {
if (o == this) return true;
if (!(o instanceof ReferencesBatcher.BatchRetriesConfig)) return false;
final ReferencesBatcher.BatchRetriesConfig other = (ReferencesBatcher.BatchRetriesConfig) o;
if (!other.canEqual((java.lang.Object) this)) return false;
if (this.getMaxTimeoutRetries() != other.getMaxTimeoutRetries()) return false;
if (this.getMaxConnectionRetries() != other.getMaxConnectionRetries()) return false;
if (this.getRetriesIntervalMs() != other.getRetriesIntervalMs()) return false;
return true;
}
@java.lang.SuppressWarnings("all")
protected boolean canEqual(final java.lang.Object other) {
return other instanceof ReferencesBatcher.BatchRetriesConfig;
}
@java.lang.Override
@java.lang.SuppressWarnings("all")
public int hashCode() {
final int PRIME = 59;
int result = 1;
result = result * PRIME + this.getMaxTimeoutRetries();
result = result * PRIME + this.getMaxConnectionRetries();
result = result * PRIME + this.getRetriesIntervalMs();
return result;
}
}
public static class AutoBatchConfig {
public static final int BATCH_SIZE = 100;
public static final int POOL_SIZE = 1;
public static final int AWAIT_TERMINATION_MS = 10000;
private final int batchSize;
private final int poolSize;
private final int awaitTerminationMs;
private final Consumer> callback;
private AutoBatchConfig(int batchSize, int poolSize, int awaitTerminationMs, Consumer> callback) {
Assert.requireGreaterEqual(batchSize, 1, "batchSize");
Assert.requireGreaterEqual(poolSize, 1, "corePoolSize");
Assert.requireGreater(awaitTerminationMs, 0, "awaitTerminationMs");
this.batchSize = batchSize;
this.poolSize = poolSize;
this.awaitTerminationMs = awaitTerminationMs;
this.callback = callback;
}
public static AutoBatchConfigBuilder defaultConfig() {
return AutoBatchConfig.builder().batchSize(BATCH_SIZE).poolSize(POOL_SIZE).awaitTerminationMs(AWAIT_TERMINATION_MS).callback(null);
}
@java.lang.SuppressWarnings("all")
public static class AutoBatchConfigBuilder {
@java.lang.SuppressWarnings("all")
private int batchSize;
@java.lang.SuppressWarnings("all")
private int poolSize;
@java.lang.SuppressWarnings("all")
private int awaitTerminationMs;
@java.lang.SuppressWarnings("all")
private Consumer> callback;
@java.lang.SuppressWarnings("all")
AutoBatchConfigBuilder() {
}
/**
* @return {@code this}.
*/
@java.lang.SuppressWarnings("all")
public ReferencesBatcher.AutoBatchConfig.AutoBatchConfigBuilder batchSize(final int batchSize) {
this.batchSize = batchSize;
return this;
}
/**
* @return {@code this}.
*/
@java.lang.SuppressWarnings("all")
public ReferencesBatcher.AutoBatchConfig.AutoBatchConfigBuilder poolSize(final int poolSize) {
this.poolSize = poolSize;
return this;
}
/**
* @return {@code this}.
*/
@java.lang.SuppressWarnings("all")
public ReferencesBatcher.AutoBatchConfig.AutoBatchConfigBuilder awaitTerminationMs(final int awaitTerminationMs) {
this.awaitTerminationMs = awaitTerminationMs;
return this;
}
/**
* @return {@code this}.
*/
@java.lang.SuppressWarnings("all")
public ReferencesBatcher.AutoBatchConfig.AutoBatchConfigBuilder callback(final Consumer> callback) {
this.callback = callback;
return this;
}
@java.lang.SuppressWarnings("all")
public ReferencesBatcher.AutoBatchConfig build() {
return new ReferencesBatcher.AutoBatchConfig(this.batchSize, this.poolSize, this.awaitTerminationMs, this.callback);
}
@java.lang.Override
@java.lang.SuppressWarnings("all")
public java.lang.String toString() {
return "ReferencesBatcher.AutoBatchConfig.AutoBatchConfigBuilder(batchSize=" + this.batchSize + ", poolSize=" + this.poolSize + ", awaitTerminationMs=" + this.awaitTerminationMs + ", callback=" + this.callback + ")";
}
}
@java.lang.SuppressWarnings("all")
public static ReferencesBatcher.AutoBatchConfig.AutoBatchConfigBuilder builder() {
return new ReferencesBatcher.AutoBatchConfig.AutoBatchConfigBuilder();
}
@java.lang.SuppressWarnings("all")
public int getBatchSize() {
return this.batchSize;
}
@java.lang.SuppressWarnings("all")
public int getPoolSize() {
return this.poolSize;
}
@java.lang.SuppressWarnings("all")
public int getAwaitTerminationMs() {
return this.awaitTerminationMs;
}
@java.lang.SuppressWarnings("all")
public Consumer> getCallback() {
return this.callback;
}
@java.lang.Override
@java.lang.SuppressWarnings("all")
public java.lang.String toString() {
return "ReferencesBatcher.AutoBatchConfig(batchSize=" + this.getBatchSize() + ", poolSize=" + this.getPoolSize() + ", awaitTerminationMs=" + this.getAwaitTerminationMs() + ", callback=" + this.getCallback() + ")";
}
@java.lang.Override
@java.lang.SuppressWarnings("all")
public boolean equals(final java.lang.Object o) {
if (o == this) return true;
if (!(o instanceof ReferencesBatcher.AutoBatchConfig)) return false;
final ReferencesBatcher.AutoBatchConfig other = (ReferencesBatcher.AutoBatchConfig) o;
if (!other.canEqual((java.lang.Object) this)) return false;
if (this.getBatchSize() != other.getBatchSize()) return false;
if (this.getPoolSize() != other.getPoolSize()) return false;
if (this.getAwaitTerminationMs() != other.getAwaitTerminationMs()) return false;
final java.lang.Object this$callback = this.getCallback();
final java.lang.Object other$callback = other.getCallback();
if (this$callback == null ? other$callback != null : !this$callback.equals(other$callback)) return false;
return true;
}
@java.lang.SuppressWarnings("all")
protected boolean canEqual(final java.lang.Object other) {
return other instanceof ReferencesBatcher.AutoBatchConfig;
}
@java.lang.Override
@java.lang.SuppressWarnings("all")
public int hashCode() {
final int PRIME = 59;
int result = 1;
result = result * PRIME + this.getBatchSize();
result = result * PRIME + this.getPoolSize();
result = result * PRIME + this.getAwaitTerminationMs();
final java.lang.Object $callback = this.getCallback();
result = result * PRIME + ($callback == null ? 43 : $callback.hashCode());
return result;
}
}
}