Please wait. This can take some minutes ...
Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance.
Project price only 1 $
You can buy this project and download/modify it how often you want.
prerna.engine.impl.vector.ChromaVectorDatabaseEngine Maven / Gradle / Ivy
package prerna.engine.impl.vector;
import java.io.File;
import java.io.IOException;
import java.nio.file.Paths;
import java.text.SimpleDateFormat;
import java.util.ArrayList;
import java.util.Date;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.FilenameUtils;
import org.apache.http.HttpHeaders;
import org.apache.http.entity.ContentType;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import com.google.gson.Gson;
import com.google.gson.GsonBuilder;
import com.google.gson.reflect.TypeToken;
import prerna.cluster.util.ClusterUtil;
import prerna.cluster.util.DeleteFilesFromEngineRunner;
import prerna.engine.api.IModelEngine;
import prerna.engine.api.VectorDatabaseTypeEnum;
import prerna.om.Insight;
import prerna.sablecc2.om.execptions.SemossPixelException;
import prerna.security.HttpHelperUtility;
import prerna.util.Constants;
import prerna.util.Utility;
public class ChromaVectorDatabaseEngine extends AbstractVectorDatabaseEngine {
private static final Logger classLogger = LogManager.getLogger(ChromaVectorDatabaseEngine.class);
public static final String CHROMA_CLASSNAME = "CHROMA_COLLECTION_NAME";
public static final String COLLECTION_ID = "COLLECTION_ID";
private final String API_TOKEN_KEY = "X-Chroma-Token";
private final String API_ADD = "/add";
private final String API_DELETE = "/delete";
private final String API_QUERY = "/query";
private String url = null;
private String apiKey = null;
private String className = null;
private String collectionID = null;
@Override
public void open(Properties smssProp) throws Exception {
super.open(smssProp);
this.url = smssProp.getProperty(Constants.HOSTNAME);
if (!this.url.endsWith("/")) {
this.url += "/";
}
this.apiKey = smssProp.getProperty(Constants.API_KEY);
this.className = smssProp.getProperty(CHROMA_CLASSNAME);
// create or fetch collection Id from the Chroma DB
this.collectionID = createCollection(this.className);
}
/**
*
* @param collectionName
*/
private String createCollection(String collectionName) {
// check to see if the collection is available
// if available, get the ID
// if not create a collection and get the ID
collectionName = collectionName.replaceAll(" ", "_");
Gson gson = new GsonBuilder().setPrettyPrinting().create();
Map headersMap = new HashMap<>();
if (this.apiKey != null && !this.apiKey.isEmpty()) {
headersMap.put(API_TOKEN_KEY, this.apiKey);
headersMap.put(HttpHeaders.CONTENT_TYPE, ContentType.APPLICATION_JSON.getMimeType());
} else {
headersMap = null;
}
String nearestNeigborResponse = null;
try {
nearestNeigborResponse = HttpHelperUtility.getRequest(this.url, headersMap, null, null, null);
} catch(Exception e) {
classLogger.error("Unable to create connection");
throw new SemossPixelException("Unable to create connection");
}
List> responseListMap = gson.fromJson(nearestNeigborResponse, new TypeToken>>() {}.getType());
for (Map responseMap : responseListMap) {
if (responseMap.get("name") != null && responseMap.get("name").toString().equals(collectionName)) {
return (String) responseMap.get("id");
}
}
// if the collection Name doesn't exist, create it and return the ID
nearestNeigborResponse = null;
Map collectionNameToCreate = new HashMap<>();
collectionNameToCreate.put("name", collectionName);
String body = gson.toJson(collectionNameToCreate);
nearestNeigborResponse = HttpHelperUtility.postRequestStringBody(this.url, headersMap, body, ContentType.APPLICATION_JSON, null, null, null);
Map responseMap = gson.fromJson(nearestNeigborResponse, new TypeToken>() {}.getType());
return (String) responseMap.get("id");
}
@Override
protected String getDefaultDistanceMethod() {
return "cosine";
}
@Override
public void addEmbeddings(VectorDatabaseCSVTable vectorCsvTable, Insight insight, Map parameters) throws Exception {
if (!modelPropsLoaded) {
verifyModelProps();
}
if (insight == null) {
throw new IllegalArgumentException("Insight must be provided to run Model Engine Encoder");
}
// if we were able to extract files, begin embeddings process
IModelEngine embeddingsEngine = Utility.getModel(this.embedderEngineId);
// send all the strings to embed in one shot
try {
vectorCsvTable.generateAndAssignEmbeddings(embeddingsEngine, insight);
} catch (Exception e) {
classLogger.error(Constants.STACKTRACE, e);
throw new IllegalArgumentException("Error occurred creating the embeddings for the generated chunks. Detailed error message = " + e.getMessage());
}
Map vectors = new HashMap<>();
List ids = new ArrayList<>();
List embeddings = new ArrayList<>();
List> metadatas = new ArrayList<>();
for (int rowIndex = 0; rowIndex < vectorCsvTable.rows.size(); rowIndex++) {
VectorDatabaseCSVRow row = vectorCsvTable.getRows().get(rowIndex);
Map properties = new HashMap<>();
properties.put("Source", row.getSource());
properties.put("Modality", row.getModality());
properties.put("Divider", row.getDivider());
properties.put("Part", row.getPart());
properties.put("Tokens", row.getTokens());
properties.put("Content", row.getContent());
// Float[] vectorEmbeddings = getEmbeddings(row.getContent(), insight);
List extends Number> embedding = row.getEmbeddings();
Float[] vectorEmbeddings = new Float[embedding.size()];
for (int vecIndex = 0; vecIndex < vectorEmbeddings.length; vecIndex++) {
vectorEmbeddings[vecIndex] = embedding.get(vecIndex).floatValue();
}
String currentRowID = row.getSource() + "-" + rowIndex;
ids.add(currentRowID);
embeddings.add(vectorEmbeddings);
metadatas.add(properties);
}
vectors.put("ids", ids);
vectors.put("embeddings", embeddings);
vectors.put("metadatas", metadatas);
String body = new Gson().toJson(vectors);
Map headersMap = new HashMap<>();
if (this.apiKey != null && !this.apiKey.isEmpty()) {
headersMap.put(API_TOKEN_KEY, this.apiKey);
headersMap.put(HttpHeaders.CONTENT_TYPE, ContentType.APPLICATION_JSON.getMimeType());
} else {
headersMap = null;
}
String response = HttpHelperUtility.postRequestStringBody(this.url + this.collectionID + API_ADD,
headersMap, body, ContentType.APPLICATION_JSON, null, null, null);
//TODO: let us add validation by looking at the response
}
@Override
public void removeDocument(List fileNames, Map parameters) throws IOException {
String indexClass = this.defaultIndexClass;
if (parameters.containsKey("indexClass")) {
indexClass = (String) parameters.get("indexClass");
}
List sourceNames = new ArrayList<>();
for(String document : fileNames) {
String documentName = FilenameUtils.getName(document);
File f = new File(document);
if(f.exists() && f.getName().endsWith(".csv")) {
sourceNames.addAll(VectorDatabaseCSVTable.pullSourceColumn(f));
} else {
sourceNames.add(documentName);
}
}
List filesToRemoveFromCloud = new ArrayList();
// need to get the source names and then delete it based on the names
for (int fileIndex = 0; fileIndex < sourceNames.size(); fileIndex++) {
String fileName = fileNames.get(fileIndex);
// Delete document in ChromaDB using their ID, but to get the ID we need to find
// the ID of a document first. Check the delete API call params
// http://localhost:5000/api/v1/collections/{}/delete
Map fileNamesForDelete = new HashMap<>();
Map sourceProperty = new HashMap<>();
// replace spaces with _ since thats how
// readCSV creates Source Property.
sourceProperty.put("Source", fileName.replaceAll(" ", "_"));
fileNamesForDelete.put("where", sourceProperty);
String body = new Gson().toJson(fileNamesForDelete);
Map headersMap = new HashMap<>();
if (this.apiKey != null && !this.apiKey.isEmpty()) {
headersMap.put(API_TOKEN_KEY, this.apiKey);
headersMap.put(HttpHeaders.CONTENT_TYPE, ContentType.APPLICATION_JSON.getMimeType());
} else {
headersMap = null;
}
String response = HttpHelperUtility.postRequestStringBody(this.url + this.collectionID + API_DELETE,
headersMap, body, ContentType.APPLICATION_JSON, null, null, null);
//TODO: let us add validation by looking at the response
String documentName = Paths.get(fileName).getFileName().toString();
// remove the physical documents
File documentFile = new File(
this.schemaFolder.getAbsolutePath() + DIR_SEPARATOR + indexClass + DIR_SEPARATOR + "documents",
documentName);
try {
if (documentFile.exists()) {
FileUtils.forceDelete(documentFile);
filesToRemoveFromCloud.add(documentFile.getAbsolutePath());
}
} catch (IOException e) {
classLogger.error(Constants.STACKTRACE, e);
}
}
if (ClusterUtil.IS_CLUSTER) {
Thread deleteFilesFromCloudThread = new Thread(new DeleteFilesFromEngineRunner(engineId,
this.getCatalogType(), filesToRemoveFromCloud.stream().toArray(String[]::new)));
deleteFilesFromCloudThread.start();
}
}
@Override
public List> nearestNeighborCall(Insight insight, String searchStatement, Number limit, Map parameters) {
if (insight == null) {
throw new IllegalArgumentException("Insight must be provided to run Model Engine Encoder");
}
if (!modelPropsLoaded) {
verifyModelProps();
}
if (limit == null) {
limit = 3;
}
Gson gson = new Gson();
List vector = getEmbeddingsDouble(searchStatement, insight);
Map query = new HashMap<>();
List> queryEmbeddings = new ArrayList<>();
// this is done to put a list of embeddings inside another list otherwise the
// API throws error.
queryEmbeddings.add(vector);
// List> metadatas = new ArrayList<>(); add metadata filter
query.put("query_texts", searchStatement);
query.put("n_results", limit);
query.put("query_embeddings", queryEmbeddings);
String body = gson.toJson(query);
Map headersMap = new HashMap<>();
if (this.apiKey != null && !this.apiKey.isEmpty()) {
headersMap.put(API_TOKEN_KEY, this.apiKey);
headersMap.put(HttpHeaders.CONTENT_TYPE, ContentType.APPLICATION_JSON.getMimeType());
} else {
headersMap = null;
}
String nearestNeigborResponse = HttpHelperUtility.postRequestStringBody(this.url + this.collectionID + API_QUERY,
headersMap, body, ContentType.APPLICATION_JSON, null, null, null);
Map responseMap = gson.fromJson(nearestNeigborResponse, new TypeToken>() {}.getType());
// Retrieve the metadatas list response
List> resultMap = (List>) responseMap.get("metadatas");
return (List>) resultMap.get(0);
}
@Override
public List> listDocuments(Map parameters) {
//TODO: needs to grab 'Source' from the database
//TODO: needs to grab 'Source' from the database
//TODO: needs to grab 'Source' from the database
//TODO: needs to grab 'Source' from the database
//TODO: needs to grab 'Source' from the database
//TODO: needs to grab 'Source' from the database
String indexClass = this.defaultIndexClass;
if (parameters.containsKey("indexClass")) {
indexClass = (String) parameters.get("indexClass");
}
File documentsDir = new File(this.schemaFolder.getAbsolutePath() + DIR_SEPARATOR + indexClass + DIR_SEPARATOR + DOCUMENTS_FOLDER_NAME);
List> fileList = new ArrayList<>();
File[] files = documentsDir.listFiles();
if (files != null) {
for (File file : files) {
String fileName = file.getName();
long fileSizeInBytes = file.length();
double fileSizeInMB = (double) fileSizeInBytes / (1024);
SimpleDateFormat dateFormat = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss");
String lastModified = dateFormat.format(new Date(file.lastModified()));
Map fileInfo = new HashMap<>();
fileInfo.put("fileName", fileName);
fileInfo.put("fileSize", fileSizeInMB);
fileInfo.put("lastModified", lastModified);
fileList.add(fileInfo);
}
}
return fileList;
}
@Override
public List> listAllRecords(Map parameters) {
throw new IllegalArgumentException("This method has not been implemented yet");
}
@Override
public VectorDatabaseTypeEnum getVectorDatabaseType() {
return VectorDatabaseTypeEnum.CHROMA;
}
}