All Downloads are FREE. Search and download functionalities are using the official Maven repository.

prerna.reactor.vector.CreateEmbeddingsFromDocumentsReactor Maven / Gradle / Ivy

The newest version!
package prerna.reactor.vector;

import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.zip.ZipEntry;
import java.util.zip.ZipInputStream;

import org.apache.commons.io.FileUtils;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.tika.Tika;
import org.apache.tika.metadata.Metadata;

import prerna.auth.utils.SecurityEngineUtils;
import prerna.engine.api.IVectorDatabaseEngine;
import prerna.engine.impl.vector.AbstractVectorDatabaseEngine;
import prerna.reactor.AbstractReactor;
import prerna.reactor.vector.VectorDatabaseParamOptionsEnum.CreateEmbeddingsParamOptions;
import prerna.sablecc2.om.GenRowStruct;
import prerna.sablecc2.om.PixelDataType;
import prerna.sablecc2.om.PixelOperationType;
import prerna.sablecc2.om.ReactorKeysEnum;
import prerna.sablecc2.om.execptions.SemossPixelException;
import prerna.sablecc2.om.nounmeta.NounMetadata;
import prerna.util.AssetUtility;
import prerna.util.Constants;
import prerna.util.Utility;

public class CreateEmbeddingsFromDocumentsReactor extends AbstractReactor {

	private static final Logger classLogger = LogManager.getLogger(CreateEmbeddingsFromDocumentsReactor.class);
	
	private final String PATH_TO_UNZIP_FILES = "zipFileExtractFolder";
	private final String FILE_PATHS_KEY = "filePaths";
	
	public CreateEmbeddingsFromDocumentsReactor() {
		this.keysToGet = new String[] {ReactorKeysEnum.ENGINE.getKey(), FILE_PATHS_KEY, 
				ReactorKeysEnum.SPACE.getKey(), ReactorKeysEnum.PARAM_VALUES_MAP.getKey()};
		this.keyRequired = new int[] {1, 1, 0, 0};
	}

	@Override
	public NounMetadata execute() {
		organizeKeys();
		String engineId = this.keyValue.get(ReactorKeysEnum.ENGINE.getKey());
		if(!SecurityEngineUtils.userCanEditEngine(this.insight.getUser(), engineId)) {
			throw new IllegalArgumentException("Vector db " + engineId + " does not exist or user does not have access to this engine");
		}

		IVectorDatabaseEngine vectorDatabase = Utility.getVectorDatabase(engineId);
		if (vectorDatabase == null) {
			throw new SemossPixelException("Unable to find engine");
		}

		Map paramMap = getMap();
		if(paramMap == null) {
			paramMap = new HashMap();
		}

		// check user has access to any embedding models as well 
		// this actually throws an error
		// but will wrap in if statement just in case
		if(!vectorDatabase.userCanAccessEmbeddingModels(this.insight.getUser())) {
			throw new IllegalArgumentException("User does not have access to all the vector database dependent models");
		}
		
		// send the insight so it can be used with IModelEngine call
		paramMap.put(AbstractVectorDatabaseEngine.INSIGHT, this.insight);

		String rootFolder = getRootFolder();
		// this is coming from an insight so i assume its just the file names
		List validFiles = new ArrayList<>();
		List invalidFiles = new ArrayList<>();
		try {
			getFiles(rootFolder, validFiles, invalidFiles);
			if (validFiles.isEmpty()) {
				throw new IllegalArgumentException("Please provide valid input files using \"filePaths\". File types supported are pdf, word, ppt, or txt files");
			}

			for (String filePath: validFiles) {
				File file = new File(Utility.normalizePath(filePath));
				// Check if the file exists
				if (!file.exists()) {
					throw new IllegalArgumentException("File path for " + file.getName() + " does not exist within the insight or project space.");
				}
			}

			vectorDatabase.addDocument(validFiles, paramMap);
		} catch (Exception e) {
			classLogger.error(Constants.STACKTRACE, e);
			throw new IllegalArgumentException("The following exception occured: " + e.getMessage());
		} finally {
			File zipFileExtractionDir = new File(rootFolder + "/" + PATH_TO_UNZIP_FILES);
			if (zipFileExtractionDir.exists()) {
				try {
					FileUtils.forceDelete(zipFileExtractionDir);
				} catch (IOException e) {
					classLogger.error(Constants.STACKTRACE, e);
				}
			}
		}

		NounMetadata noun = new NounMetadata(true, PixelDataType.BOOLEAN, PixelOperationType.OPERATION);
		if(!invalidFiles.isEmpty()) {
			List invalidFileNamesRelative = new ArrayList<>(invalidFiles.size());
			for(String invalidF : invalidFiles) {
				invalidFileNamesRelative.add(invalidF.replace(rootFolder, ""));
			}
			noun.addAdditionalReturn(NounMetadata.getWarningNounMessage("Unable to upload " + String.join(", ", invalidFileNamesRelative)));
		}
		return noun;
	}

	/**
	 * Get the map from the paramValues noun store
	 * @return list of engines to delete
	 */
	private Map getMap() {
		GenRowStruct mapGrs = this.store.getNoun(ReactorKeysEnum.PARAM_VALUES_MAP.getKey());
		if(mapGrs != null && !mapGrs.isEmpty()) {
			List mapInputs = mapGrs.getNounsOfType(PixelDataType.MAP);
			if(mapInputs != null && !mapInputs.isEmpty()) {
				return (Map) mapInputs.get(0).getValue();
			}
		}
		List mapInputs = this.curRow.getNounsOfType(PixelDataType.MAP);
		if(mapInputs != null && !mapInputs.isEmpty()) {
			return (Map) mapInputs.get(0).getValue();
		}
		
		return null;
	}

	/**
	 * 
	 * @return
	 */
	private String getRootFolder() {
		String space = null;
		GenRowStruct spaceGrs = store.getNoun(ReactorKeysEnum.SPACE.getKey());
		if (spaceGrs != null && !spaceGrs.isEmpty()) {
			space = spaceGrs.get(0).toString();
		}
		
		return AssetUtility.getAssetVersionBasePath(this.insight, space, false);
	}
	
	/**
	 * @param insightFolder
	 * @param validFiles
	 * @param invalidFiles
	 * @return
	 * @throws IOException
	 */
	private void getFiles(String rootFolder, List validFiles, List invalidFiles) throws IOException {
		GenRowStruct grs = this.store.getNoun(FILE_PATHS_KEY);
		if (grs != null && !grs.isEmpty()) {
			int size = grs.size();
			for (int i = 0; i < size; i++) {
				String filePath = rootFolder + "/" + grs.get(i).toString();
				if (isZipFile(filePath)) {
					String zipFileLocation = filePath.replace('\\', '/');
					File zipFileExtractFolder = new File(rootFolder, PATH_TO_UNZIP_FILES);
					unzipAndFilter(zipFileLocation, zipFileExtractFolder.getAbsolutePath(), validFiles, invalidFiles);
				} else {
					//String filePath = destDirectory + File.separator + entry.getName();
					if(isSupportedFileType(filePath)) {
						validFiles.add(filePath);
					} else {
						invalidFiles.add(filePath);
					}
				}
			}
		}
	}

	/**
	 * Recursively go through all the zips, directories and files in a zip file and save the paths of 
	 * valid file types
	 * 
	 * @param zipFilePath
	 * @param destDirectory
	 * @param validFiles
	 * @param invalidFiles
	 * @throws IOException
	 */
	private void unzipAndFilter(String zipFilePath, String destDirectory, List validFiles, List invalidFiles) throws IOException {
		File destDir = new File(Utility.normalizePath(destDirectory));
		if (!destDir.exists()) {
			destDir.mkdir();
		}

		try (ZipInputStream zipIn = new ZipInputStream(new FileInputStream(Utility.normalizePath(zipFilePath)))) {
			ZipEntry entry = zipIn.getNextEntry();

			while (entry != null) {
				String filePath = destDirectory + "/" + entry.getName();
				if (!entry.isDirectory()) {
					if(isSupportedFileType(filePath)) {
						extractFile(zipIn, filePath);
						validFiles.add(filePath);
					} else {
						invalidFiles.add(filePath);
					}
				} else if (entry.isDirectory()) {
					File dir = new File(Utility.normalizePath(filePath));
					dir.mkdirs();
				} else if (isZipFile(filePath)) {
					// Handle nested zip file
					this.extractFile(zipIn, filePath);

					// Check if the entry is not in the root directory
					String parentPath = null;
					if(filePath.contains("/")) { // ZIP entries use "/" as a separator
						parentPath = filePath.substring(0, filePath.lastIndexOf('/'));
					}

					// Extract the last part of the path (file name + extension)
					String fileNameWithExtension = filePath.contains("/") 
							? filePath.substring(filePath.lastIndexOf('/') + 1) 
									: filePath;

							// Remove the extension
							String baseName = fileNameWithExtension.contains(".") 
									? fileNameWithExtension.substring(0, fileNameWithExtension.lastIndexOf('.')) 
											: fileNameWithExtension;

									unzipAndFilter(filePath, parentPath + "/" + baseName, validFiles, invalidFiles);
				}

				zipIn.closeEntry();
				entry = zipIn.getNextEntry();
			}
		}
	}

	/**
	 * 
	 * @param zipIn
	 * @param filePath
	 * @throws IOException
	 */
	private void extractFile(ZipInputStream zipIn, String filePath) throws IOException {
		try (FileOutputStream fos = new FileOutputStream(Utility.normalizePath(filePath))) {
			byte[] buffer = new byte[1024];
			int bytesRead;
			while ((bytesRead = zipIn.read(buffer)) != -1) {
				fos.write(buffer, 0, bytesRead);
			}
		}
	}

	/**
	 * 
	 * @param filePath
	 * @return
	 */
	private boolean isSupportedFileType(String filePath) {
		// Find the last index of '.'
		int dotIndex = filePath.lastIndexOf('.');

		if (dotIndex > 0 && dotIndex < filePath.length() - 1) {
			// Extract the extension and convert it to lower case
			String extension = filePath.substring(dotIndex + 1).toLowerCase();

			return extension.equals("pdf") || extension.equals("pptx") || extension.equals("ppt")
					|| extension.equals("doc") || extension.equals("docx") || extension.equals("txt") || extension.equals("csv");
		} else {
			// do a mime type check
			Tika tika = new Tika();
			File file = new File(Utility.normalizePath(filePath));
			try (FileInputStream inputstream = new FileInputStream(file)) {
				String mimeType = tika.detect(inputstream, new Metadata());

				switch (mimeType) {
				case "application/pdf":
				case "application/vnd.openxmlformats-officedocument.wordprocessingml.document": // .docx
				case "application/vnd.ms-powerpoint": // .ppt
				case "application/vnd.openxmlformats-officedocument.presentationml.presentation": // .pptx
				case "text/plain":
					return true;
				default:
					return false;
				}
			} catch (IOException e) {
				classLogger.error(Constants.ERROR_MESSAGE, e);
				return false;
			}
		}
	}

	/**
	 * 
	 * @param filePath
	 * @return
	 */
	private boolean isZipFile(String filePath) {        
		// Find the last index of '.'
		int dotIndex = filePath.lastIndexOf('.');

		if (dotIndex > 0 && dotIndex < filePath.length() - 1) {
			// Extract the extension and convert it to lower case
			String extension = filePath.substring(dotIndex + 1).toLowerCase();

			return extension.equals("zip");
		} else {
			// do a mime type check
			Tika tika = new Tika();
			File file = new File(Utility.normalizePath(filePath));
			try (FileInputStream inputstream = new FileInputStream(file)) {
				String mimeType = tika.detect(inputstream, new Metadata());

				if (mimeType != null) {
					if (mimeType.equalsIgnoreCase("application/zip")) {
						return true;
					}
				} 

				return false;
			} catch (IOException e) {
				classLogger.error(Constants.ERROR_MESSAGE, e);
				return false;
			}
		}
	}

	@Override
	protected String getDescriptionForKey(String key) {
		if(key.equals(ReactorKeysEnum.PARAM_VALUES_MAP.getKey())) {
			StringBuilder finalDescription = new StringBuilder("Param Options depend on the engine implementation");

			for (CreateEmbeddingsParamOptions entry : CreateEmbeddingsParamOptions.values()) {
				finalDescription.append("\n")
				.append("\t\t\t\t\t")
				.append(entry.getVectorDbType().getVectorDatabaseName())
				.append(":");

				for (String paramKey : entry.getParamOptionsKeys()) {
					finalDescription.append("\n")
					.append("\t\t\t\t\t\t")
					.append(paramKey)
					.append("\t")
					.append("-")
					.append("\t")
					.append("(").append(entry.getRequirementStatus(paramKey)).append(")")
					.append(" ")
					.append(VectorDatabaseParamOptionsEnum.getDescriptionFromKey(paramKey));
				}
			}
			return finalDescription.toString();
		}

		return super.getDescriptionForKey(key);
	}
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy