
prerna.engine.impl.vector.PGVectorDatabaseEngine Maven / Gradle / Ivy
The newest version!
package prerna.engine.impl.vector;
import java.io.File;
import java.io.IOException;
import java.nio.file.Paths;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.sql.Statement;
import java.text.SimpleDateFormat;
import java.time.ZonedDateTime;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Date;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.Set;
import java.util.UUID;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.FilenameUtils;
import org.apache.commons.text.StringSubstitutor;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import com.google.gson.Gson;
import com.google.gson.GsonBuilder;
import com.pgvector.PGvector;
import prerna.auth.User;
import prerna.auth.utils.SecurityEngineUtils;
import prerna.cluster.util.ClusterUtil;
import prerna.cluster.util.CopyFilesToEngineRunner;
import prerna.cluster.util.DeleteFilesFromEngineRunner;
import prerna.ds.py.PyTranslator;
import prerna.ds.py.PyUtils;
import prerna.engine.api.ICustomEmbeddingsFunctionEngine;
import prerna.engine.api.IEngine;
import prerna.engine.api.IFunctionEngine;
import prerna.engine.api.IModelEngine;
import prerna.engine.api.IVectorDatabaseEngine;
import prerna.engine.api.VectorDatabaseTypeEnum;
import prerna.engine.impl.SmssUtilities;
import prerna.engine.impl.model.responses.EmbeddingsModelEngineResponse;
import prerna.engine.impl.model.workers.ModelEngineInferenceLogsWorker;
import prerna.engine.impl.rdbms.RDBMSNativeEngine;
import prerna.engine.impl.vector.metadata.VectorDatabaseMetadataCSVRow;
import prerna.engine.impl.vector.metadata.VectorDatabaseMetadataCSVTable;
import prerna.engine.impl.vector.metadata.VectorDatabaseMetadataCSVWriter;
import prerna.om.ClientProcessWrapper;
import prerna.om.Insight;
import prerna.om.InsightStore;
import prerna.query.querystruct.SelectQueryStruct;
import prerna.query.querystruct.filters.GenRowFilters;
import prerna.query.querystruct.filters.IQueryFilter;
import prerna.query.querystruct.selectors.QueryColumnSelector;
import prerna.query.querystruct.selectors.QueryOpaqueSelector;
import prerna.reactor.vector.VectorDatabaseParamOptionsEnum;
import prerna.util.ConnectionUtils;
import prerna.util.Constants;
import prerna.util.EngineUtility;
import prerna.util.QueryExecutionUtility;
import prerna.util.Settings;
import prerna.util.Utility;
import prerna.util.sql.PGVectorQueryUtil;
public class PGVectorDatabaseEngine extends RDBMSNativeEngine implements IVectorDatabaseEngine {
private static final Logger classLogger = LogManager.getLogger(PGVectorDatabaseEngine.class);
private static final String DIR_SEPARATOR = "/";
private static final String FILE_SEPARATOR = java.nio.file.FileSystems.getDefault().getSeparator();
public static final String PGVECTOR_TABLE_NAME = "PGVECTOR_TABLE_NAME";
public static final String PGVECTOR_METADATA_TABLE_NAME = "PGVECTOR_METADATA_TABLE_NAME";
private int contentLength = 512;
private int contentOverlap = 0;
private String defaultChunkUnit;
// protected String defaultExtractionMethod;
protected boolean customDocumentProcessor = false;
protected String customDocumentProcessorFunctionID = null;
private String embedderEngineId = null;
private String keywordGeneratorEngineId = null;
private String distanceMethod = null;
private String vectorTableName = null;
private String vectorTableMetadataName = null;
private File schemaFolder;
// our paradigm for how we store files
private String defaultIndexClass;
private List indexClasses;
// python server
private PyTranslator pyt = null;
private File pyDirectoryBasePath;
private ClientProcessWrapper cpw = null;
private boolean modelPropsLoaded = false;
// string substitute vars
private Map vars = new HashMap<>();
private PGVectorQueryUtil pgVectorQueryUtil = new PGVectorQueryUtil();
// maintain details in the log database
protected boolean keepInputOutput = false;
protected boolean inferenceLogsEnbaled = Utility.isModelInferenceLogsEnabled();
@Override
public void open(Properties smssProp) throws Exception {
super.open(smssProp);
this.distanceMethod = smssProp.getProperty(Constants.DISTANCE_METHOD);
this.vectorTableName = smssProp.getProperty(PGVECTOR_TABLE_NAME);
if(this.vectorTableName == null || (this.vectorTableName=this.vectorTableName.trim()).isEmpty()) {
throw new NullPointerException("Must define the vector db table name");
}
this.vectorTableMetadataName = smssProp.getProperty(PGVECTOR_METADATA_TABLE_NAME);
if(this.vectorTableMetadataName == null || (this.vectorTableMetadataName=this.vectorTableMetadataName.trim()).isEmpty()) {
this.vectorTableMetadataName = this.vectorTableName + "_METADATA";
}
Connection conn = null;
try {
conn = getConnection();
PGvector.addVectorType(conn);
initSQL(this.vectorTableName, this.vectorTableMetadataName);
} catch(SQLException e) {
classLogger.error(Constants.STACKTRACE, e);
throw e;
} finally {
ConnectionUtils.closeAllConnectionsIfPooling(this, conn, null, null);
}
if (this.smssProp.containsKey(Constants.CONTENT_LENGTH)) {
this.contentLength = Integer.parseInt(this.smssProp.getProperty(Constants.CONTENT_LENGTH));
}
if (this.smssProp.containsKey(Constants.CONTENT_OVERLAP)) {
this.contentOverlap = Integer.parseInt(this.smssProp.getProperty(Constants.CONTENT_OVERLAP));
}
this.keepInputOutput = Boolean.parseBoolean(this.smssProp.getProperty(Constants.KEEP_INPUT_OUTPUT));
this.defaultChunkUnit = "tokens";
if (this.smssProp.containsKey(Constants.DEFAULT_CHUNK_UNIT)) {
this.defaultChunkUnit = this.smssProp.getProperty(Constants.DEFAULT_CHUNK_UNIT).toLowerCase().trim();
if (!this.defaultChunkUnit.equals("tokens") && !this.defaultChunkUnit.equals("characters")){
throw new IllegalArgumentException("DEFAULT_CHUNK_UNIT should be either 'tokens' or 'characters'");
}
}
// this.defaultExtractionMethod = this.smssProp.getProperty(Constants.EXTRACTION_METHOD, "None");
this.distanceMethod = this.smssProp.getProperty(Constants.DISTANCE_METHOD, "Cosine Similarity");
this.defaultIndexClass = "default";
if (this.smssProp.containsKey(Constants.INDEX_CLASSES)) {
this.defaultIndexClass = this.smssProp.getProperty(Constants.INDEX_CLASSES);
}
// smss properties for custom document processing
if (this.smssProp.containsKey(Constants.CUSTOM_DOCUMENT_PROCESSOR)) {
this.customDocumentProcessor = Boolean.parseBoolean(this.smssProp.getProperty(Constants.CUSTOM_DOCUMENT_PROCESSOR));
}
if (this.smssProp.containsKey(Constants.CUSTOM_DOCUMENT_PROCESSOR_FUNCTION_ID)) {
this.customDocumentProcessorFunctionID = this.smssProp.getProperty(Constants.CUSTOM_DOCUMENT_PROCESSOR_FUNCTION_ID);
}
// highest directory (first layer inside vector db base folder)
String engineDir = EngineUtility.getSpecificEngineBaseFolder(IEngine.CATALOG_TYPE.VECTOR, this.engineId, this.engineName);
this.pyDirectoryBasePath = new File(Utility.normalizePath(engineDir + DIR_SEPARATOR + "py" + DIR_SEPARATOR));
// second layer - This holds all the different "tables". The reason we want this is to easily and quickly grab the sub folders
this.schemaFolder = new File(engineDir, "schema");
if(!this.schemaFolder.exists()) {
this.schemaFolder.mkdirs();
}
// third layer - All the separate tables,classes, or searchers that can be added to this db
this.indexClasses = new ArrayList<>();
for (File file : this.schemaFolder.listFiles()) {
if (file.isDirectory() && !file.getName().equals("temp")) {
this.indexClasses.add(file.getName());
}
}
}
/**
*
* @param table
* @throws SQLException
*/
private void initSQL(String table, String metadataTable) throws SQLException {
String createMainTable = pgVectorQueryUtil.createEmbeddingsTable(table);
String createMetaTable = pgVectorQueryUtil.createEmbeddingsMetadataTable(metadataTable);
execCreateStatement(createMainTable);
execCreateStatement(createMetaTable);
pgVectorQueryUtil.createOWL(this, table, metadataTable);
}
/**
*
* @param createQuery
* @throws SQLException
*/
private void execCreateStatement(String createQuery) throws SQLException {
Connection conn = null;
Statement stmt = null;
try {
conn = getConnection();
stmt = conn.createStatement();
classLogger.info("Executing create table for "
+ SmssUtilities.getUniqueName(this.engineName, this.engineId) + " = " + createQuery);
stmt.execute(createQuery);
} catch(SQLException e) {
classLogger.warn("Unable to create the table " + createQuery);;
classLogger.error(Constants.STACKTRACE, e);
} finally {
if(this.dataSource != null) {
ConnectionUtils.closeAllConnections(conn, stmt);
} else {
ConnectionUtils.closeAllConnections(null, stmt);
}
}
}
/**
*
*/
protected void verifyModelProps() {
// This could get moved depending on other vector db needs
// This is to get the Model Name and Max Token for an encoder -- we need this to verify chunks aren't getting truncated
this.embedderEngineId = this.smssProp.getProperty(Constants.EMBEDDER_ENGINE_ID);
if (this.embedderEngineId == null || (this.embedderEngineId=this.embedderEngineId.trim()).isEmpty()) {
// check legacy key....
this.embedderEngineId = this.smssProp.getProperty("ENCODER_ID");
if (this.embedderEngineId == null || (this.embedderEngineId=this.embedderEngineId.trim()).isEmpty()) {
throw new IllegalArgumentException("Must define the embedder engine id for this vector database using " + Constants.EMBEDDER_ENGINE_ID);
}
this.smssProp.put(Constants.EMBEDDER_ENGINE_ID, embedderEngineId);
}
IModelEngine modelEngine = Utility.getModel(embedderEngineId);
if (modelEngine == null) {
throw new NullPointerException("Could not find the defined embedder engine id for this vector database with value = " + this.embedderEngineId);
}
Properties modelProperties = modelEngine.getSmssProp();
if (modelProperties.isEmpty() || !modelProperties.containsKey(Constants.MODEL)) {
throw new IllegalArgumentException("Embedder engine exists but does not contain key " + Constants.MODEL);
}
this.smssProp.put(Constants.MODEL, modelProperties.getProperty(Constants.MODEL));
this.smssProp.put(IModelEngine.MODEL_TYPE, modelProperties.getProperty(IModelEngine.MODEL_TYPE));
if (!modelProperties.containsKey(Constants.MAX_TOKENS)) {
this.smssProp.put(Constants.MAX_TOKENS, "None");
} else {
this.smssProp.put(Constants.MAX_TOKENS, modelProperties.getProperty(Constants.MAX_TOKENS));
}
// model engine responsible for creating keywords
this.keywordGeneratorEngineId = this.smssProp.getProperty(Constants.KEYWORD_ENGINE_ID);
if (this.keywordGeneratorEngineId != null && !(this.keywordGeneratorEngineId=this.keywordGeneratorEngineId.trim()).isEmpty()) {
// pull the model smss if needed
Utility.getModel(this.keywordGeneratorEngineId);
this.smssProp.put(Constants.KEYWORD_ENGINE_ID, this.keywordGeneratorEngineId);
} else {
// add it to the smss prop so the string substitution does not fail
this.smssProp.put(Constants.KEYWORD_ENGINE_ID, "");
}
for (Object smssKey : this.smssProp.keySet()) {
String key = smssKey.toString();
this.vars.put(key, this.smssProp.getProperty(key));
}
this.modelPropsLoaded = true;
}
@Override
public void addEmbeddings(List vectorCsvFiles, Insight insight, Map parameters) throws Exception {
for(String vectorCsvFile : vectorCsvFiles) {
VectorDatabaseCSVTable vectorCsvTable = VectorDatabaseCSVTable.initCSVTable(new File(vectorCsvFile));
addEmbeddings(vectorCsvTable, insight, parameters);
}
}
@Override
public void addEmbeddings(String vectorCsvFile, Insight insight, Map parameters) throws Exception {
VectorDatabaseCSVTable vectorCsvTable = VectorDatabaseCSVTable.initCSVTable(new File(vectorCsvFile));
addEmbeddings(vectorCsvTable, insight, parameters);
}
@Override
public void addEmbeddingFiles(List vectorCsvFiles, Insight insight, Map parameters) throws Exception {
for(File vectorCsvFile : vectorCsvFiles) {
VectorDatabaseCSVTable vectorCsvTable = VectorDatabaseCSVTable.initCSVTable(vectorCsvFile);
addEmbeddings(vectorCsvTable, insight, parameters);
}
}
@Override
public void addEmbeddingFile(File vectorCsvFile, Insight insight, Map parameters) throws Exception {
VectorDatabaseCSVTable vectorCsvTable = VectorDatabaseCSVTable.initCSVTable(vectorCsvFile);
addEmbeddings(vectorCsvTable, insight, parameters);
}
@Override
public void addEmbeddings(VectorDatabaseCSVTable vectorCsvTable, Insight insight, Map parameters) throws Exception {
if (insight == null) {
throw new IllegalArgumentException("Insight must be provided to run Model Engine Encoder");
}
if (!modelPropsLoaded) {
verifyModelProps();
}
// if we were able to extract files, begin embeddings process
IModelEngine embeddingsEngine = Utility.getModel(this.embedderEngineId);
// send all the strings to embed in one shot
try {
vectorCsvTable.generateAndAssignEmbeddings(embeddingsEngine, insight);
} catch (Exception e) {
classLogger.error(Constants.STACKTRACE, e);
throw new IllegalArgumentException("Error occurred creating the embeddings for the generated chunks. Detailed error message = " + e.getMessage());
}
String psString = "INSERT INTO "
+ this.vectorTableName
+ " (EMBEDDING, SOURCE, MODALITY, DIVIDER, PART, TOKENS, CONTENT) "
+ "VALUES (?,?,?,?,?,?,?)";
Connection conn = null;
PreparedStatement ps = null;
try {
conn = this.getConnection();
ps = conn.prepareStatement(psString);
// if (parameters.containsKey(VectorDatabaseParamOptionsEnum.KEYWORD_SEARCH_PARAM.getKey())) {
// IModelEngine keywordEngine = Utility.getModel(this.keywordGeneratorEngineId);
// dataForTable.setKeywordEngine(keywordEngine);
// }
final int batchSize = 1000;
int count = 0;
for (VectorDatabaseCSVRow row: vectorCsvTable.getRows()) {
int index = 1;
ps.setObject(index++, new PGvector(row.getEmbeddings()));
ps.setString(index++, row.getSource());
ps.setString(index++, row.getModality());
ps.setString(index++, row.getDivider());
ps.setString(index++, row.getPart());
ps.setInt(index++, row.getTokens());
ps.setString(index++, row.getContent());
ps.addBatch();
// batch commit based on size
if (++count % batchSize == 0) {
classLogger.info("Executing embeddings batch .... row num = " + count);
int[] results = ps.executeBatch();
for(int j=0; j> metadata = (Map>) parameters.get(AbstractVectorDatabaseEngine.METADATA);
if(!metadata.isEmpty()) {
String tempMetadataFile = insight.getInsightFolder()+"/metadata"+Utility.getRandomString(6)+".csv";
VectorDatabaseMetadataCSVWriter writer = new VectorDatabaseMetadataCSVWriter(tempMetadataFile);
writer.bulkWriteRow(metadata);
try {
addMetadata(VectorDatabaseMetadataCSVTable.initCSVTable(new File(tempMetadataFile)));
} catch (SQLException | IOException e) {
classLogger.error(Constants.STACKTRACE, e);
throw e;
}
}
}
}
@Override
public void addEmbedding(List extends Number> embedding, String source, String modality, String divider,
String part, int tokens, String content, Map additionalMetadata) throws SQLException {
// just do the insertion
String psString = "INSERT INTO "
+ this.vectorTableName
+ " (EMBEDDING, SOURCE, MODALITY, DIVIDER, PART, TOKENS, CONTENT) "
+ "VALUES (?,?,?,?,?,?,?)";
Connection conn = null;
PreparedStatement ps = null;
try {
conn = this.getConnection();
ps = conn.prepareStatement(psString);
int index = 1;
ps.setObject(index++, new PGvector(embedding));
ps.setString(index++, source);
ps.setString(index++, modality);
ps.setString(index++, divider);
ps.setString(index++, part);
ps.setInt(index++, tokens);
ps.setString(index++, content);
int result = ps.executeUpdate();
if(result == PreparedStatement.EXECUTE_FAILED) {
throw new SQLException("Error inserting embeddings data");
}
if (!conn.getAutoCommit()) {
conn.commit();
}
} catch (SQLException e) {
classLogger.error(Constants.STACKTRACE, e);
throw e;
} finally {
ConnectionUtils.closeAllConnectionsIfPooling(this, conn, null, null);
}
}
@Override
public void removeDocument(List fileNames, Map parameters) throws IOException {
String indexClass = this.defaultIndexClass;
if (parameters.containsKey("indexClass")) {
indexClass = (String) parameters.get("indexClass");
}
List sourceNames = new ArrayList<>();
for(String document : fileNames) {
String documentName = FilenameUtils.getName(document);
File f = new File(document);
if(f.exists() && f.getName().endsWith(".csv")) {
sourceNames.addAll(VectorDatabaseCSVTable.pullSourceColumn(f));
} else {
sourceNames.add(documentName);
}
}
final String DOCUMENT_FOLDER = this.schemaFolder.getAbsolutePath() + DIR_SEPARATOR + indexClass + DIR_SEPARATOR + AbstractVectorDatabaseEngine.DOCUMENTS_FOLDER_NAME;
List filesToRemoveFromCloud = new ArrayList();
String deleteQuery = "DELETE FROM "+this.vectorTableName+" WHERE SOURCE=?";
String deleteMetaQuery = "DELETE FROM "+this.vectorTableMetadataName+" WHERE SOURCE=?";
Connection conn = null;
PreparedStatement ps = null;
PreparedStatement metaPs = null;
int[] results = null;
try {
conn = this.getConnection();
ps = conn.prepareStatement(deleteQuery);
metaPs = conn.prepareStatement(deleteMetaQuery);
for (String document : sourceNames) {
String documentName = Paths.get(document).getFileName().toString();
// remove the physical documents
File documentFile = new File(DOCUMENT_FOLDER, documentName);
if (documentFile.exists()) {
FileUtils.forceDelete(documentFile);
filesToRemoveFromCloud.add(documentFile.getAbsolutePath());
}
// remove the results from the db
int parameterIndex = 1;
ps.setString(parameterIndex++, documentName);
ps.addBatch();
parameterIndex = 1;
metaPs.setString(parameterIndex++, documentName);
metaPs.addBatch();
}
results = ps.executeBatch();
// since metadata is optional
// its fine if no rows updated
metaPs.executeBatch();
for(int j=0; j> nearestNeighborCall(Insight insight, String searchStatement, Number limit, Map parameters) {
if (insight == null) {
throw new IllegalArgumentException("Insight must be provided to run Model Engine Encoder");
}
if (!this.modelPropsLoaded) {
verifyModelProps();
}
List filters = null;
List metaFilters = null;
if (parameters.containsKey(AbstractVectorDatabaseEngine.FILTERS_KEY)) {
filters = PGVectorQueryFitlerTranslationHelper.convertFilters(
(List) parameters.get(AbstractVectorDatabaseEngine.FILTERS_KEY), this.vectorTableName
);
}
if (parameters.containsKey(AbstractVectorDatabaseEngine.METADATA_FILTERS_KEY)) {
metaFilters = PGVectorQueryMetaFitlerTranslationHelper.convertFilters(
(List) parameters.get(AbstractVectorDatabaseEngine.METADATA_FILTERS_KEY), this.vectorTableMetadataName
);
}
if (parameters.containsKey(VectorDatabaseParamOptionsEnum.COLUMNS_TO_RETURN.getKey())) {}
if (parameters.containsKey(VectorDatabaseParamOptionsEnum.RETURN_THRESHOLD.getKey())) {}
if (parameters.containsKey(VectorDatabaseParamOptionsEnum.ASCENDING.getKey())) {}
IModelEngine engine = Utility.getModel(this.embedderEngineId);
EmbeddingsModelEngineResponse embeddingsResponse = engine.embeddings(Arrays.asList(new String[] {searchStatement}), insight, null);
final String tablePrefix = this.vectorTableName+"__";
// final String metaTablePrefix = this.vectorTableMetadataName+"__";
SelectQueryStruct qs = new SelectQueryStruct();
qs.addSelector(new QueryColumnSelector(tablePrefix+VectorDatabaseCSVTable.SOURCE, VectorDatabaseCSVTable.SOURCE));
qs.addSelector(new QueryColumnSelector(tablePrefix+VectorDatabaseCSVTable.MODALITY, VectorDatabaseCSVTable.MODALITY));
qs.addSelector(new QueryColumnSelector(tablePrefix+VectorDatabaseCSVTable.DIVIDER, VectorDatabaseCSVTable.DIVIDER));
qs.addSelector(new QueryColumnSelector(tablePrefix+VectorDatabaseCSVTable.PART, VectorDatabaseCSVTable.PART));
qs.addSelector(new QueryColumnSelector(tablePrefix+VectorDatabaseCSVTable.TOKENS, VectorDatabaseCSVTable.TOKENS));
qs.addSelector(new QueryColumnSelector(tablePrefix+VectorDatabaseCSVTable.CONTENT, VectorDatabaseCSVTable.CONTENT));
// Determine the distanceMethod to use for the query
// Store the result in the "Score" field,
if ("Cosine Similarity".equalsIgnoreCase(distanceMethod)) {
// '<=>' cosine similarity operator
// cosine distance is between -1 and 1
// Using 1 - cosine distance converts the distance metric into a similarity metric.
// 1 = identical
// 0 = orthogonal
// -1 = opposite
// so need to show results as desc
qs.addSelector(new QueryOpaqueSelector("1 - (EMBEDDING <=> '" + embeddingsResponse.getResponse().get(0) + "')", "Score"));
// This allows us to sort results by similarity in descending order
// (from most similar to least similar).
qs.addOrderBy("Score", "DESC");
} else {
// '<->' Euclidean (L2) distance operator
// The POWER function is used to square the distance to avoid the computational cost of square roots
// This also ensures all distance values are non-negative, which is important for optimization
qs.addSelector(new QueryOpaqueSelector(
"POWER((EMBEDDING <-> '" + embeddingsResponse.getResponse().get(0) + "'),2)", "Score"));
qs.addOrderBy("Score", "ASC");
}
if(filters != null && !filters.isEmpty()) {
qs.addExplicitFilter(new GenRowFilters(filters), true);
}
if(metaFilters != null && !metaFilters.isEmpty()) {
// also need the join
qs.addRelation(this.vectorTableName, this.vectorTableMetadataName, "inner.join");
qs.addExplicitFilter(new GenRowFilters(metaFilters), true);
}
if(limit != null) {
qs.setLimit(limit.longValue());
}
List
© 2015 - 2025 Weber Informatics LLC | Privacy Policy