All Downloads are FREE. Search and download functionalities are using the official Maven repository.

tech.amikos.chromadb.Collection Maven / Gradle / Ivy

There is a newer version: 0.1.6
Show newest version
package tech.amikos.chromadb;

import com.google.gson.Gson;
import com.google.gson.annotations.SerializedName;
import com.google.gson.internal.LinkedTreeMap;
import tech.amikos.chromadb.embeddings.EmbeddingFunction;
import tech.amikos.chromadb.handler.ApiException;
import tech.amikos.chromadb.handler.DefaultApi;
import tech.amikos.chromadb.model.*;

import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

import static java.lang.Thread.sleep;

public class Collection {
    static Gson gson = new Gson();
    DefaultApi api;
    String collectionName;

    String collectionId;

    LinkedTreeMap metadata = new LinkedTreeMap<>();

    private EmbeddingFunction embeddingFunction;

    public Collection(DefaultApi api, String collectionName, EmbeddingFunction embeddingFunction) {
        this.api = api;
        this.collectionName = collectionName;
        this.embeddingFunction = embeddingFunction;

    }

    public String getName() {
        return collectionName;
    }

    public String getId() {
        return collectionId;
    }

    public Map getMetadata() {
        return metadata;
    }

    public Collection fetch() throws ApiException {
        try {
            LinkedTreeMap resp = (LinkedTreeMap) api.getCollection(collectionName);
            this.collectionName = resp.get("name").toString();
            this.collectionId = resp.get("id").toString();
            this.metadata = (LinkedTreeMap) resp.get("metadata");
            return this;
        } catch (ApiException e) {
            throw e;
        }
    }

    public static Collection getInstance(DefaultApi api, String collectionName) throws ApiException {
        return new Collection(api, collectionName, null);
    }

    @Override
    public String toString() {
        return "Collection{" +
                "collectionName='" + collectionName + '\'' +
                ", collectionId='" + collectionId + '\'' +
                ", metadata=" + metadata +
                '}';
    }

    public GetResult get(List ids, Map where, Map whereDocument) throws ApiException {
        GetEmbedding req = new GetEmbedding();
        req.ids(ids).where(where).whereDocument(whereDocument);
        Gson gson = new Gson();
        String json = gson.toJson(api.get(req, this.collectionId));
        return new Gson().fromJson(json, GetResult.class);
    }

    public GetResult get() throws ApiException {
        return this.get(null, null, null);
    }

    public Object delete() throws ApiException {
        return this.delete(null, null, null);
    }

    public Object upsert(List embeddings, List> metadatas, List documents, List ids) throws ChromaException {
        AddEmbedding req = new AddEmbedding();
        List _embeddings = embeddings;
        if (_embeddings == null) {
            _embeddings = this.embeddingFunction.embedDocuments(documents);
        }
        req.setEmbeddings(_embeddings.stream().map(Embedding::asArray).collect(Collectors.toList()));
        req.setMetadatas((List>) (Object) metadatas);
        req.setDocuments(documents);
        req.incrementIndex(true);
        req.setIds(ids);
        try {
            return api.upsert(req, this.collectionId);
        } catch (ApiException e) {
            throw new ChromaException(e);
        }
    }


    public Object add(List embeddings, List> metadatas, List documents, List ids) throws ChromaException {
        AddEmbedding req = new AddEmbedding();
        List _embeddings = embeddings;
        if (_embeddings == null) {
            _embeddings = this.embeddingFunction.embedDocuments(documents);
        }
        req.setEmbeddings(_embeddings.stream().map(Embedding::asArray).collect(Collectors.toList()));
        req.setMetadatas((List>) (Object) metadatas);
        req.setDocuments(documents);
        req.incrementIndex(true);
        req.setIds(ids);
        try {
            return api.add(req, this.collectionId);
        } catch (ApiException e) {
            throw new ChromaException(e);
        }
    }

    public Integer count() throws ApiException {
        return api.count(this.collectionId);
    }

    public Object delete(List ids, Map where, Map whereDocument) throws ApiException {
        DeleteEmbedding req = new DeleteEmbedding();
        req.setIds(ids);
        if (where != null) {
            req.where(where.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)));
        }
        if (whereDocument != null) {
            req.whereDocument(whereDocument.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)));
        }
        return api.delete(req, this.collectionId);
    }

    public Object deleteWithIds(List ids) throws ApiException {
        return delete(ids, null, null);
    }

    public Object deleteWhere(Map where) throws ApiException {
        return delete(null, where, null);
    }

    public Object deleteWhereWhereDocuments(Map where, Map whereDocument) throws ApiException {
        return delete(null, where, whereDocument);
    }

    public Object deleteWhereDocuments(Map whereDocument) throws ApiException {
        return delete(null, null, whereDocument);
    }


    public Object update(String newName, Map newMetadata) throws ApiException {
        UpdateCollection req = new UpdateCollection();
        if (newName != null) {
            req.setNewName(newName);
        }
        if (newMetadata != null && embeddingFunction != null) {
            if (!newMetadata.containsKey("embedding_function")) {
                newMetadata.put("embedding_function", embeddingFunction.getClass().getName());
            }
            req.setNewMetadata(newMetadata);
        }
        Object resp = api.updateCollection(req, this.collectionId);
        this.collectionName = newName;
        this.fetch(); //do we really need to fetch?
        return resp;
    }

    public Object updateEmbeddings(List embeddings, List> metadatas, List documents, List ids) throws ChromaException {
        UpdateEmbedding req = new UpdateEmbedding();
        List _embeddings = embeddings;
        if (_embeddings == null) {
            _embeddings = this.embeddingFunction.embedDocuments(documents);
        }
        req.setEmbeddings(_embeddings.stream().map(Embedding::asArray).collect(Collectors.toList()));
        req.setDocuments(documents);
        req.setMetadatas((List) (Object) metadatas);
        req.setIds(ids);
        try {
            return api.update(req, this.collectionId);
        } catch (ApiException e) {
            throw new ChromaException(e);
        }
    }


    public QueryResponse query(List queryTexts, Integer nResults, Map where, Map whereDocument, List include) throws ChromaException {
        QueryEmbedding body = new QueryEmbedding();
        body.queryEmbeddings(this.embeddingFunction.embedDocuments(queryTexts).stream().map(Embedding::asArray).collect(Collectors.toList()));
        body.nResults(nResults);
        body.include(include);
        if (where != null) {
            body.where(where.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)));
        }
        if (whereDocument != null) {
            body.whereDocument(whereDocument.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)));
        }
        try {
            Gson gson = new Gson();
            String json = gson.toJson(api.getNearestNeighbors(body, this.collectionId));
            return new Gson().fromJson(json, QueryResponse.class);
        } catch (ApiException e) {
            throw new ChromaException(e);
        }
    }

    public static class QueryResponse {
        @SerializedName("documents")
        private List> documents;
        @SerializedName("embeddings")
        private List> embeddings;
        @SerializedName("ids")
        private List> ids;
        @SerializedName("metadatas")
        private List>> metadatas;
        @SerializedName("distances")
        private List> distances;

        public List> getDocuments() {
            return documents;
        }

        public List> getEmbeddings() {
            return embeddings;
        }

        public List> getIds() {
            return ids;
        }

        public List>> getMetadatas() {
            return metadatas;
        }

        public List> getDistances() {
            return distances;
        }

        @Override
        public String toString() {
            return new Gson().toJson(this);
        }


    }

    public static class GetResult {
        @SerializedName("documents")
        private List documents;
        @SerializedName("embeddings")
        private List embeddings;
        @SerializedName("ids")
        private List ids;
        @SerializedName("metadatas")
        private List> metadatas;

        public List getDocuments() {
            return documents;
        }

        public List getEmbeddings() {
            return embeddings;
        }

        public List getIds() {
            return ids;
        }

        public List> getMetadatas() {
            return metadatas;
        }

        @Override
        public String toString() {
            return new Gson().toJson(this);
        }
    }
}