All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.aerospike.vector.client.dbclient.Client Maven / Gradle / Ivy

/* (C)2024 */
package com.aerospike.vector.client.dbclient;

import com.aerospike.vector.client.*;
import com.aerospike.vector.client.Projection;
import com.aerospike.vector.client.VectorSearchQuery;
import com.aerospike.vector.client.internal.*;
import com.aerospike.vector.client.proto.*;
import com.google.protobuf.Empty;
import io.grpc.Deadline;
import io.grpc.stub.StreamObserver;
import java.util.*;
import java.util.List;
import java.util.Map;
import java.util.concurrent.*;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * This class implements the DbClient interface using gRPC for communication with a vector database.
 * It manages vector data operations such as put, get, and vector search, leveraging asynchronous
 * and synchronous gRPC calls.
 */
public class Client implements IClient {
    private static final Logger log = LoggerFactory.getLogger(Client.class);

    private final ExecutorService clientExecutor =
            Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors() * 2);
    private final ClusterTenderer clusterTenderer;

    /**
     * Constructs a Client to manage database operations.
     *
     * @param connectionConfig the configuration settings for connecting to the cluster. This
     *     includes parameters such as the cluster URL, credentials, TLS and other required
     *     information.
     */
    public Client(ConnectionConfig connectionConfig) {
        this.clusterTenderer = new ClusterTenderer(connectionConfig, "dbclient");
    }

    /** {@inheritDoc} */
    @Override
    public void put(
            String namespace,
            @Nullable String set,
            Object key,
            Map bins,
            boolean ignoreMemQueueFull,
            WriteType writeType) {
        TransactServiceGrpc.TransactServiceBlockingStub transactService =
                clusterTenderer.getTransactBlockingStub();
        transactService.put(
                buildPutRequest(namespace, set, key, bins, ignoreMemQueueFull, writeType));
    }

    @Override
    public void putAsync(
            String namespace,
            @Nullable String set,
            Object key,
            Map fields,
            boolean ignoreMemQueueFull,
            WriteType writeType) {
        TransactServiceGrpc.TransactServiceStub transactService =
                clusterTenderer.getTransactNonBlockingStub().withExecutor(clientExecutor);

        transactService.put(
                buildPutRequest(namespace, set, key, fields, ignoreMemQueueFull, writeType),
                new StreamObserver<>() {
                    @Override
                    public void onNext(Empty empty) {}

                    @Override
                    public void onError(Throwable throwable) {
                        throw new RuntimeException("Error in putAsync");
                    }

                    @Override
                    public void onCompleted() {}
                });
    }

    private PutRequest buildPutRequest(
            String namespace,
            @Nullable String set,
            Object key,
            Map fields,
            boolean ignoreMemQueueFull,
            WriteType writeType) {
        Key grpcKey = Conversions.buildKey(namespace, set, key);
        List grpcFields =
                fields.entrySet().stream()
                        .map(entry -> Conversions.buildField(entry.getKey(), entry.getValue()))
                        .collect(Collectors.toList());

        return PutRequest.newBuilder()
                .setKey(grpcKey)
                .setWriteTypeValue(writeType.getNumber())
                .addAllFields(grpcFields)
                .build();
    }

    @Override
    public Map get(
            String namespace, @Nullable String set, Object key, Projection projection) {
        TransactServiceGrpc.TransactServiceBlockingStub transactService =
                clusterTenderer.getTransactBlockingStub();
        GetRequest getRequest =
                GetRequest.newBuilder()
                        .setKey(Conversions.buildKey(namespace, set, key))
                        .setProjection(projection.toProjectionSpec())
                        .build();

        return transactService.get(getRequest).getFieldsList().stream()
                .collect(Collectors.toMap(Field::getName, Field::getValue));
    }

    @Override
    public boolean exists(String namespace, @Nullable String set, Object key) {
        TransactServiceGrpc.TransactServiceBlockingStub transactService =
                clusterTenderer.getTransactBlockingStub();
        Key grpcKey = Conversions.buildKey(namespace, set, key);
        return transactService
                .exists(ExistsRequest.newBuilder().setKey(grpcKey).build())
                .getValue();
    }

    /** {@inheritDoc} */
    @Override
    public boolean isIndexed(
            String namespace, String set, Object key, String indexNamespace, String indexName) {
        TransactServiceGrpc.TransactServiceBlockingStub transactService =
                clusterTenderer.getTransactBlockingStub();
        IsIndexedRequest request =
                IsIndexedRequest.newBuilder()
                        .setKey(Conversions.buildKey(namespace, set, key))
                        .setIndexId(
                                IndexId.newBuilder()
                                        .setNamespace(indexNamespace)
                                        .setName(indexName)
                                        .build())
                        .build();

        return transactService.isIndexed(request).getValue();
    }

    private IndexStatusResponse indexStatus(IndexId indexId) {
        IndexServiceGrpc.IndexServiceBlockingStub indexService =
                clusterTenderer.getIndexServiceBlockingStub();
        IndexStatusRequest request = IndexStatusRequest.newBuilder().setIndexId(indexId).build();
        return indexService.getStatus(request);
    }

    @Override
    /*
     {@inheritDoc}
    */
    public void waitForIndexCompletion(IndexId indexId, long timeoutMillis) {
        Callable task =
                () -> {
                    long waitInterval = 20_000L; // 20 seconds
                    long endTime = System.currentTimeMillis() + timeoutMillis;
                    while (System.currentTimeMillis() < endTime) {
                        Thread.sleep(waitInterval);
                        IndexStatusResponse indexStatus = indexStatus(indexId);
                        if (indexStatus.getUnmergedRecordCount() == 0) {
                            return null; // Indexing completed
                        }
                    }
                    throw new TimeoutException(
                            "Indexing did not complete within the allotted time.");
                };

        Future future = clientExecutor.submit(task);
        try {
            future.get(timeoutMillis, TimeUnit.MILLISECONDS);
        } catch (InterruptedException | ExecutionException | TimeoutException e) {
            throw new RuntimeException("Error waiting for index completion", e);
        }
    }

    /**
     * {@inheritDoc} Performs a vector search against the specified index using the given query
     * vector.
     */
    @Override
    public List vectorSearch(VectorSearchQuery query) {
        Objects.requireNonNull(query, "query can't be null");
        TransactServiceGrpc.TransactServiceBlockingStub transactService =
                clusterTenderer.getTransactBlockingStub();
        if (query.getTimeout() != Integer.MAX_VALUE) {
            transactService =
                    transactService.withDeadline(
                            Deadline.after(query.getTimeout(), TimeUnit.MILLISECONDS));
        }

        IndexId indexId =
                IndexId.newBuilder()
                        .setNamespace(query.getNamespace())
                        .setName(query.getIndexName())
                        .build();

        VectorSearchRequest.Builder requestBuilder =
                VectorSearchRequest.newBuilder()
                        .setIndex(indexId)
                        .setQueryVector(query.getVector())
                        .setLimit(query.getLimit())
                        .setProjection(
                                query.toVectorSearchRequest().getProjection()); // set timeout

        if (query.getSearchParams() != null) {
            requestBuilder.setHnswSearchParams(query.getSearchParams());
        }
        VectorSearchRequest request = requestBuilder.build();
        Iterator response = transactService.vectorSearch(request);
        List neighbors = new ArrayList<>();
        while (response.hasNext()) {
            neighbors.add(response.next());
        }

        return neighbors;
    }

    @Override
    public void delete(String namespace, @Nullable String set, Object key) {
        TransactServiceGrpc.TransactServiceBlockingStub transactService =
                clusterTenderer.getTransactBlockingStub();
        Key grpcKey = Conversions.buildKey(namespace, set, key);
        DeleteRequest deleteRequest = DeleteRequest.newBuilder().setKey(grpcKey).build();
        transactService.delete(deleteRequest);
    }

    /**
     * Performs an asynchronous vector search, passing results back through the provided listener.
     *
     * @param listener Listener to handle asynchronous results.
     * @param query Query Vector {@code } for the search.
     */
    @Override
    public void vectorSearchAsync(VectorSearchQuery query, VectorSearchListener listener) {
        TransactServiceGrpc.TransactServiceStub transactService =
                clusterTenderer.getTransactNonBlockingStub().withExecutor(clientExecutor);

        if (query.getTimeout() != Integer.MAX_VALUE) {
            transactService =
                    transactService.withDeadline(
                            Deadline.after(query.getTimeout(), TimeUnit.MILLISECONDS));
        }

        transactService.vectorSearch(
                query.toVectorSearchRequest(),
                new StreamObserver<>() {
                    @Override
                    public void onNext(Neighbor result) {
                        listener.onNext(result);
                    }

                    @Override
                    public void onError(Throwable t) {
                        listener.onError(t);
                    }

                    @Override
                    public void onCompleted() {
                        listener.onComplete();
                    }
                });
    }

    @Override
    public void close() {
        try {
            clusterTenderer.close();
            clientExecutor.shutdown();
            if (!clientExecutor.awaitTermination(30, TimeUnit.SECONDS)) {
                clientExecutor.shutdownNow();
            }
        } catch (InterruptedException e) {
            Thread.currentThread().interrupt();
            log.error("Interrupted during close", e);
        }
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy