Please wait. This can take some minutes ...
Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance.
Project price only 1 $
You can buy this project and download/modify it how often you want.
prerna.engine.impl.vector.PineConeVectorDatabaseEngine Maven / Gradle / Ivy
package prerna.engine.impl.vector;
import java.io.File;
import java.io.IOException;
import java.nio.file.Paths;
import java.text.SimpleDateFormat;
import java.util.ArrayList;
import java.util.Date;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.Set;
import java.util.TreeSet;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.FilenameUtils;
import org.apache.http.HttpHeaders;
import org.apache.http.entity.ContentType;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.google.gson.Gson;
import com.google.gson.GsonBuilder;
import com.google.gson.JsonArray;
import com.google.gson.JsonObject;
import com.google.gson.JsonParser;
import com.google.gson.reflect.TypeToken;
import prerna.cluster.util.ClusterUtil;
import prerna.cluster.util.DeleteFilesFromEngineRunner;
import prerna.engine.api.IModelEngine;
import prerna.engine.api.VectorDatabaseTypeEnum;
import prerna.om.Insight;
import prerna.security.HttpHelperUtility;
import prerna.util.Constants;
import prerna.util.Utility;
public class PineConeVectorDatabaseEngine extends AbstractVectorDatabaseEngine {
private static final Logger classLogger = LoggerFactory.getLogger(PineConeVectorDatabaseEngine.class);
private final String NAMESPACE = "NAMESPACE";
private final String API_UPSERT = "/vectors/upsert";
private final String API_DELETE = "/vectors/delete";
private final String API_QUERY = "/query";
private final String API_KY= "Api-Key";
private final String LIST_QUERY = "/vectors/list?namespace=";
private final String FETCH_QUERY = "/vectors/fetch?namespace=";
private final String HASH = "#";
private final String PREFIX = "&prefix=";
private final String PAGINATION_TOKEN = "&paginationToken=";
private String hostname = null;
private String apiKey = null;
private String defaultNamespace = null;
@Override
public void open(Properties smssProp) throws Exception {
super.open(smssProp);
this.apiKey = smssProp.getProperty(Constants.API_KEY);
if (this.apiKey == null || (this.apiKey = this.apiKey.trim()).isEmpty()) {
throw new IllegalArgumentException("Must define the api key");
}
this.hostname = smssProp.getProperty(Constants.HOSTNAME);
this.defaultNamespace = this.smssProp.getProperty(NAMESPACE);
}
@Override
protected String getDefaultDistanceMethod() {
// this is stored on the index itself
// since we dont create it for this engine - this doesn't matter ...
return "cosine";
}
@Override
public void addEmbeddings(VectorDatabaseCSVTable vectorCsvTable, Insight insight, Map parameters) throws Exception {
if (!modelPropsLoaded) {
verifyModelProps();
}
if (insight == null) {
throw new IllegalArgumentException("Insight must be provided to run Model Engine Encoder");
}
IModelEngine embeddingsEngine = Utility.getModel(this.embedderEngineId);
// send all the strings to embed in one shot
try {
vectorCsvTable.generateAndAssignEmbeddings(embeddingsEngine, insight);
} catch (Exception e) {
classLogger.error(Constants.STACKTRACE, e);
throw new IllegalArgumentException("Error occurred creating the embeddings for the generated chunks. Detailed error message = " + e.getMessage());
}
// Sample URL:
// https://docs-quickstart-index3-fiarr5p.svc.aped-4627-b74a.pinecone.io/vectors/upsert;
String url = this.hostname + API_UPSERT;
Map headersMap = new HashMap<>();
headersMap.put(API_KY, this.apiKey);
headersMap.put(HttpHeaders.CONTENT_TYPE, ContentType.APPLICATION_JSON.getMimeType());
JsonArray vectors = new JsonArray();
// loop through and make the giant json
int fileCounter = 0;
String previousFileName = null;
for (int rowIndex = 0; rowIndex < vectorCsvTable.rows.size(); rowIndex++) {
VectorDatabaseCSVRow row = vectorCsvTable.getRows().get(rowIndex);
JsonObject metadataJson = new JsonObject();
metadataJson.addProperty(VectorDatabaseCSVTable.SOURCE, row.getSource());
metadataJson.addProperty(VectorDatabaseCSVTable.MODALITY, row.getModality());
metadataJson.addProperty(VectorDatabaseCSVTable.DIVIDER, row.getDivider());
metadataJson.addProperty(VectorDatabaseCSVTable.PART, row.getPart());
metadataJson.addProperty(VectorDatabaseCSVTable.TOKENS, row.getTokens());
metadataJson.addProperty(VectorDatabaseCSVTable.CONTENT, row.getContent());
List vector = getEmbeddingsDouble(row.getContent(), insight);
if (row.getSource().equals(previousFileName)) {
fileCounter = 0;
}
JsonObject thisChunkJson = new JsonObject();
thisChunkJson.addProperty("id", row.getSource().replaceAll(" ", "_") + "-" + fileCounter++);
JsonArray thisEmbeddingVector = new JsonArray();
for(Double d : vector) {
thisEmbeddingVector.add(d);
}
thisChunkJson.add("values", thisEmbeddingVector);
thisChunkJson.add("metadata", metadataJson);
vectors.add(thisChunkJson);
}
JsonObject vectorsMap = new JsonObject();
vectorsMap.addProperty("namespace", this.defaultNamespace);
vectorsMap.add("vectors", vectors);
HttpHelperUtility.postRequestStringBody(url, headersMap, vectorsMap.toString(), ContentType.APPLICATION_JSON, null, null, null);
}
@Override
public void removeDocument(List fileNames, Map parameters) throws IOException {
String indexClass = this.defaultIndexClass;
if (parameters.containsKey("indexClass")) {
indexClass = (String) parameters.get("indexClass");
}
List sourceNames = new ArrayList<>();
for(String document : fileNames) {
String documentName = FilenameUtils.getName(document);
File f = new File(document);
if(f.exists() && f.getName().endsWith(".csv")) {
sourceNames.addAll(VectorDatabaseCSVTable.pullSourceColumn(f));
} else {
sourceNames.add(documentName);
}
}
Gson gson = new GsonBuilder().create();
List filesToRemoveFromCloud = new ArrayList();
Map headersMap = new HashMap<>();
headersMap.put(API_KY, this.apiKey);
headersMap.put(HttpHeaders.CONTENT_TYPE, ContentType.APPLICATION_JSON.getMimeType());
// need to get the source names and then delete it based on the names
for (int fileIndex = 0; fileIndex < sourceNames.size(); fileIndex++) {
String fileName = fileNames.get(fileIndex);
boolean firstExecution = true;
String paginationToken = null;
while(firstExecution || paginationToken!=null) {
String listVectorsUrl = this.hostname + LIST_QUERY + this.defaultNamespace;
if(paginationToken!=null) {
listVectorsUrl+=PAGINATION_TOKEN+paginationToken;
} else {
listVectorsUrl+=PREFIX + fileName.replaceAll(" ", "_") + HASH;
}
String idListResponse = HttpHelperUtility.getRequest(listVectorsUrl, headersMap, null, null, null);
Map responseMap = gson.fromJson(idListResponse, new TypeToken>() {}.getType());
List> vectors = (List>) responseMap.get("vectors");
executeDelete(this.hostname + API_DELETE, headersMap, vectors);
// we can only pull 100 at a time
// we need to check if there is pagination to keep going
Map paginationMap = (Map) responseMap.get("pagination");
if(paginationMap != null && !paginationMap.isEmpty()) {
paginationToken = paginationMap.get("next");
} else {
paginationToken = null;
}
firstExecution=false;
}
String documentName = Paths.get(fileName).getFileName().toString();
// remove the physical documents
File documentFile = new File(this.schemaFolder.getAbsolutePath() + DIR_SEPARATOR + indexClass + DIR_SEPARATOR + "documents", documentName);
try {
if (documentFile.exists()) {
FileUtils.forceDelete(documentFile);
filesToRemoveFromCloud.add(documentFile.getAbsolutePath());
}
} catch (IOException e) {
classLogger.error(Constants.STACKTRACE, e);
}
}
if (ClusterUtil.IS_CLUSTER) {
Thread deleteFilesFromCloudThread = new Thread(new DeleteFilesFromEngineRunner(engineId, this.getCatalogType(), filesToRemoveFromCloud.stream().toArray(String[]::new)));
deleteFilesFromCloudThread.start();
}
}
/**
*
* @param url
* @param headersMap
* @param listReturnVector
*/
private void executeDelete(String url, Map headersMap, List> listReturnVector) {
if(listReturnVector == null || listReturnVector.isEmpty()) {
return;
}
JsonArray idsJsonArray = new JsonArray();
for (Map v : listReturnVector) {
idsJsonArray.add(v.get("id"));
}
JsonObject deleteJson = new JsonObject();
deleteJson.add("ids", idsJsonArray);
deleteJson.addProperty("namespace", this.defaultNamespace);
HttpHelperUtility.postRequestStringBody(url, headersMap, deleteJson.toString(), ContentType.APPLICATION_JSON, null, null, null);
}
@Override
public List> nearestNeighborCall(Insight insight, String searchStatement, Number limit, Map parameters) {
if (insight == null) {
throw new IllegalArgumentException("Insight must be provided to run Model Engine Encoder");
}
if (limit == null) {
limit = 3;
}
if (!modelPropsLoaded) {
verifyModelProps();
}
String url = this.hostname + API_QUERY;
JsonObject queryJson = new JsonObject();
List vector = getEmbeddingsDouble(searchStatement, insight);
JsonArray embeddingsJsonArr = new JsonArray();
for (int i = 0; i < vector.size(); i++) {
embeddingsJsonArr.add(vector.get(i));
}
queryJson.addProperty("topK", limit);
queryJson.addProperty("includeMetadata", true);
queryJson.addProperty("includeValues", true);
queryJson.addProperty("namespace", this.defaultNamespace);
queryJson.add("vector", embeddingsJsonArr);
Map headersMap = new HashMap<>();
headersMap.put(API_KY, this.apiKey);
headersMap.put(HttpHeaders.CONTENT_TYPE, ContentType.APPLICATION_JSON.getMimeType());
String nearestNeigborResponse = HttpHelperUtility.postRequestStringBody(url, headersMap, queryJson.toString(), ContentType.APPLICATION_JSON, null, null, null);
Gson gson = new Gson();
Map responseMap = gson.fromJson(nearestNeigborResponse, new TypeToken>() {}.getType());
List> matches = (List>) responseMap.get("matches");
List> retOut = new ArrayList<>();
for (int i = 0; i < matches.size(); i++) {
Map thisMatch = matches.get(i);
Map metadataMap = (Map) thisMatch.get("metadata");
Map resultMap = new HashMap<>();
resultMap.put("Id", matches.get(i).get("id"));
resultMap.put("Score", matches.get(i).get("score"));
resultMap.put(VectorDatabaseCSVTable.SOURCE, metadataMap.get(VectorDatabaseCSVTable.SOURCE));
resultMap.put(VectorDatabaseCSVTable.CONTENT, metadataMap.get(VectorDatabaseCSVTable.CONTENT));
resultMap.put(VectorDatabaseCSVTable.DIVIDER, metadataMap.get(VectorDatabaseCSVTable.DIVIDER));
resultMap.put(VectorDatabaseCSVTable.MODALITY, metadataMap.get(VectorDatabaseCSVTable.MODALITY));
resultMap.put(VectorDatabaseCSVTable.PART, metadataMap.get(VectorDatabaseCSVTable.PART));
resultMap.put(VectorDatabaseCSVTable.TOKENS, metadataMap.get(VectorDatabaseCSVTable.TOKENS));
retOut.add(resultMap);
}
return retOut;
}
@Override
public List> listDocuments(Map parameters) {
/**
* Pinecone's API is not useful
* https://docs.pinecone.io/reference/api/2025-01/data-plane/list
*
* dont have a way to fetch results w/o first listing the index's
* and after listing the index's the fetch cant limit the return so
* will also get back the vectors
*
*/
String indexClass = this.defaultIndexClass;
if (parameters.containsKey(INDEX_CLASS)) {
indexClass = (String) parameters.get(INDEX_CLASS);
}
Gson gson = new GsonBuilder().create();
Map headersMap = new HashMap<>();
headersMap.put(API_KY, this.apiKey);
headersMap.put(HttpHeaders.CONTENT_TYPE, ContentType.APPLICATION_JSON.getMimeType());
Set sources = new TreeSet<>();
boolean firstExecution = true;
String paginationToken = null;
WHILE_LOOP : while(firstExecution || paginationToken!=null) {
String listVectorsUrl = this.hostname + LIST_QUERY + this.defaultNamespace;
if(paginationToken!=null) {
listVectorsUrl+=PAGINATION_TOKEN+paginationToken;
}
String idListResponse = HttpHelperUtility.getRequest(listVectorsUrl, headersMap, null, null, null);
Map responseMap = gson.fromJson(idListResponse, new TypeToken>() {}.getType());
List> vectors = (List>) responseMap.get("vectors");
if(vectors.isEmpty()) {
break WHILE_LOOP;
}
StringBuilder ids = new StringBuilder();
for(int i = 0 ; i < vectors.size(); i++) {
Map idMap = vectors.get(i);
ids.append("&ids=").append(idMap.get("id"));
if( (i+1) % 20 == 0) {
sources.addAll(fetchUniqueSourceValues(ids));
ids = new StringBuilder();
}
}
if(ids.length() != 0) {
sources.addAll(fetchUniqueSourceValues(ids));
ids = new StringBuilder();
}
// we can only pull 100 at a time
// we need to check if there is pagination to keep going
Map paginationMap = (Map) responseMap.get("pagination");
if(paginationMap != null && !paginationMap.isEmpty()) {
paginationToken = paginationMap.get("next");
} else {
paginationToken = null;
}
firstExecution=false;
}
List> fileList = new ArrayList<>();
File documentsDir = new File(this.schemaFolder.getAbsolutePath() + DIR_SEPARATOR + indexClass + DIR_SEPARATOR + AbstractVectorDatabaseEngine.DOCUMENTS_FOLDER_NAME);
if(documentsDir.exists() && documentsDir.isDirectory()) {
for(String fileName : sources) {
Map fileInfo = new HashMap<>();
fileInfo.put("fileName", fileName);
File thisF = new File(documentsDir, fileName);
if(thisF.exists() && thisF.isFile()) {
long fileSizeInBytes = thisF.length();
double fileSizeInMB = (double) fileSizeInBytes / (1024);
SimpleDateFormat dateFormat = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss");
String lastModified = dateFormat.format(new Date(thisF.lastModified()));
// add file size and last modified into the map
fileInfo.put("fileSize", fileSizeInMB);
fileInfo.put("lastModified", lastModified);
}
fileList.add(fileInfo);
}
}
return fileList;
}
@Override
public List> listAllRecords(Map parameters) {
/**
* Pinecone's API is not useful
* https://docs.pinecone.io/reference/api/2025-01/data-plane/list
*
* don't have a way to fetch results w/o first listing the index's
* and after listing the index's the fetch can't limit the return so
* will also get back the vectors
*/
Gson gson = new GsonBuilder().create();
Map headersMap = new HashMap<>();
headersMap.put(API_KY, this.apiKey);
headersMap.put(HttpHeaders.CONTENT_TYPE, ContentType.APPLICATION_JSON.getMimeType());
List> allRecords = new ArrayList<>();
boolean firstExecution = true;
String paginationToken = null;
WHILE_LOOP : while(firstExecution || paginationToken!=null) {
String listVectorsUrl = this.hostname + LIST_QUERY + this.defaultNamespace;
if(paginationToken!=null) {
listVectorsUrl+=PAGINATION_TOKEN+paginationToken;
}
String idListResponse = HttpHelperUtility.getRequest(listVectorsUrl, headersMap, null, null, null);
Map responseMap = gson.fromJson(idListResponse, new TypeToken>() {}.getType());
List> vectors = (List>) responseMap.get("vectors");
if(vectors.isEmpty()) {
break WHILE_LOOP;
}
StringBuilder ids = new StringBuilder();
for(int i = 0 ; i < vectors.size(); i++) {
Map idMap = vectors.get(i);
ids.append("&ids=").append(idMap.get("id"));
if( (i+1) % 20 == 0) {
allRecords.addAll(fetchAllValues(ids));
ids = new StringBuilder();
}
}
if(ids.length() != 0) {
allRecords.addAll(fetchAllValues(ids));
ids = new StringBuilder();
}
// we can only pull 100 at a time
// we need to check if there is pagination to keep going
Map paginationMap = (Map) responseMap.get("pagination");
if(paginationMap != null && !paginationMap.isEmpty()) {
paginationToken = paginationMap.get("next");
} else {
paginationToken = null;
}
firstExecution=false;
}
return allRecords;
}
/**
*
* @param ids
* @return
*/
private List> fetchAllValues(StringBuilder ids) {
List> records = new ArrayList<>();
Map headersMap = new HashMap<>();
headersMap.put(API_KY, this.apiKey);
headersMap.put(HttpHeaders.CONTENT_TYPE, ContentType.APPLICATION_JSON.getMimeType());
String fetchUrl = this.hostname + FETCH_QUERY + this.defaultNamespace + ids.toString();
String fetchResponse = HttpHelperUtility.getRequest(fetchUrl, headersMap, null, null, null);
JsonObject pineconeHorribleJson = JsonParser.parseString(fetchResponse).getAsJsonObject();
JsonObject vectorsJson = pineconeHorribleJson.get("vectors").getAsJsonObject();
for(String uid : vectorsJson.keySet()) {
JsonObject record = vectorsJson.get(uid).getAsJsonObject();
JsonObject metadata = record.get("metadata").getAsJsonObject();
Map recordMap = new HashMap<>();
recordMap.put(VectorDatabaseCSVTable.SOURCE, metadata.get(VectorDatabaseCSVTable.SOURCE).getAsString());
recordMap.put(VectorDatabaseCSVTable.MODALITY, metadata.get(VectorDatabaseCSVTable.MODALITY).getAsString());
recordMap.put(VectorDatabaseCSVTable.DIVIDER, metadata.get(VectorDatabaseCSVTable.DIVIDER).getAsString());
recordMap.put(VectorDatabaseCSVTable.PART, metadata.get(VectorDatabaseCSVTable.PART).getAsString());
recordMap.put(VectorDatabaseCSVTable.TOKENS, metadata.get(VectorDatabaseCSVTable.TOKENS).getAsInt());
recordMap.put(VectorDatabaseCSVTable.CONTENT, metadata.get(VectorDatabaseCSVTable.CONTENT).getAsString());
records.add(recordMap);
}
return records;
}
/**
*
* @param ids
* @return
*/
private Set fetchUniqueSourceValues(StringBuilder ids) {
Set uniqueSources = new HashSet<>();
Map headersMap = new HashMap<>();
headersMap.put(API_KY, this.apiKey);
headersMap.put(HttpHeaders.CONTENT_TYPE, ContentType.APPLICATION_JSON.getMimeType());
String fetchUrl = this.hostname + FETCH_QUERY + this.defaultNamespace + ids.toString();
String fetchResponse = HttpHelperUtility.getRequest(fetchUrl, headersMap, null, null, null);
JsonObject pineconeHorribleJson = JsonParser.parseString(fetchResponse).getAsJsonObject();
JsonObject vectorsJson = pineconeHorribleJson.get("vectors").getAsJsonObject();
for(String uid : vectorsJson.keySet()) {
JsonObject record = vectorsJson.get(uid).getAsJsonObject();
JsonObject metadata = record.get("metadata").getAsJsonObject();
uniqueSources.add(metadata.get(VectorDatabaseCSVTable.SOURCE).getAsString());
}
return uniqueSources;
}
@Override
public VectorDatabaseTypeEnum getVectorDatabaseType() {
return VectorDatabaseTypeEnum.PINECONE;
}
}