
prerna.engine.impl.model.AbstractPythonModelEngine Maven / Gradle / Ivy
The newest version!
package prerna.engine.impl.model;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Hashtable;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import org.apache.commons.text.StringSubstitutor;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import prerna.ds.py.PyTranslator;
import prerna.ds.py.PyUtils;
import prerna.engine.impl.SmssUtilities;
import prerna.engine.impl.model.inferencetracking.ModelInferenceLogsUtils;
import prerna.engine.impl.model.responses.AskModelEngineResponse;
import prerna.engine.impl.model.responses.AskToolModelEngineResponse;
import prerna.engine.impl.model.responses.EmbeddingsModelEngineResponse;
import prerna.engine.impl.model.responses.InstructModelEngineResponse;
import prerna.engine.impl.model.workers.ModelEngineInferenceLogsWorker;
import prerna.om.ClientProcessWrapper;
import prerna.om.Insight;
import prerna.tcp.PayloadStruct;
import prerna.util.Constants;
import prerna.util.Settings;
import prerna.util.Utility;
/**
* This class is responsible for creating a {@code IModelEngine} class that is directly linked to
* a python process. The corresponding python class should handle all method implementations. This java class is
* simply mechanism to forward calls to the python process.
*/
public abstract class AbstractPythonModelEngine extends AbstractModelEngine {
private static final Logger classLogger = LogManager.getLogger(AbstractPythonModelEngine.class);
// python server
protected String prefix = null;
protected String workingDirectory;
protected String workingDirectoryBasePath = null;
protected PyTranslator pyt = null;
protected File cacheFolder;
private ClientProcessWrapper cpw = null;
protected String varName = null;
// string substitute vars
protected Map vars = new HashMap<>();
private Map>> chatHistory = new Hashtable<>();
@Override
public void open(String smssFilePath) throws Exception {
setSmssFilePath(smssFilePath);
this.open(Utility.loadProperties(smssFilePath));
}
@Override
public void open(Properties smssProp) throws Exception {
super.open(smssProp);
if(!this.smssProp.containsKey(Settings.VAR_NAME)) {
String randomString = "v_" + Utility.getRandomString(6);
this.varName = randomString;
this.smssProp.put(Settings.VAR_NAME, randomString);
} else {
this.varName = this.smssProp.getProperty(Settings.VAR_NAME);
}
// vars for string substitution
for (Object smssKey : this.smssProp.keySet()) {
String key = smssKey.toString();
this.vars.put(key, this.smssProp.getProperty(key));
}
}
/**
* This method is responsible for starting the python process that is linked to this model engine.
*
* @param port The port number to use when creating the server/client connection.
*/
protected synchronized void startServer(int port) {
if(this.cpw != null && this.cpw.getSocketClient() != null && this.cpw.getSocketClient().isConnected()) {
return;
}
if(this.workingDirectoryBasePath == null) {
this.createCacheFolder();
}
// check if we have already created a process wrapper
ClientProcessWrapper cpwToInit = new ClientProcessWrapper();
if(this.cpw != null) {
this.cpw.shutdown(false);
}
String timeout = "30";
if(this.smssProp.containsKey(Constants.IDLE_TIMEOUT)) {
timeout = this.smssProp.getProperty(Constants.IDLE_TIMEOUT);
}
if(cpwToInit.getSocketClient() == null) {
boolean debug = false;
// pull the relevant values from the smss
String forcePort = this.smssProp.getProperty(Settings.FORCE_PORT);
String customClassPath = this.smssProp.getProperty("TCP_WORKER_CP");
String loggerLevel = this.smssProp.getProperty(Settings.LOGGER_LEVEL, "INFO");
String venvEngineId = this.smssProp.getProperty(Constants.VIRTUAL_ENV_ENGINE, null);
String venvPath = venvEngineId != null ? Utility.getVenvEngine(venvEngineId).pathToExecutable() : null;
if(port < 0) {
// port has not been forced
if(forcePort != null && !(forcePort=forcePort.trim()).isEmpty()) {
try {
port = Integer.parseInt(forcePort);
debug = true;
} catch(NumberFormatException e) {
// ignore
classLogger.warn("Model " + this.getEngineName() + " has an invalid FORCE_PORT value");
}
}
}
String serverDirectory = this.cacheFolder.getAbsolutePath();
boolean nativePyServer = true; // it has to be -- don't change this unless you can send engine calls from python
try {
cpwToInit.createProcessAndClient(nativePyServer, null, port, venvPath, serverDirectory, customClassPath, debug, timeout, loggerLevel);
} catch (Exception e) {
classLogger.error(Constants.STACKTRACE, e);
throw new IllegalArgumentException("Unable to connect to server for python model engine.");
}
} else if (!cpwToInit.getSocketClient().isConnected()) {
cpwToInit.shutdown(false);
try {
cpwToInit.reconnect();
} catch (Exception e) {
classLogger.error(Constants.STACKTRACE, e);
throw new IllegalArgumentException("Failed to start TCP Server for Python Model Engine = " +this.getEngineName());
}
}
// create the py translator
pyt = new PyTranslator();
pyt.setSocketClient(cpwToInit.getSocketClient());
try {
// execute all the basic commands
String initCommands = this.smssProp.getProperty(Constants.INIT_MODEL_ENGINE);
// break the commands seperated by ;
String [] commands = initCommands.split(PyUtils.PY_COMMAND_SEPARATOR);
// replace the Vars
for(int commandIndex = 0; commandIndex < commands.length;commandIndex++) {
commands[commandIndex] = fillVars(commands[commandIndex]);
}
pyt.runEmptyPy(commands);
// for debugging...
classLogger.info("Initializing " + SmssUtilities.getUniqueName(this.engineName, this.engineId)
+ " python process with commands >>> " + String.join("\n", commands));
// run a prefix command
setPrefix(cpwToInit);
// finally set the cpw in the class
this.cpw = cpwToInit;
} catch(Exception e) {
classLogger.error(Constants.STACKTRACE, e);
if(cpwToInit != null) {
classLogger.warn("Able to start the python process for the python model engine "
+ SmssUtilities.getUniqueName(this.engineName, this.engineId)
+ " but the start script failed.");
cpwToInit.shutdown(false);
}
throw e;
}
}
/**
* This method checks whether the socket client is instantiated and connected.
*/
protected void checkSocketStatus() {
if(this.cpw == null || this.cpw.getSocketClient() == null || !this.cpw.getSocketClient().isConnected()) {
this.startServer(-1);
}
}
/**
*
*/
private void setPrefix(ClientProcessWrapper cpwToInit) {
this.prefix = cpwToInit.getPrefix();
PayloadStruct prefixPayload = new PayloadStruct();
prefixPayload.payload = new String[] {"prefix", this.prefix};
prefixPayload.operation = PayloadStruct.OPERATION.CMD;
cpwToInit.getSocketClient().executeCommand(prefixPayload);
}
@Override
public AskModelEngineResponse askCall(String question, Object fullPrompt, String context, Insight insight, Map parameters) {
checkSocketStatus();
boolean keepConvoHisotry = this.keepsConversationHistory();
final String TRIPLE_QUOTE = "\"\"\"";
StringBuilder callMaker = new StringBuilder(varName + ".ask(");
if (fullPrompt != null) {
callMaker.append(FULL_PROMPT)
.append("=")
.append(PyUtils.determineStringType(fullPrompt));
} else {
if(question.startsWith("\"")) {
question = " " + question;
}
if(question.endsWith("\"")) {
question = question + " ";
}
question = question.replace(TRIPLE_QUOTE, "\\\"\\\"\\\"");
callMaker.append("question=")
.append(TRIPLE_QUOTE)
.append(question)
.append(TRIPLE_QUOTE);
if(context != null) {
if(context.startsWith("\"")) {
context = " " + context;
}
if(context.endsWith("\"")) {
context = context + " ";
}
context = context.replace(TRIPLE_QUOTE, "\\\"\\\"\\\"");
callMaker.append(",")
.append("context=")
.append(TRIPLE_QUOTE)
.append(context)
.append(TRIPLE_QUOTE);
}
if (parameters.containsKey("toolExecution")) {
Map toolExecutionMap = (Map) parameters.get("toolExecution");
if (chatHistory.containsKey(insight.getInsightId())) {
chatHistory.get(insight.getInsightId()).add(toolExecutionMap);
}
parameters.remove("toolExecution");
}
String history = getConversationHistory(insight.getUserId(), insight.getInsightId(), keepConvoHisotry);
if(history != null) {
//could still be null if its the first question in the convo
callMaker.append(",")
.append("history=")
.append(history);
}
}
if(parameters != null && !parameters.isEmpty()) {
Iterator paramKeys = parameters.keySet().iterator();
while(paramKeys.hasNext()) {
String key = paramKeys.next();
Object value = parameters.get(key);
callMaker.append(",")
.append(key)
.append("=")
.append(PyUtils.determineStringType(value));
}
}
if(this.prefix != null) {
callMaker.append(", prefix='")
.append(prefix)
.append("'");
}
callMaker.append(")");
classLogger.debug("Running >>>" + callMaker.toString());
Object output = pyt.runSmssWrapperEval(callMaker.toString(), insight);
AskModelEngineResponse response = null;
try {
response = AskModelEngineResponse.fromObject(output);
} catch(Exception e) {
classLogger.warn("Could not create response object from output = " + output);
classLogger.error(Constants.STACKTRACE, e);
throw new IllegalArgumentException(e.getMessage());
}
if (keepConvoHisotry) {
//IF ITS A tool call - then append adjust history
Map inputMap = new HashMap<>();
Map outputMap = new HashMap<>();
inputMap.put(ROLE, "user");
inputMap.put(MESSAGE_CONTENT, question);
outputMap.put(ROLE, "assistant");
if(response.getMessageType().equalsIgnoreCase(AskModelEngineResponse.TOOL)) {
AskToolModelEngineResponse toolResponse = (AskToolModelEngineResponse) response;
// Create the tool call structure
Map toolCall = new HashMap<>();
toolCall.put(TYPE, "function");
toolCall.put(ID, toolResponse.getToolCallId());
Map functionMap = new HashMap<>();
functionMap.put(ARGUMENTS, toolResponse.getToolCallArgumentsAsString());
functionMap.put(NAME, toolResponse.getToolCallName());
toolCall.put(FUNCTION, functionMap);
// Add tool call to output map
outputMap.put(TOOL_CALLS, Arrays.asList(toolCall));
outputMap.put(MESSAGE_CONTENT, ""); // Empty content for tool
}
else {
// Regular response
outputMap.put(MESSAGE_CONTENT, response.getStringResponse());
}
// Update chat history
if (chatHistory.containsKey(insight.getInsightId())) {
chatHistory.get(insight.getInsightId()).add(inputMap);
chatHistory.get(insight.getInsightId()).add(outputMap);
}
}
return response;
}
@Override
public InstructModelEngineResponse instructCall(String task, String context, List
© 2015 - 2025 Weber Informatics LLC | Privacy Policy