All Downloads are FREE. Search and download functionalities are using the official Maven repository.
Please wait. This can take some minutes ...
Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance.
Project price only 1 $
You can buy this project and download/modify it how often you want.
tech.amikos.chromadb.Collection Maven / Gradle / Ivy
package tech.amikos.chromadb;
import com.google.gson.Gson;
import com.google.gson.annotations.SerializedName;
import com.google.gson.internal.LinkedTreeMap;
import tech.amikos.chromadb.embeddings.EmbeddingFunction;
import tech.amikos.chromadb.handler.ApiException;
import tech.amikos.chromadb.handler.DefaultApi;
import tech.amikos.chromadb.model.*;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import static java.lang.Thread.sleep;
public class Collection {
static Gson gson = new Gson();
DefaultApi api;
String collectionName;
String collectionId;
LinkedTreeMap metadata = new LinkedTreeMap<>();
private EmbeddingFunction embeddingFunction;
public Collection(DefaultApi api, String collectionName, EmbeddingFunction embeddingFunction) {
this.api = api;
this.collectionName = collectionName;
this.embeddingFunction = embeddingFunction;
}
public String getName() {
return collectionName;
}
public String getId() {
return collectionId;
}
public Map getMetadata() {
return metadata;
}
public Collection fetch() throws ApiException {
try {
LinkedTreeMap resp = (LinkedTreeMap) api.getCollection(collectionName);
this.collectionName = resp.get("name").toString();
this.collectionId = resp.get("id").toString();
this.metadata = (LinkedTreeMap) resp.get("metadata");
return this;
} catch (ApiException e) {
throw e;
}
}
public static Collection getInstance(DefaultApi api, String collectionName) throws ApiException {
return new Collection(api, collectionName, null);
}
@Override
public String toString() {
return "Collection{" +
"collectionName='" + collectionName + '\'' +
", collectionId='" + collectionId + '\'' +
", metadata=" + metadata +
'}';
}
public GetResult get(List ids, Map where, Map whereDocument) throws ApiException {
GetEmbedding req = new GetEmbedding();
req.ids(ids).where(where).whereDocument(whereDocument);
Gson gson = new Gson();
String json = gson.toJson(api.get(req, this.collectionId));
return new Gson().fromJson(json, GetResult.class);
}
public GetResult get() throws ApiException {
return this.get(null, null, null);
}
public Object delete() throws ApiException {
return this.delete(null, null, null);
}
public Object upsert(List embeddings, List> metadatas, List documents, List ids) throws ChromaException {
AddEmbedding req = new AddEmbedding();
List _embeddings = embeddings;
if (_embeddings == null) {
_embeddings = this.embeddingFunction.embedDocuments(documents);
}
req.setEmbeddings(_embeddings.stream().map(Embedding::asArray).collect(Collectors.toList()));
req.setMetadatas((List>) (Object) metadatas);
req.setDocuments(documents);
req.incrementIndex(true);
req.setIds(ids);
try {
return api.upsert(req, this.collectionId);
} catch (ApiException e) {
throw new ChromaException(e);
}
}
public Object add(List embeddings, List> metadatas, List documents, List ids) throws ChromaException {
AddEmbedding req = new AddEmbedding();
List _embeddings = embeddings;
if (_embeddings == null) {
_embeddings = this.embeddingFunction.embedDocuments(documents);
}
req.setEmbeddings(_embeddings.stream().map(Embedding::asArray).collect(Collectors.toList()));
req.setMetadatas((List>) (Object) metadatas);
req.setDocuments(documents);
req.incrementIndex(true);
req.setIds(ids);
try {
return api.add(req, this.collectionId);
} catch (ApiException e) {
throw new ChromaException(e);
}
}
public Integer count() throws ApiException {
return api.count(this.collectionId);
}
public Object delete(List ids, Map where, Map whereDocument) throws ApiException {
DeleteEmbedding req = new DeleteEmbedding();
req.setIds(ids);
if (where != null) {
req.where(where.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)));
}
if (whereDocument != null) {
req.whereDocument(whereDocument.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)));
}
return api.delete(req, this.collectionId);
}
public Object deleteWithIds(List ids) throws ApiException {
return delete(ids, null, null);
}
public Object deleteWhere(Map where) throws ApiException {
return delete(null, where, null);
}
public Object deleteWhereWhereDocuments(Map where, Map whereDocument) throws ApiException {
return delete(null, where, whereDocument);
}
public Object deleteWhereDocuments(Map whereDocument) throws ApiException {
return delete(null, null, whereDocument);
}
public Object update(String newName, Map newMetadata) throws ApiException {
UpdateCollection req = new UpdateCollection();
if (newName != null) {
req.setNewName(newName);
}
if (newMetadata != null && embeddingFunction != null) {
if (!newMetadata.containsKey("embedding_function")) {
newMetadata.put("embedding_function", embeddingFunction.getClass().getName());
}
req.setNewMetadata(newMetadata);
}
Object resp = api.updateCollection(req, this.collectionId);
this.collectionName = newName;
this.fetch(); //do we really need to fetch?
return resp;
}
public Object updateEmbeddings(List embeddings, List> metadatas, List documents, List ids) throws ChromaException {
UpdateEmbedding req = new UpdateEmbedding();
List _embeddings = embeddings;
if (_embeddings == null) {
_embeddings = this.embeddingFunction.embedDocuments(documents);
}
req.setEmbeddings(_embeddings.stream().map(Embedding::asArray).collect(Collectors.toList()));
req.setDocuments(documents);
req.setMetadatas((List) (Object) metadatas);
req.setIds(ids);
try {
return api.update(req, this.collectionId);
} catch (ApiException e) {
throw new ChromaException(e);
}
}
public QueryResponse query(List queryTexts, Integer nResults, Map where, Map whereDocument, List include) throws ChromaException {
QueryEmbedding body = new QueryEmbedding();
body.queryEmbeddings(this.embeddingFunction.embedDocuments(queryTexts).stream().map(Embedding::asArray).collect(Collectors.toList()));
body.nResults(nResults);
body.include(include);
if (where != null) {
body.where(where.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)));
}
if (whereDocument != null) {
body.whereDocument(whereDocument.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)));
}
try {
Gson gson = new Gson();
String json = gson.toJson(api.getNearestNeighbors(body, this.collectionId));
return new Gson().fromJson(json, QueryResponse.class);
} catch (ApiException e) {
throw new ChromaException(e);
}
}
public static class QueryResponse {
@SerializedName("documents")
private List> documents;
@SerializedName("embeddings")
private List> embeddings;
@SerializedName("ids")
private List> ids;
@SerializedName("metadatas")
private List>> metadatas;
@SerializedName("distances")
private List> distances;
public List> getDocuments() {
return documents;
}
public List> getEmbeddings() {
return embeddings;
}
public List> getIds() {
return ids;
}
public List>> getMetadatas() {
return metadatas;
}
public List> getDistances() {
return distances;
}
@Override
public String toString() {
return new Gson().toJson(this);
}
}
public static class GetResult {
@SerializedName("documents")
private List documents;
@SerializedName("embeddings")
private List embeddings;
@SerializedName("ids")
private List ids;
@SerializedName("metadatas")
private List> metadatas;
public List getDocuments() {
return documents;
}
public List getEmbeddings() {
return embeddings;
}
public List getIds() {
return ids;
}
public List> getMetadatas() {
return metadatas;
}
@Override
public String toString() {
return new Gson().toJson(this);
}
}
}