All Downloads are FREE. Search and download functionalities are using the official Maven repository.

prerna.reactor.frame.gaas.NLPQueryReactor Maven / Gradle / Ivy

The newest version!
package prerna.reactor.frame.gaas;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import java.util.Vector;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

import prerna.algorithm.api.ITableDataFrame;
import prerna.ds.py.PandasFrame;
import prerna.reactor.AbstractReactor;
import prerna.sablecc2.om.PixelDataType;
import prerna.sablecc2.om.PixelOperationType;
import prerna.sablecc2.om.ReactorKeysEnum;
import prerna.sablecc2.om.nounmeta.NounMetadata;
import prerna.util.DIHelper;
import prerna.util.Settings;
import prerna.util.Utility;

public class NLPQueryReactor extends AbstractReactor {

	// get a NLP Text
	// starts the environment / sets the model
	// convert text to sql through pipeline
	// plug the pipeline into insight
	
	//
	private static final Logger logger = LogManager.getLogger(NLPQueryReactor.class);

	public NLPQueryReactor() {
		this.keysToGet = new String[]{ReactorKeysEnum.COMMAND.getKey(), ReactorKeysEnum.MODEL.getKey()};
	}

	@Override
	public NounMetadata execute() {
		
		organizeKeys();
		String query = keyValue.get(keysToGet[0]);
		
		String model = "tscholak/cxmefzzi";
		// will handle huggingface model
		if(keyValue.containsKey(keysToGet[1]))
			model = keyValue.get(keysToGet[1]);

		if(DIHelper.getInstance().getProperty("HF_SEQ_2_SQL") != null) // force the model
			model = DIHelper.getInstance().getProperty("HF_SEQ_2_SQL");
		
		String modelVarName = Utility.cleanString(model, true);
		modelVarName = modelVarName.replaceAll("-", "_");

		String pipeVar = modelVarName; // setting up so that I dont need to load same model multiple times during the same session
		
		boolean modelLoaded = (Boolean)this.insight.getPyTranslator().runScript("'" + modelVarName +"' in globals()");
		
		if(!modelLoaded) // load the model
		{
			// load the model
			//pipeVar = Utility.getRandomString(5);
			String cacheFolder = DIHelper.getInstance().getProperty(Settings.HF_CACHE_DIR);
			if(cacheFolder == null)
				return NounMetadata.getErrorNounMessage("Hugging Face Cache is not set. Please set it RDF Map HF_CACHE_DIR");
			this.insight.getPyTranslator().runScript(modelVarName + " = smssutil.load_hugging_face_model('" + model + "', 'text2text-generation', '" + cacheFolder + "')");
			logger.info("Loaded the model as " + pipeVar);
		}
		
		// may be we should get all the frames here
		Set  allFrames = this.insight.getVarStore().getFrames();
		
		// iterate through each of the frame
		// get the columns
		// get the dictionary data for these columns
		// create the string
		// set into the frames
		// format - [question] | [db_id] | [table] : [column] ( [content] , [content] ) , [column] ( ... ) , [...] | [table] : ... | ...
		// example - concert_singer | stadium : stadium_id, location, name, capacity, highest, lowest, average | singer : singer_id, name, country, 
		
		StringBuffer finalDbString = new StringBuffer("db");
		
		Iterator  frameIterator = allFrames.iterator();
		while(frameIterator.hasNext())
		{
			ITableDataFrame thisFrame = frameIterator.next();
			logger.info("Processing frame " + thisFrame.getName());
			HashMap columnValues = new HashMap();
			if(thisFrame instanceof PandasFrame)
			{
				Object output = this.insight.getPyTranslator().runScript(thisFrame.getName() + "w.get_categorical_values()");
				if(output instanceof HashMap)
					columnValues = (HashMap)this.insight.getPyTranslator().runScript(thisFrame.getName() + "w.get_categorical_values()");
			}
			
			finalDbString.append(" | ").append(thisFrame.getName()).append(" : ");
			String [] columns = thisFrame.getColumnHeaders();
			
			// if the frame is pandas frame get the data
			// we will get to this shortly
			for(int columnIndex = 0;columnIndex < columns.length;columnIndex++)
			{
				if(columnIndex == 0)
					finalDbString.append(columns[columnIndex]);
				else
					finalDbString.append(" , ").append(columns[columnIndex]);
				//if(columnValues.containsKey(columns[columnIndex]))
				//	finalDbString.append(columnValues.get(columns[columnIndex]));
			}
		}
		
		String pipeQuery = query + " | " + finalDbString;
		logger.info("executing query " + pipeQuery);
		Object output = insight.getPyTranslator().runScript(pipeVar + "(\"" + pipeQuery + "\")");
		StringBuffer outputString = new StringBuffer();
		if(output instanceof ArrayList)
		{
			ArrayList thisList = (ArrayList)output;
			for(int listIndex = 0;listIndex < thisList.size();listIndex++)
			{
				String element = thisList.get(listIndex) + "";
				logger.info(element);
				outputString.append(element);
			}
		}
		// get the string
		// make a frame
		// load the frame into insight
		logger.info("Output query is " + outputString);
		String sqlDFQuery = outputString.toString();
		sqlDFQuery = sqlDFQuery.split("\\|")[1];
		sqlDFQuery = sqlDFQuery.substring(0, sqlDFQuery.length() -1);

		// execute sqlDF to create a frame
		// need to check if the query is right and then feed this into sqldf
		String frameName = Utility.getRandomString(5);
		String frameMaker = frameName + "= pd.DataFrame(sqldf('" + sqlDFQuery + "'))";
		logger.info("Creating frame with query..  " + sqlDFQuery + " <<>> " + frameMaker);
		insight.getPyTranslator().runEmptyPy("from pandasql import sqldf");
		insight.getPyTranslator().runScript(frameMaker); // load the sql df
		
		// check to see if the variable was created
		// if not this is a bad query
		boolean frameCreated = (Boolean)insight.getPyTranslator().runScript("'" + frameName + "' in globals()");
		
		List outputs = new Vector(4);

				
		if(frameCreated)
		{
			// now we just need to tell the user here is the frame
			String frameType = "Py";
			
			outputs.add(new NounMetadata("Query Generated : " + sqlDFQuery + " Data : " + frameName, PixelDataType.CONST_STRING));
			outputs.add(new NounMetadata(this.insight.getPyTranslator().runSingle(frameName + ".head(20)", this.insight), PixelDataType.CONST_STRING));
			outputs.add(new NounMetadata("To start working with this frame  GenerateFrameFrom" + frameType + "Variable('" + frameName + "')", PixelDataType.CONST_STRING));
			
			return new NounMetadata(outputs, PixelDataType.CODE, PixelOperationType.CODE_EXECUTION);
		}
		else
		{
			return new NounMetadata("Could not compute the result / query invalid -- " + sqlDFQuery, PixelDataType.CONST_STRING);
		}
	}

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy