All Downloads are FREE. Search and download functionalities are using the official Maven repository.

prerna.engine.impl.vector.ChromaVectorDatabaseEngine Maven / Gradle / Ivy

The newest version!
package prerna.engine.impl.vector;

import java.io.File;
import java.io.IOException;
import java.nio.file.Paths;
import java.text.SimpleDateFormat;
import java.util.ArrayList;
import java.util.Date;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Properties;

import org.apache.commons.io.FileUtils;
import org.apache.commons.io.FilenameUtils;
import org.apache.http.HttpHeaders;
import org.apache.http.entity.ContentType;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

import com.google.gson.Gson;
import com.google.gson.GsonBuilder;
import com.google.gson.reflect.TypeToken;

import prerna.cluster.util.ClusterUtil;
import prerna.cluster.util.DeleteFilesFromEngineRunner;
import prerna.engine.api.IModelEngine;
import prerna.engine.api.VectorDatabaseTypeEnum;
import prerna.om.Insight;
import prerna.sablecc2.om.execptions.SemossPixelException;
import prerna.security.HttpHelperUtility;
import prerna.util.Constants;
import prerna.util.Utility;

public class ChromaVectorDatabaseEngine extends AbstractVectorDatabaseEngine {

	private static final Logger classLogger = LogManager.getLogger(ChromaVectorDatabaseEngine.class);
	
	public static final String CHROMA_CLASSNAME = "CHROMA_COLLECTION_NAME";
	public static final String COLLECTION_ID = "COLLECTION_ID";

	private final String API_TOKEN_KEY = "X-Chroma-Token";
	
	private final String API_ADD = "/add";
	private final String API_DELETE = "/delete";
	private final String API_QUERY = "/query";
	
	private String url = null;
	private String apiKey = null;
	private String className = null;
	private String collectionID = null;

	@Override
	public void open(Properties smssProp) throws Exception {
		super.open(smssProp);

		this.url = smssProp.getProperty(Constants.HOSTNAME);
		if (!this.url.endsWith("/")) {
			this.url += "/";
		}
		this.apiKey = smssProp.getProperty(Constants.API_KEY);
		this.className = smssProp.getProperty(CHROMA_CLASSNAME);

		// create or fetch collection Id from the Chroma DB
		this.collectionID = createCollection(this.className);
	}

	/**
	 * 
	 * @param collectionName
	 */
	private String createCollection(String collectionName) {
		// check to see if the collection is available
		// if available, get the ID
		// if not create a collection and get the ID
		collectionName = collectionName.replaceAll(" ", "_");
		Gson gson = new GsonBuilder().setPrettyPrinting().create();
		Map headersMap = new HashMap<>();
		if (this.apiKey != null && !this.apiKey.isEmpty()) {
			headersMap.put(API_TOKEN_KEY, this.apiKey);
			headersMap.put(HttpHeaders.CONTENT_TYPE, ContentType.APPLICATION_JSON.getMimeType());
		} else {
			headersMap = null;
		}
		
		String nearestNeigborResponse = null;
		try {
			nearestNeigborResponse = HttpHelperUtility.getRequest(this.url, headersMap, null, null, null);
		} catch(Exception e) {
			classLogger.error("Unable to create connection");
			throw new SemossPixelException("Unable to create connection");
		}
		
		List> responseListMap = gson.fromJson(nearestNeigborResponse, new TypeToken>>() {}.getType());
		for (Map responseMap : responseListMap) {
			if (responseMap.get("name") != null && responseMap.get("name").toString().equals(collectionName)) {
				return (String) responseMap.get("id");
			}
		}

		// if the collection Name doesn't exist, create it and return the ID
		nearestNeigborResponse = null;
		Map collectionNameToCreate = new HashMap<>();
		collectionNameToCreate.put("name", collectionName);
		String body = gson.toJson(collectionNameToCreate);
		nearestNeigborResponse = HttpHelperUtility.postRequestStringBody(this.url, headersMap, body, ContentType.APPLICATION_JSON, null, null, null);
		Map responseMap = gson.fromJson(nearestNeigborResponse, new TypeToken>() {}.getType());
		
		return (String) responseMap.get("id");
	}
	
	@Override
	protected String getDefaultDistanceMethod() {
		return "cosine";
	}
	
	@Override
	public void addEmbeddings(VectorDatabaseCSVTable vectorCsvTable, Insight insight, Map parameters) throws Exception {
		if (!modelPropsLoaded) {
			verifyModelProps();
		}

		if (insight == null) {
			throw new IllegalArgumentException("Insight must be provided to run Model Engine Encoder");
		}

		// if we were able to extract files, begin embeddings process
		IModelEngine embeddingsEngine = Utility.getModel(this.embedderEngineId);
		// send all the strings to embed in one shot
		try {
			vectorCsvTable.generateAndAssignEmbeddings(embeddingsEngine, insight);
		} catch (Exception e) {
			classLogger.error(Constants.STACKTRACE, e);
			throw new IllegalArgumentException("Error occurred creating the embeddings for the generated chunks. Detailed error message = " + e.getMessage());
		}
		
		Map vectors = new HashMap<>();
		List ids = new ArrayList<>();
		List embeddings = new ArrayList<>();
		List> metadatas = new ArrayList<>();

		for (int rowIndex = 0; rowIndex < vectorCsvTable.rows.size(); rowIndex++) {
			VectorDatabaseCSVRow row = vectorCsvTable.getRows().get(rowIndex);
			Map properties = new HashMap<>();
			properties.put("Source", row.getSource());
			properties.put("Modality", row.getModality());
			properties.put("Divider", row.getDivider());
			properties.put("Part", row.getPart());
			properties.put("Tokens", row.getTokens());
			properties.put("Content", row.getContent());

			// Float[] vectorEmbeddings = getEmbeddings(row.getContent(), insight);
			List embedding = row.getEmbeddings();
			Float[] vectorEmbeddings = new Float[embedding.size()];
			for (int vecIndex = 0; vecIndex < vectorEmbeddings.length; vecIndex++) {
				vectorEmbeddings[vecIndex] = embedding.get(vecIndex).floatValue();
			}

			String currentRowID = row.getSource() + "-" + rowIndex;
			ids.add(currentRowID);
			embeddings.add(vectorEmbeddings);
			metadatas.add(properties);
		}

		vectors.put("ids", ids);
		vectors.put("embeddings", embeddings);
		vectors.put("metadatas", metadatas);

		String body = new Gson().toJson(vectors);

		Map headersMap = new HashMap<>();
		if (this.apiKey != null && !this.apiKey.isEmpty()) {
			headersMap.put(API_TOKEN_KEY, this.apiKey);
			headersMap.put(HttpHeaders.CONTENT_TYPE, ContentType.APPLICATION_JSON.getMimeType());
		} else {
			headersMap = null;
		}

		String response = HttpHelperUtility.postRequestStringBody(this.url + this.collectionID + API_ADD, 
				headersMap, body, ContentType.APPLICATION_JSON, null, null, null);
		
		//TODO: let us add validation by looking at the response
	}

	@Override
	public void removeDocument(List fileNames, Map parameters) throws IOException {
		String indexClass = this.defaultIndexClass;
		if (parameters.containsKey("indexClass")) {
			indexClass = (String) parameters.get("indexClass");
		}

		List sourceNames = new ArrayList<>();
    	for(String document : fileNames) {
			String documentName = FilenameUtils.getName(document);
			File f = new File(document);
			if(f.exists() && f.getName().endsWith(".csv")) {
				sourceNames.addAll(VectorDatabaseCSVTable.pullSourceColumn(f));
			} else {
				sourceNames.add(documentName);
			}
    	}
		
		List filesToRemoveFromCloud = new ArrayList();

		// need to get the source names and then delete it based on the names
		for (int fileIndex = 0; fileIndex < sourceNames.size(); fileIndex++) {
			String fileName = fileNames.get(fileIndex);

			// Delete document in ChromaDB using their ID, but to get the ID we need to find
			// the ID of a document first. Check the delete API call params
			// http://localhost:5000/api/v1/collections/{}/delete

			Map fileNamesForDelete = new HashMap<>();
			Map sourceProperty = new HashMap<>();

			// replace spaces with _ since thats how
			// readCSV creates Source Property.
			sourceProperty.put("Source", fileName.replaceAll(" ", "_")); 
																			
			fileNamesForDelete.put("where", sourceProperty);

			String body = new Gson().toJson(fileNamesForDelete);

			Map headersMap = new HashMap<>();
			if (this.apiKey != null && !this.apiKey.isEmpty()) {
				headersMap.put(API_TOKEN_KEY, this.apiKey);
				headersMap.put(HttpHeaders.CONTENT_TYPE, ContentType.APPLICATION_JSON.getMimeType());
			} else {
				headersMap = null;
			}

			String response = HttpHelperUtility.postRequestStringBody(this.url + this.collectionID + API_DELETE,
					headersMap, body, ContentType.APPLICATION_JSON, null, null, null);

			//TODO: let us add validation by looking at the response			
			
			String documentName = Paths.get(fileName).getFileName().toString();
			// remove the physical documents
			File documentFile = new File(
					this.schemaFolder.getAbsolutePath() + DIR_SEPARATOR + indexClass + DIR_SEPARATOR + "documents",
					documentName);
			try {
				if (documentFile.exists()) {
					FileUtils.forceDelete(documentFile);
					filesToRemoveFromCloud.add(documentFile.getAbsolutePath());
				}
			} catch (IOException e) {
				classLogger.error(Constants.STACKTRACE, e);
			}

		}

		if (ClusterUtil.IS_CLUSTER) {
			Thread deleteFilesFromCloudThread = new Thread(new DeleteFilesFromEngineRunner(engineId,
					this.getCatalogType(), filesToRemoveFromCloud.stream().toArray(String[]::new)));
			deleteFilesFromCloudThread.start();
		}
	}

	@Override
	public List> nearestNeighborCall(Insight insight, String searchStatement, Number limit, Map  parameters) {
		if (insight == null) {
			throw new IllegalArgumentException("Insight must be provided to run Model Engine Encoder");
		}
		if (!modelPropsLoaded) {
			verifyModelProps();
		}
		if (limit == null) {
			limit = 3;
		}
		
		Gson gson = new Gson();

		List vector = getEmbeddingsDouble(searchStatement, insight);
		Map query = new HashMap<>();
		List> queryEmbeddings = new ArrayList<>();
		// this is done to put a list of embeddings inside another list otherwise the
		// API throws error.
		queryEmbeddings.add(vector); 
										
		// List> metadatas = new ArrayList<>(); add metadata filter
		query.put("query_texts", searchStatement);
		query.put("n_results", limit);
		query.put("query_embeddings", queryEmbeddings);
		String body = gson.toJson(query);

		Map headersMap = new HashMap<>();
		if (this.apiKey != null && !this.apiKey.isEmpty()) {
			headersMap.put(API_TOKEN_KEY, this.apiKey);
			headersMap.put(HttpHeaders.CONTENT_TYPE, ContentType.APPLICATION_JSON.getMimeType());
		} else {
			headersMap = null;
		}
		
		String nearestNeigborResponse = HttpHelperUtility.postRequestStringBody(this.url + this.collectionID + API_QUERY,
				headersMap, body, ContentType.APPLICATION_JSON, null, null, null);

		Map responseMap = gson.fromJson(nearestNeigborResponse, new TypeToken>() {}.getType());
		
		// Retrieve the metadatas list response
		List> resultMap = (List>) responseMap.get("metadatas");
		return (List>) resultMap.get(0);
	}
	
	@Override
	public List> listDocuments(Map parameters) {
		//TODO: needs to grab 'Source' from the database
		//TODO: needs to grab 'Source' from the database
		//TODO: needs to grab 'Source' from the database
		//TODO: needs to grab 'Source' from the database
		//TODO: needs to grab 'Source' from the database
		//TODO: needs to grab 'Source' from the database
		
		String indexClass = this.defaultIndexClass;
		if (parameters.containsKey("indexClass")) {
			indexClass = (String) parameters.get("indexClass");
		}

		File documentsDir = new File(this.schemaFolder.getAbsolutePath() + DIR_SEPARATOR + indexClass + DIR_SEPARATOR + DOCUMENTS_FOLDER_NAME);

		List> fileList = new ArrayList<>();

		File[] files = documentsDir.listFiles();
		if (files != null) {
			for (File file : files) {
				String fileName = file.getName();
				long fileSizeInBytes = file.length();
				double fileSizeInMB = (double) fileSizeInBytes / (1024);
				SimpleDateFormat dateFormat = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss");
				String lastModified = dateFormat.format(new Date(file.lastModified()));

				Map fileInfo = new HashMap<>();
				fileInfo.put("fileName", fileName);
				fileInfo.put("fileSize", fileSizeInMB);
				fileInfo.put("lastModified", lastModified);
				fileList.add(fileInfo);
			}
		} 

		return fileList;
	}
	
	@Override
	public List> listAllRecords(Map parameters) {
		throw new IllegalArgumentException("This method has not been implemented yet");
	}
	
	@Override
	public VectorDatabaseTypeEnum getVectorDatabaseType() {
		return VectorDatabaseTypeEnum.CHROMA;
	}

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy