All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.weaviate.client.v1.batch.api.ReferencesBatcher Maven / Gradle / Ivy

There is a newer version: 4.9.0
Show newest version
// Generated by delombok at Mon Jan 08 10:19:26 UTC 2024
package io.weaviate.client.v1.batch.api;

import io.weaviate.client.v1.batch.model.BatchReference;
import io.weaviate.client.v1.batch.model.BatchReferenceResponse;
import io.weaviate.client.v1.batch.util.ReferencesPath;
import org.apache.commons.lang3.ObjectUtils;
import io.weaviate.client.Config;
import io.weaviate.client.base.BaseClient;
import io.weaviate.client.base.ClientResult;
import io.weaviate.client.base.Response;
import io.weaviate.client.base.Result;
import io.weaviate.client.base.WeaviateErrorMessage;
import io.weaviate.client.base.WeaviateErrorResponse;
import io.weaviate.client.base.http.HttpClient;
import io.weaviate.client.base.util.Assert;
import java.io.Closeable;
import java.net.ConnectException;
import java.net.SocketTimeoutException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;
import java.util.function.Supplier;
import java.util.stream.Collectors;

public class ReferencesBatcher extends BaseClient implements ClientResult, Closeable {
  private final ReferencesPath referencesPath;
  private final BatchRetriesConfig batchRetriesConfig;
  private final AutoBatchConfig autoBatchConfig;
  private final boolean autoRunEnabled;
  private final ScheduledExecutorService executorService;
  private final DelayedExecutor delayedExecutor;
  private final List references;
  private String consistencyLevel;
  private final List>> undoneFutures;

  private ReferencesBatcher(HttpClient httpClient, Config config, ReferencesPath referencesPath, BatchRetriesConfig batchRetriesConfig, AutoBatchConfig autoBatchConfig) {
    super(httpClient, config);
    this.referencesPath = referencesPath;
    this.references = new ArrayList<>();
    this.batchRetriesConfig = batchRetriesConfig;
    if (autoBatchConfig != null) {
      this.autoRunEnabled = true;
      this.autoBatchConfig = autoBatchConfig;
      this.executorService = Executors.newScheduledThreadPool(autoBatchConfig.poolSize);
      this.delayedExecutor = new ExecutorServiceDelayedExecutor(executorService);
      this.undoneFutures = Collections.synchronizedList(new ArrayList<>());
    } else {
      this.autoRunEnabled = false;
      this.autoBatchConfig = null;
      this.executorService = null;
      this.delayedExecutor = new SleepDelayedExecutor();
      this.undoneFutures = null;
    }
  }

  public static ReferencesBatcher create(HttpClient httpClient, Config config, ReferencesPath referencesPath, BatchRetriesConfig batchRetriesConfig) {
    Assert.requiredNotNull(batchRetriesConfig, "batchRetriesConfig");
    return new ReferencesBatcher(httpClient, config, referencesPath, batchRetriesConfig, null);
  }

  public static ReferencesBatcher createAuto(HttpClient httpClient, Config config, ReferencesPath referencesPath, BatchRetriesConfig batchRetriesConfig, AutoBatchConfig autoBatchConfig) {
    Assert.requiredNotNull(batchRetriesConfig, "batchRetriesConfig");
    Assert.requiredNotNull(autoBatchConfig, "autoBatchConfig");
    return new ReferencesBatcher(httpClient, config, referencesPath, batchRetriesConfig, autoBatchConfig);
  }

  public ReferencesBatcher withReference(BatchReference reference) {
    return withReferences(reference);
  }

  public ReferencesBatcher withReferences(BatchReference... references) {
    this.references.addAll(Arrays.asList(references));
    autoRun();
    return this;
  }

  public ReferencesBatcher withConsistencyLevel(String consistencyLevel) {
    this.consistencyLevel = consistencyLevel;
    return this;
  }

  @Override
  public Result run() {
    if (autoRunEnabled) {
      flush(); // fallback to flush in auto run enabled
      return null;
    }
    if (references.isEmpty()) {
      return new Result<>(0, new BatchReferenceResponse[0], null);
    }
    List batch = extractBatch(references.size());
    return runRecursively(batch, 0, 0, (DelayedExecutor>) delayedExecutor);
  }

  public void flush() {
    if (!autoRunEnabled) {
      run(); // fallback to run if auto run disabled
      return;
    }
    if (!references.isEmpty()) {
      List batch = extractBatch(references.size());
      runInThread(batch);
    }
    CompletableFuture[] futures = undoneFutures.toArray(new CompletableFuture[0]);
    if (futures.length == 0) {
      return;
    }
    CompletableFuture.allOf(futures).join();
  }

  @Override
  public void close() {
    if (!autoRunEnabled) {
      return;
    }
    executorService.shutdown();
    try {
      if (!executorService.awaitTermination(autoBatchConfig.awaitTerminationMs, TimeUnit.MILLISECONDS)) {
        executorService.shutdownNow();
      }
    } catch (InterruptedException e) {
      executorService.shutdownNow();
    }
  }

  private List extractBatch(int batchSize) {
    List batch = new ArrayList<>(batchSize);
    List sublist = references.subList(0, batchSize);
    batch.addAll(sublist);
    sublist.clear();
    return batch;
  }

  private void autoRun() {
    if (!autoRunEnabled) {
      return;
    }
    while (references.size() >= autoBatchConfig.batchSize) {
      List batch = extractBatch(autoBatchConfig.batchSize);
      runInThread(batch);
    }
  }

  private void runInThread(List batch) {
    CompletableFuture> future = CompletableFuture.supplyAsync(() -> createRunFuture(batch), executorService).thenCompose(f -> f);
    if (autoBatchConfig.callback != null) {
      future = future.whenComplete((result, e) -> autoBatchConfig.callback.accept(result));
    }
    CompletableFuture> undoneFuture = future;
    undoneFutures.add(undoneFuture);
    undoneFuture.whenComplete((result, ex) -> undoneFutures.remove(undoneFuture));
  }

  private CompletableFuture> createRunFuture(List batch) {
    return runRecursively(batch, 0, 0, (DelayedExecutor>>) delayedExecutor);
  }

  private  T runRecursively(List batch, int connectionErrorCount, int timeoutErrorCount, DelayedExecutor delayedExecutor) {
    Result result = internalRun(batch);
    if (result.hasErrors()) {
      List messages = result.getError().getMessages();
      if (!messages.isEmpty()) {
        Throwable throwable = messages.get(0).getThrowable();
        boolean executeAgain = false;
        int delay = 0;
        if (throwable instanceof ConnectException) {
          if (connectionErrorCount++ < batchRetriesConfig.maxConnectionRetries) {
            executeAgain = true;
            delay = connectionErrorCount * batchRetriesConfig.retriesIntervalMs;
          }
        } else if (throwable instanceof SocketTimeoutException) {
          if (timeoutErrorCount++ < batchRetriesConfig.maxTimeoutRetries) {
            executeAgain = true;
            delay = timeoutErrorCount * batchRetriesConfig.retriesIntervalMs;
          }
        }
        if (executeAgain) {
          int lambdaConnectionErrorCount = connectionErrorCount;
          int lambdaTimeoutErrorCount = timeoutErrorCount;
          List lambdaBatch = batch;
          return delayedExecutor.delayed(delay, () -> runRecursively(lambdaBatch, lambdaConnectionErrorCount, lambdaTimeoutErrorCount, delayedExecutor));
        }
      }
    } else {
      batch = null;
    }
    Result finalResult = createFinalResultFromLastResult(result, batch);
    return delayedExecutor.now(finalResult);
  }

  private Result internalRun(List batch) {
    BatchReference[] payload = batch.toArray(new BatchReference[0]);
    String path = referencesPath.buildCreate(ReferencesPath.Params.builder().consistencyLevel(consistencyLevel).build());
    Response resp = sendPostRequest(path, payload, BatchReferenceResponse[].class);
    return new Result<>(resp);
  }

  private Result createFinalResultFromLastResult(Result lastResult, List failedBatch) {
    if (ObjectUtils.isEmpty(failedBatch)) {
      return lastResult;
    }
    String failedRefs = failedBatch.stream().map(ref -> ref.getFrom() + " => " + ref.getTo()).collect(Collectors.joining(", "));
    WeaviateErrorMessage failedRefsMessage = WeaviateErrorMessage.builder().message("Failed refs: " + failedRefs).build();
    List messages;
    int statusCode = 0;
    if (lastResult.hasErrors()) {
      statusCode = lastResult.getError().getStatusCode();
      List prevMessages = lastResult.getError().getMessages();
      messages = new ArrayList<>(prevMessages.size() + 1);
      messages.addAll(prevMessages);
      messages.add(failedRefsMessage);
    } else {
      messages = Collections.singletonList(failedRefsMessage);
    }
    return new Result<>(statusCode, null, WeaviateErrorResponse.builder().error(messages).code(statusCode).build());
  }


  private interface DelayedExecutor {
    T delayed(int delay, Supplier supplier);

    T now(Result result);
  }


  private static class ExecutorServiceDelayedExecutor implements DelayedExecutor>> {
    private final ScheduledExecutorService executorService;

    @Override
    public CompletableFuture> delayed(int delay, Supplier>> supplier) {
      Executor executor = runnable -> executorService.schedule(runnable, delay, TimeUnit.MILLISECONDS);
      return CompletableFuture.supplyAsync(supplier, executor).thenCompose(f -> f);
    }

    @Override
    public CompletableFuture> now(Result result) {
      return CompletableFuture.completedFuture(result);
    }

    @java.lang.SuppressWarnings("all")
    public ExecutorServiceDelayedExecutor(final ScheduledExecutorService executorService) {
      this.executorService = executorService;
    }
  }


  private static class SleepDelayedExecutor implements DelayedExecutor> {
    @Override
    public Result delayed(int delay, Supplier> supplier) {
      try {
        Thread.sleep(delay);
      } catch (InterruptedException e) {
        Thread.currentThread().interrupt();
      }
      return supplier.get();
    }

    @Override
    public Result now(Result result) {
      return result;
    }
  }


  public static class BatchRetriesConfig {
    public static final int MAX_TIMEOUT_RETRIES = 3;
    public static final int MAX_CONNECTION_RETRIES = 3;
    public static final int RETRIES_INTERVAL = 2000;
    private final int maxTimeoutRetries;
    private final int maxConnectionRetries;
    private final int retriesIntervalMs;

    private BatchRetriesConfig(int maxTimeoutRetries, int maxConnectionRetries, int retriesIntervalMs) {
      Assert.requireGreaterEqual(maxTimeoutRetries, 0, "maxTimeoutRetries");
      Assert.requireGreaterEqual(maxConnectionRetries, 0, "maxConnectionRetries");
      Assert.requireGreater(retriesIntervalMs, 0, "retriesIntervalMs");
      this.maxTimeoutRetries = maxTimeoutRetries;
      this.maxConnectionRetries = maxConnectionRetries;
      this.retriesIntervalMs = retriesIntervalMs;
    }

    public static BatchRetriesConfigBuilder defaultConfig() {
      return BatchRetriesConfig.builder().maxTimeoutRetries(MAX_TIMEOUT_RETRIES).maxConnectionRetries(MAX_CONNECTION_RETRIES).retriesIntervalMs(RETRIES_INTERVAL);
    }


    @java.lang.SuppressWarnings("all")
    public static class BatchRetriesConfigBuilder {
      @java.lang.SuppressWarnings("all")
      private int maxTimeoutRetries;
      @java.lang.SuppressWarnings("all")
      private int maxConnectionRetries;
      @java.lang.SuppressWarnings("all")
      private int retriesIntervalMs;

      @java.lang.SuppressWarnings("all")
      BatchRetriesConfigBuilder() {
      }

      /**
       * @return {@code this}.
       */
      @java.lang.SuppressWarnings("all")
      public ReferencesBatcher.BatchRetriesConfig.BatchRetriesConfigBuilder maxTimeoutRetries(final int maxTimeoutRetries) {
        this.maxTimeoutRetries = maxTimeoutRetries;
        return this;
      }

      /**
       * @return {@code this}.
       */
      @java.lang.SuppressWarnings("all")
      public ReferencesBatcher.BatchRetriesConfig.BatchRetriesConfigBuilder maxConnectionRetries(final int maxConnectionRetries) {
        this.maxConnectionRetries = maxConnectionRetries;
        return this;
      }

      /**
       * @return {@code this}.
       */
      @java.lang.SuppressWarnings("all")
      public ReferencesBatcher.BatchRetriesConfig.BatchRetriesConfigBuilder retriesIntervalMs(final int retriesIntervalMs) {
        this.retriesIntervalMs = retriesIntervalMs;
        return this;
      }

      @java.lang.SuppressWarnings("all")
      public ReferencesBatcher.BatchRetriesConfig build() {
        return new ReferencesBatcher.BatchRetriesConfig(this.maxTimeoutRetries, this.maxConnectionRetries, this.retriesIntervalMs);
      }

      @java.lang.Override
      @java.lang.SuppressWarnings("all")
      public java.lang.String toString() {
        return "ReferencesBatcher.BatchRetriesConfig.BatchRetriesConfigBuilder(maxTimeoutRetries=" + this.maxTimeoutRetries + ", maxConnectionRetries=" + this.maxConnectionRetries + ", retriesIntervalMs=" + this.retriesIntervalMs + ")";
      }
    }

    @java.lang.SuppressWarnings("all")
    public static ReferencesBatcher.BatchRetriesConfig.BatchRetriesConfigBuilder builder() {
      return new ReferencesBatcher.BatchRetriesConfig.BatchRetriesConfigBuilder();
    }

    @java.lang.SuppressWarnings("all")
    public int getMaxTimeoutRetries() {
      return this.maxTimeoutRetries;
    }

    @java.lang.SuppressWarnings("all")
    public int getMaxConnectionRetries() {
      return this.maxConnectionRetries;
    }

    @java.lang.SuppressWarnings("all")
    public int getRetriesIntervalMs() {
      return this.retriesIntervalMs;
    }

    @java.lang.Override
    @java.lang.SuppressWarnings("all")
    public java.lang.String toString() {
      return "ReferencesBatcher.BatchRetriesConfig(maxTimeoutRetries=" + this.getMaxTimeoutRetries() + ", maxConnectionRetries=" + this.getMaxConnectionRetries() + ", retriesIntervalMs=" + this.getRetriesIntervalMs() + ")";
    }

    @java.lang.Override
    @java.lang.SuppressWarnings("all")
    public boolean equals(final java.lang.Object o) {
      if (o == this) return true;
      if (!(o instanceof ReferencesBatcher.BatchRetriesConfig)) return false;
      final ReferencesBatcher.BatchRetriesConfig other = (ReferencesBatcher.BatchRetriesConfig) o;
      if (!other.canEqual((java.lang.Object) this)) return false;
      if (this.getMaxTimeoutRetries() != other.getMaxTimeoutRetries()) return false;
      if (this.getMaxConnectionRetries() != other.getMaxConnectionRetries()) return false;
      if (this.getRetriesIntervalMs() != other.getRetriesIntervalMs()) return false;
      return true;
    }

    @java.lang.SuppressWarnings("all")
    protected boolean canEqual(final java.lang.Object other) {
      return other instanceof ReferencesBatcher.BatchRetriesConfig;
    }

    @java.lang.Override
    @java.lang.SuppressWarnings("all")
    public int hashCode() {
      final int PRIME = 59;
      int result = 1;
      result = result * PRIME + this.getMaxTimeoutRetries();
      result = result * PRIME + this.getMaxConnectionRetries();
      result = result * PRIME + this.getRetriesIntervalMs();
      return result;
    }
  }


  public static class AutoBatchConfig {
    public static final int BATCH_SIZE = 100;
    public static final int POOL_SIZE = 1;
    public static final int AWAIT_TERMINATION_MS = 10000;
    private final int batchSize;
    private final int poolSize;
    private final int awaitTerminationMs;
    private final Consumer> callback;

    private AutoBatchConfig(int batchSize, int poolSize, int awaitTerminationMs, Consumer> callback) {
      Assert.requireGreaterEqual(batchSize, 1, "batchSize");
      Assert.requireGreaterEqual(poolSize, 1, "corePoolSize");
      Assert.requireGreater(awaitTerminationMs, 0, "awaitTerminationMs");
      this.batchSize = batchSize;
      this.poolSize = poolSize;
      this.awaitTerminationMs = awaitTerminationMs;
      this.callback = callback;
    }

    public static AutoBatchConfigBuilder defaultConfig() {
      return AutoBatchConfig.builder().batchSize(BATCH_SIZE).poolSize(POOL_SIZE).awaitTerminationMs(AWAIT_TERMINATION_MS).callback(null);
    }


    @java.lang.SuppressWarnings("all")
    public static class AutoBatchConfigBuilder {
      @java.lang.SuppressWarnings("all")
      private int batchSize;
      @java.lang.SuppressWarnings("all")
      private int poolSize;
      @java.lang.SuppressWarnings("all")
      private int awaitTerminationMs;
      @java.lang.SuppressWarnings("all")
      private Consumer> callback;

      @java.lang.SuppressWarnings("all")
      AutoBatchConfigBuilder() {
      }

      /**
       * @return {@code this}.
       */
      @java.lang.SuppressWarnings("all")
      public ReferencesBatcher.AutoBatchConfig.AutoBatchConfigBuilder batchSize(final int batchSize) {
        this.batchSize = batchSize;
        return this;
      }

      /**
       * @return {@code this}.
       */
      @java.lang.SuppressWarnings("all")
      public ReferencesBatcher.AutoBatchConfig.AutoBatchConfigBuilder poolSize(final int poolSize) {
        this.poolSize = poolSize;
        return this;
      }

      /**
       * @return {@code this}.
       */
      @java.lang.SuppressWarnings("all")
      public ReferencesBatcher.AutoBatchConfig.AutoBatchConfigBuilder awaitTerminationMs(final int awaitTerminationMs) {
        this.awaitTerminationMs = awaitTerminationMs;
        return this;
      }

      /**
       * @return {@code this}.
       */
      @java.lang.SuppressWarnings("all")
      public ReferencesBatcher.AutoBatchConfig.AutoBatchConfigBuilder callback(final Consumer> callback) {
        this.callback = callback;
        return this;
      }

      @java.lang.SuppressWarnings("all")
      public ReferencesBatcher.AutoBatchConfig build() {
        return new ReferencesBatcher.AutoBatchConfig(this.batchSize, this.poolSize, this.awaitTerminationMs, this.callback);
      }

      @java.lang.Override
      @java.lang.SuppressWarnings("all")
      public java.lang.String toString() {
        return "ReferencesBatcher.AutoBatchConfig.AutoBatchConfigBuilder(batchSize=" + this.batchSize + ", poolSize=" + this.poolSize + ", awaitTerminationMs=" + this.awaitTerminationMs + ", callback=" + this.callback + ")";
      }
    }

    @java.lang.SuppressWarnings("all")
    public static ReferencesBatcher.AutoBatchConfig.AutoBatchConfigBuilder builder() {
      return new ReferencesBatcher.AutoBatchConfig.AutoBatchConfigBuilder();
    }

    @java.lang.SuppressWarnings("all")
    public int getBatchSize() {
      return this.batchSize;
    }

    @java.lang.SuppressWarnings("all")
    public int getPoolSize() {
      return this.poolSize;
    }

    @java.lang.SuppressWarnings("all")
    public int getAwaitTerminationMs() {
      return this.awaitTerminationMs;
    }

    @java.lang.SuppressWarnings("all")
    public Consumer> getCallback() {
      return this.callback;
    }

    @java.lang.Override
    @java.lang.SuppressWarnings("all")
    public java.lang.String toString() {
      return "ReferencesBatcher.AutoBatchConfig(batchSize=" + this.getBatchSize() + ", poolSize=" + this.getPoolSize() + ", awaitTerminationMs=" + this.getAwaitTerminationMs() + ", callback=" + this.getCallback() + ")";
    }

    @java.lang.Override
    @java.lang.SuppressWarnings("all")
    public boolean equals(final java.lang.Object o) {
      if (o == this) return true;
      if (!(o instanceof ReferencesBatcher.AutoBatchConfig)) return false;
      final ReferencesBatcher.AutoBatchConfig other = (ReferencesBatcher.AutoBatchConfig) o;
      if (!other.canEqual((java.lang.Object) this)) return false;
      if (this.getBatchSize() != other.getBatchSize()) return false;
      if (this.getPoolSize() != other.getPoolSize()) return false;
      if (this.getAwaitTerminationMs() != other.getAwaitTerminationMs()) return false;
      final java.lang.Object this$callback = this.getCallback();
      final java.lang.Object other$callback = other.getCallback();
      if (this$callback == null ? other$callback != null : !this$callback.equals(other$callback)) return false;
      return true;
    }

    @java.lang.SuppressWarnings("all")
    protected boolean canEqual(final java.lang.Object other) {
      return other instanceof ReferencesBatcher.AutoBatchConfig;
    }

    @java.lang.Override
    @java.lang.SuppressWarnings("all")
    public int hashCode() {
      final int PRIME = 59;
      int result = 1;
      result = result * PRIME + this.getBatchSize();
      result = result * PRIME + this.getPoolSize();
      result = result * PRIME + this.getAwaitTerminationMs();
      final java.lang.Object $callback = this.getCallback();
      result = result * PRIME + ($callback == null ? 43 : $callback.hashCode());
      return result;
    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy