All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.aerospike.vector.client.dbclient.Client Maven / Gradle / Ivy

package com.aerospike.vector.client.dbclient;

import com.aerospike.vector.client.*;
import com.aerospike.vector.client.internal.*;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.protobuf.Empty;
import io.grpc.Channel;
import io.grpc.stub.StreamObserver;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.annotation.Nullable;
import java.util.*;
import java.util.List;
import java.util.Map;
import java.util.concurrent.*;
import java.util.stream.Collectors;

/**
 * This class implements the VectorDbClient interface using gRPC for communication with a vector database.
 * It manages vector data operations such as put, get, and vector search, leveraging asynchronous and synchronous gRPC calls.
 */
public class Client implements IClient {
    private static final Logger log = LoggerFactory.getLogger(Client.class);

    private final ExecutorService clientExecutor =  Executors.newThreadPerTaskExecutor( Thread.ofVirtual().name("avs-client-", 0L).factory());
    private final ChannelProvider channelProvider;

    /**
     * Constructs a VectorDbClientGrpc to manage database operations.
     *
     * @param seeds List of host and port pairs for initializing database connections.
     * @param listenerName The listener name for the gRPC service.
     * @param isLoadBalancer Indicates if load balancing should be enabled.
     */
    public Client(List seeds, String listenerName, boolean isLoadBalancer) {
        this.channelProvider = new ChannelProvider(seeds, listenerName, isLoadBalancer);
    }

    /**
     * {@inheritDoc}
     */
    @Override
    public void put(String namespace, @Nullable String set, Object key, Map bins, int writeType) {
        TransactGrpc.TransactBlockingStub transactService = TransactGrpc.newBlockingStub(channelProvider.getChannel()).withExecutor(clientExecutor);
        Key grpcKey = Conversions.buildKey(namespace, set, key);
        List grpcFields = bins.entrySet().stream()
                .map(entry -> Conversions.buildField(entry.getKey(), entry.getValue()))
                .collect(Collectors.toList());

        PutRequest putRequest = PutRequest.newBuilder()
                .setKey(grpcKey)
                .setWriteTypeValue(writeType)
                .addAllFields(grpcFields)
                .build();

        transactService.put(putRequest);
    }

    @Override
    public ListenableFuture putAsync(String namespace, @Nullable String set, Object key, Map fields, int writeType) {
        TransactGrpc.TransactFutureStub transactService = TransactGrpc.newFutureStub(channelProvider.getChannel()).withExecutor(clientExecutor);
        Key grpcKey = Conversions.buildKey(namespace, set, key);
        List grpcFields = fields.entrySet().stream()
                    .map(entry -> Conversions.buildField(entry.getKey(), entry.getValue()))
                    .collect(Collectors.toList());

        PutRequest putRequest = PutRequest.newBuilder()
                .setKey(grpcKey)
                .setWriteTypeValue(writeType)
                .addAllFields(grpcFields)
                .build();

        return transactService.put(putRequest);
    }

    @Override
    public Map get(String namespace, @Nullable String set, Object key, Projection projection) {
        TransactGrpc.TransactBlockingStub transactService = TransactGrpc.newBlockingStub(channelProvider.getChannel()).withExecutor(clientExecutor);
        GetRequest getRequest = GetRequest.newBuilder()
                .setKey(Conversions.buildKey(namespace, set, key))
                .setProjectionSpec(projection.toProjectionSpec())
                .build();

        return transactService.get(getRequest).getFieldsList().stream()
                .collect(Collectors.toMap(Field::getName, Field::getValue));

    }

    @Override
    public boolean exists(String namespace, @Nullable String set, Object key) {
        TransactGrpc.TransactBlockingStub transactService = TransactGrpc.newBlockingStub(channelProvider.getChannel()).withExecutor(clientExecutor);
        Key grpcKey = Conversions.buildKey(namespace, set, key);
        return transactService.exists(ExistsRequest.newBuilder().setKey(grpcKey).build()).getValue();

    }

    /**
     *{@inheritDoc}
     */
    @Override
    public boolean isIndexed(String namespace, String set, Object key, String indexNamespace, String indexName) {
        TransactGrpc.TransactBlockingStub transactService = TransactGrpc.newBlockingStub(channelProvider.getChannel()).withExecutor(clientExecutor);
        IsIndexedRequest request = IsIndexedRequest.newBuilder()
                .setKey(Conversions.buildKey(namespace, set, key))
                .setIndexId(IndexId.newBuilder().setNamespace(indexNamespace).setName(indexName).build())
                .build();

        return transactService.isIndexed(request).getValue();
    }

    private IndexStatusResponse indexStatus(IndexId indexId) {
        IndexServiceGrpc.IndexServiceBlockingStub indexService = IndexServiceGrpc.newBlockingStub(channelProvider.getChannel());
        return indexService.getStatus(indexId);
    }
    @Override
    /**
     * {@inheritDoc}
     */
    public void waitForIndexCompletion(IndexId indexId, long timeoutMillis) {
        ExecutorService executor = Executors.newSingleThreadExecutor();
        Callable task = () -> {
            long waitInterval = 20000L;  // 20 seconds
            long endTime = System.currentTimeMillis() + timeoutMillis;
            while (System.currentTimeMillis() < endTime) {
                Thread.sleep(waitInterval);
                IndexStatusResponse indexStatus = indexStatus(indexId);
                if (indexStatus.getUnmergedRecordCount() == 0) {
                    return null;  // Indexing completed
                }
            }
            throw new TimeoutException("Indexing did not complete within the allotted time.");
        };

        Future future = executor.submit(task);
        try {
            future.get(timeoutMillis, TimeUnit.MILLISECONDS);
        } catch (InterruptedException | ExecutionException | TimeoutException e) {
            throw new RuntimeException("Error waiting for index completion", e);
        } finally {
            executor.shutdown();
        }
    }

    /**
     * {@inheritDoc}
     * Performs a vector search against the specified index using the given query vector.
     */
    @Override
    public List vectorSearch(VectorSearchQuery query) {
        Objects.requireNonNull(query, "query can't be null");

        Channel channel = channelProvider.getChannel();
        TransactGrpc.TransactBlockingStub transactStub = TransactGrpc.newBlockingStub(channel).withExecutor(clientExecutor);

        IndexId indexId = IndexId.newBuilder()
                .setNamespace(query.getNamespace())
                .setName(query.getIndexName())
                .build();

        VectorSearchRequest.Builder requestBuilder = VectorSearchRequest.newBuilder()
                .setIndex(indexId)
                .setQueryVector(query.getVector())
                .setLimit(query.getLimit())
                .setProjection(query.toVectorSearchRequest().getProjection());

         if(query.getSearchParams() != null) {
             requestBuilder.setHnswSearchParams(query.getSearchParams());
         }
        VectorSearchRequest request = requestBuilder.build();
        Iterator response = transactStub.vectorSearch(request);
        List neighbors = new ArrayList<>();
        while (response.hasNext()){
            neighbors.add(response.next());
        }

        return neighbors;
    }


    @Override
    public void delete(String namespace, @Nullable String set, Object key) {
        TransactGrpc.TransactBlockingStub transactService = TransactGrpc.newBlockingStub(channelProvider.getChannel());
        Key grpcKey = Conversions.buildKey(namespace, set, key);
        DeleteRequest deleteRequest = DeleteRequest.newBuilder().setKey(grpcKey).build();
        transactService.delete(deleteRequest);
    }

    /**
     * Performs an asynchronous vector search, passing results back through the provided listener.
     * @param listener Listener to handle asynchronous results.
     * @param query Query Vector {@code } for the search.
     */
    @Override
    public void vectorSearchAsync(VectorSearchQuery query, VectorSearchListener listener) {
        TransactGrpc.TransactStub transactService = TransactGrpc.newStub(channelProvider.getChannel()).withExecutor(clientExecutor);

        transactService.vectorSearch(query.toVectorSearchRequest(), new StreamObserver() {
            @Override
            public void onNext(Neighbor result) {
                listener.onNext(result);
            }

            @Override
            public void onError(Throwable t) {
                listener.onError(t);
            }

            @Override
            public void onCompleted() {
                listener.onComplete();
            }
        });
    }

    @Override
    public void close() {
        try {
            channelProvider.close();
            clientExecutor.shutdown();
            if (!clientExecutor.awaitTermination(30, TimeUnit.SECONDS)) {
                clientExecutor.shutdownNow();
            }
        } catch (InterruptedException e) {
            Thread.currentThread().interrupt();
            log.error("Interrupted during close", e);
        }
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy