
prerna.reactor.frame.gaas.NLPQuery3Reactor Maven / Gradle / Ivy
The newest version!
package prerna.reactor.frame.gaas;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import prerna.algorithm.api.DataFrameTypeEnum;
import prerna.algorithm.api.ITableDataFrame;
import prerna.algorithm.api.SemossDataType;
import prerna.ds.nativeframe.NativeFrame;
import prerna.ds.py.PandasFrame;
import prerna.ds.r.RDataTable;
import prerna.ds.rdbms.AbstractRdbmsFrame;
import prerna.engine.api.IModelEngine;
import prerna.engine.api.IRawSelectWrapper;
import prerna.query.parsers.GenExpressionWrapper;
import prerna.query.parsers.SqlParser2;
import prerna.query.querystruct.GenExpression;
import prerna.query.querystruct.HardSelectQueryStruct;
import prerna.query.querystruct.SelectQueryStruct;
import prerna.reactor.frame.AbstractFrameReactor;
import prerna.reactor.frame.r.util.AbstractRJavaTranslator;
import prerna.sablecc2.om.PixelDataType;
import prerna.sablecc2.om.PixelOperationType;
import prerna.sablecc2.om.ReactorKeysEnum;
import prerna.sablecc2.om.nounmeta.NounMetadata;
import prerna.util.Constants;
import prerna.util.DIHelper;
import prerna.util.Utility;
public class NLPQuery3Reactor extends AbstractFrameReactor {
// get a NLP Text
// starts the environment / sets the model
// convert text to sql through pipeline
// plug the pipeline into insight
//
private static final Logger classLogger = LogManager.getLogger(NLPQuery3Reactor.class);
public NLPQuery3Reactor() {
this.keysToGet = new String[] { ReactorKeysEnum.COMMAND.getKey(), "json", ReactorKeysEnum.TOKEN_COUNT.getKey(),
ReactorKeysEnum.FRAME.getKey(), "allFrames", "dialect", ReactorKeysEnum.ENGINE.getKey() };
}
@Override
public NounMetadata execute() {
organizeKeys();
String query = keyValue.get(keysToGet[0]);
boolean json = true;
if (keyValue.containsKey(keysToGet[1])) {
if (keyValue.get(keysToGet[1]).equalsIgnoreCase("true")) {
json = true;
} else {
json = false;
}
}
int maxTokens = 150;
if (keyValue.containsKey(keysToGet[2])) {
maxTokens = Integer.parseInt(keyValue.get(keysToGet[2]));
}
List theseFrames = new ArrayList<>();
if (Boolean.parseBoolean(this.keyValue.get(this.keysToGet[4]))) {
theseFrames.addAll(this.getAllFrames());
if (theseFrames.isEmpty()) {
return NounMetadata.getErrorNounMessage("No frames found");
}
} else {
ITableDataFrame thisFrame = getFrameDefaultLast();
if (thisFrame == null) {
return NounMetadata.getErrorNounMessage("No frame found for " + keyValue.get(keysToGet[3]));
}
theseFrames.add(thisFrame);
}
String dialect = this.keyValue.get(this.keysToGet[5]);
if (dialect == null || (dialect = dialect.trim()).isEmpty()) {
dialect = "SQLite3";
}
IModelEngine engine = null;
if (keyValue.containsKey(keysToGet[6])) {
String engineId = this.keyValue.get(this.keysToGet[6]);
engine = (IModelEngine) Utility.getEngine(engineId);
}
if (engine == null) {
String engineId = DIHelper.getInstance().getProperty(Constants.SQL_MOOSE_MODEL);
engine = (IModelEngine) Utility.getEngine(engineId);
}
if(engine == null) {
throw new IllegalArgumentException("Model engine ID must be passed in or added as a property");
}
List retListForFrames = new ArrayList<>();
for (ITableDataFrame thisFrame : theseFrames) {
// StringBuffer finalDbString = new StringBuffer();
StringBuffer finalDbString2 = new StringBuffer();
// StringBuffer finalQuery = new StringBuffer();
// finalDbString.append("Given Database Schema: ");
finalDbString2.append("You are tasked with generating sql to best answer a user's question given a table schema and the question. Below is both schema and the question, respond with the correct sql and ensure that the output starts and ends with ``` markdown. If user specifies any space delimeted value related to a column from dataset then make sure to replace the space with an underscore character.\n\nTABLE SCHEMA: ");
Map columnTypes = thisFrame.getMetaData().getHeaderToTypeMap();
// finalDbString.append("CREATE TABLE ").append(thisFrame.getName()).append("(");
finalDbString2.append("CREATE TABLE ").append(thisFrame.getName()).append("(");
Iterator columns = columnTypes.keySet().iterator();
while (columns.hasNext()) {
String thisColumn = columns.next();
SemossDataType colType = columnTypes.get(thisColumn);
thisColumn = thisColumn.replace(thisFrame.getOriginalName() + "__", "");
String colTypeString = SemossDataType.convertDataTypeToString(colType);
if (colType == SemossDataType.DOUBLE || colType == SemossDataType.INT)
colTypeString = "NUMBER";
if (colType == SemossDataType.STRING)
colTypeString = "TEXT";
// finalDbString.append(thisColumn).append(" ").append(colTypeString).append(",");
finalDbString2.append(thisColumn).append(" ").append(colTypeString).append(",");
}
// finalDbString.append(")");
// finalDbString.append(". Provide an SQL to list ").append(query);
// finalDbString.append(". Be Concise. Provide as markdown. Output should start and end with ``` markdown.");
//
finalDbString2.append(")\n\n");
finalDbString2.append("USER QUESTION: ").append(query);
finalDbString2.append("\n\nRespond with the correct sql and ensure that the output starts and ends with ``` markdown.");
// classLogger.info(finalDbString + "");
classLogger.info("prompt2: "+finalDbString2 + "");
Object output = null;
Map params = new HashMap();
params.put("temperature", 0.3);
Map modelOutput = engine.ask(finalDbString2 + "", null, this.insight, params).toMap();
String response = modelOutput.get("response")+"";
classLogger.info("Response: "+response);
// if it comes in with finalDBString take it out
response = response.replace(finalDbString2, "");
String markdown = "```";
int start = response.indexOf(markdown);
if (start >= 0)
response = response.substring(start + markdown.length());
// get the select also
start = response.indexOf("SELECT");
if (start >= 0)
response = response.substring(start);
// remove the end quotes
int end = response.indexOf("```");
if (end >= 0)
response = response.substring(0, end);
end = response.indexOf(";");
if (end >= 0)
response = response.substring(0, end);
classLogger.info(response);
output = response;
// }
// get the string
// make a frame
// load the frame into insight
classLogger.info("SQL query is " + output);
// Create a new SQL Data Frame
String sqlDFQuery = output.toString().trim();
// remove the new line
sqlDFQuery = sqlDFQuery.replace("\n", " ");
sqlDFQuery = sqlDFQuery.replaceAll("[\\t\\n\\r]+"," ");
classLogger.info("sql df query: "+sqlDFQuery);
// execute sqlDF to create a frame
// need to check if the query is right and then feed this into sqldf
// need to parse this
// a. see if the table names match with the frame names if not change it
// b. See the constants and change the value based on the appropriate value the
// column has - you can circumvent this by giving value in quotes
String frameName = Utility.getRandomString(5);
Map outputMap = new HashMap<>();
boolean sameColumns = isSameColumns(sqlDFQuery, thisFrame);
outputMap.put("COLUMN_CHANGE", sameColumns + "");
if (thisFrame instanceof PandasFrame) {
sqlDFQuery = sqlDFQuery.replace("\"", "\\\"");
// do we need a way to check the library is installed?
PandasFrame pFrame = (PandasFrame) thisFrame;
String sqliteName = pFrame.getSQLite();
// pd.read_sql("select * from diab1 where age > 60", conn)
String frameMaker = frameName + " = pd.read_sql(\"" + sqlDFQuery + "\", " + sqliteName + ")";
classLogger.info("Creating frame with query.. " + sqlDFQuery + " <<>> " + frameMaker);
insight.getPyTranslator().runEmptyPy(frameMaker);
String sampleOut = insight.getPyTranslator().runSingle(frameName + ".head(20)", this.insight);
System.err.println(sampleOut);
// send information
// check to see if the variable was created
// if not this is a bad query
if (sampleOut != null && sampleOut.length() > 0) {
if (json) {
outputMap.put(ReactorKeysEnum.FRAME_TYPE.getKey(), DataFrameTypeEnum.PYTHON.getTypeAsString());
outputMap.put("Query", sqlDFQuery);
outputMap.put(ReactorKeysEnum.FRAME.getKey(), frameName);
outputMap.put("SAMPLE", sampleOut);
outputMap.put("COMMAND", "GenerateFrameFromPyVariable('" + frameName + "')");
retListForFrames.add(new NounMetadata(outputMap, PixelDataType.MAP));
} else {
StringBuffer outputString = new StringBuffer("Query Generated : " + sqlDFQuery);
outputString.append("\nData : " + frameName);
outputString.append("\n");
outputString.append(sampleOut);
outputString.append("\n");
retListForFrames.add(new NounMetadata(outputString.toString(), PixelDataType.CONST_STRING));
}
} else {
if (json) {
outputMap.put("Query", sqlDFQuery);
outputMap.put("SAMPLE", "Could not compute data, query is not correct.");
retListForFrames.add(new NounMetadata(outputMap, PixelDataType.MAP));
} else {
StringBuffer outputString = new StringBuffer("Query Generated : " + sqlDFQuery);
outputString.append("\n");
outputString.append("Query did not yield any results... ");
retListForFrames.add(new NounMetadata(outputString.toString(), PixelDataType.CONST_STRING));
}
try {
this.insight.getPyTranslator().runScript("del " + frameName + " , sqldf");
} catch (Exception ignored) {
}
}
} else if (thisFrame instanceof RDataTable) {
sqlDFQuery = sqlDFQuery.replace("\"", "\\\"");
AbstractRJavaTranslator rt = insight.getRJavaTranslator(this.getClass().getName());
rt.checkPackages(new String[] { "sqldf" });
String frameMaker = frameName + " <- sqldf(\"" + sqlDFQuery + "\")";
classLogger.info("Creating frame with query.. " + sqlDFQuery + " <<>> " + frameMaker);
rt.runRAndReturnOutput("library(sqldf)");
rt.runR(frameMaker); // load the sql df
boolean frameCreated = rt.runRAndReturnOutput("exists('" + frameName + "')").toUpperCase()
.contains("TRUE");
if (frameCreated) {
String sampleOut = rt.runRAndReturnOutput("head(" + frameName + ", 20)");
if (json) {
outputMap.put(ReactorKeysEnum.FRAME_TYPE.getKey(), DataFrameTypeEnum.R.getTypeAsString());
outputMap.put("Query", sqlDFQuery);
outputMap.put(ReactorKeysEnum.FRAME.getKey(), frameName);
outputMap.put("SAMPLE", sampleOut);
outputMap.put("COMMAND", "GenerateFrameFromRVariable('" + frameName + "')");
retListForFrames.add(new NounMetadata(outputMap, PixelDataType.MAP));
} else {
StringBuffer outputString = new StringBuffer("Query Generated : " + sqlDFQuery);
// now we just need to tell the user here is the frame
outputString.append("\nData : " + frameName);
outputString.append("\n");
outputString.append(sampleOut);
outputString.append("\n");
outputString.append(
"To start working with this frame GenerateFrameFromRVariable('" + frameName + "')");
retListForFrames.add(new NounMetadata(outputString.toString(), PixelDataType.CONST_STRING));
}
} else {
if (json) {
outputMap.put("Query", sqlDFQuery);
outputMap.put("SAMPLE", "Could not compute data, query is not correct.");
retListForFrames.add(new NounMetadata(outputMap, PixelDataType.MAP));
} else {
StringBuffer outputString = new StringBuffer("Query Generated : " + sqlDFQuery);
outputString.append("\n");
outputString.append("Query did not yield any results... ");
retListForFrames.add(new NounMetadata(outputString.toString(), PixelDataType.CONST_STRING));
}
}
} else if (thisFrame instanceof NativeFrame) {
// we do a query from a subquery
SelectQueryStruct allDataQs = thisFrame.getMetaData().getFlatTableQs(true);
String baseQuery = ((NativeFrame) thisFrame).getEngineQuery(allDataQs);
String newQuery = sqlDFQuery.replace(thisFrame.getName(),
"(" + baseQuery + ") as " + thisFrame.getName());
HardSelectQueryStruct hqs = new HardSelectQueryStruct();
hqs.setQuery(newQuery);
int counter = 0;
List> sampleOut = new ArrayList<>();
try {
IRawSelectWrapper it = thisFrame.query(hqs);
while (it.hasNext() && counter < 10) {
sampleOut.add(Arrays.asList(it.next().getValues()));
counter++;
}
if (json) {
outputMap.put(ReactorKeysEnum.FRAME_TYPE.getKey(), DataFrameTypeEnum.NATIVE.getTypeAsString());
outputMap.put("Query", newQuery);
outputMap.put(ReactorKeysEnum.FRAME.getKey(), frameName);
outputMap.put("SAMPLE", sampleOut.toString());
retListForFrames.add(new NounMetadata(outputMap, PixelDataType.MAP));
} else {
StringBuffer outputString = new StringBuffer("Query Generated : " + newQuery);
outputString.append("\nData : " + frameName);
outputString.append("\n");
outputString.append(sampleOut);
retListForFrames.add(new NounMetadata(outputString.toString(), PixelDataType.CONST_STRING));
}
} catch (Exception e) {
outputMap.put("Query", newQuery);
outputMap.put("SAMPLE", "Could not compute data, query is not correct.");
if (json) {
retListForFrames.add(new NounMetadata(outputMap, PixelDataType.MAP));
} else {
StringBuffer outputString = new StringBuffer("Query Generated : " + newQuery);
outputString.append("\n");
outputString.append("Query did not yield any results... ");
retListForFrames.add(new NounMetadata(outputString.toString(), PixelDataType.CONST_STRING));
}
}
} else if (thisFrame instanceof AbstractRdbmsFrame) {
HardSelectQueryStruct hqs = new HardSelectQueryStruct();
hqs.setQuery(sqlDFQuery);
int counter = 0;
List> sampleOut = new ArrayList<>();
try {
IRawSelectWrapper it = thisFrame.query(hqs);
while (it.hasNext() && counter < 10) {
sampleOut.add(Arrays.asList(it.next().getValues()));
counter++;
}
if (json) {
outputMap.put(ReactorKeysEnum.FRAME_TYPE.getKey(), DataFrameTypeEnum.GRID.getTypeAsString());
outputMap.put("Query", sqlDFQuery);
outputMap.put(ReactorKeysEnum.FRAME.getKey(), frameName);
outputMap.put("SAMPLE", sampleOut.toString());
retListForFrames.add(new NounMetadata(outputMap, PixelDataType.MAP));
} else {
StringBuffer outputString = new StringBuffer("Query Generated : " + sqlDFQuery);
outputString.append("\nData : " + frameName);
outputString.append("\n");
outputString.append(sampleOut);
retListForFrames.add(new NounMetadata(outputString.toString(), PixelDataType.CONST_STRING));
}
} catch (Exception e) {
outputMap.put("Query", sqlDFQuery);
outputMap.put("SAMPLE", "Could not compute data, query is not correct.");
if (json) {
retListForFrames.add(new NounMetadata(outputMap, PixelDataType.MAP));
} else {
StringBuffer outputString = new StringBuffer("Query Generated : " + sqlDFQuery);
outputString.append("\n");
outputString.append("Query did not yield any results... ");
retListForFrames.add(new NounMetadata(outputString.toString(), PixelDataType.CONST_STRING));
}
}
} else {
retListForFrames.add(getError(
"NLP Query 3 has only been implemented for python, r, grid, and native frame at this point, please convert your frames to python,r and try again"));
}
}
return new NounMetadata(retListForFrames, PixelDataType.VECTOR, PixelOperationType.VECTOR);
}
private boolean isSameColumns(String sqlDFQuery, ITableDataFrame thisFrame) {
boolean sameColumns = true;
try {
SqlParser2 p2 = new SqlParser2();
GenExpressionWrapper wrapper = p2.processQuery(sqlDFQuery);
String[] columnHeaders = thisFrame.getColumnHeaders();
boolean allColumns = false;
List selects = wrapper.root.nselectors;
if (selects.size() == 1) {
// possibly select *
GenExpression allSelect = selects.get(0);
allColumns = allSelect.getLeftExpr().equalsIgnoreCase("*");
// we are good
}
if (!allColumns) {
for (int selectorIndex = 0; selectorIndex < columnHeaders.length && sameColumns; selectorIndex++) // going
// to
// run
// a
// dual
// for
// loop
// here
{
String thisColumn = columnHeaders[selectorIndex];
boolean foundThisColumn = false;
for (int newColumnIndex = 0; newColumnIndex < selects.size(); newColumnIndex++) {
GenExpression thisSelector = selects.get(newColumnIndex);
String alias = thisSelector.getLeftAlias();
if (alias == null)
alias = thisSelector.getLeftExpr();
if (thisColumn.equalsIgnoreCase(alias))
foundThisColumn = true;
}
sameColumns = sameColumns & foundThisColumn;
}
}
} catch (Exception e) {
// TODO Auto-generated catch block
classLogger.info(e.getMessage());
;
sameColumns = false;
}
return sameColumns;
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy