
prerna.engine.impl.model.AbstractModelEngine Maven / Gradle / Ivy
The newest version!
package prerna.engine.impl.model;
import java.io.File;
import java.io.IOException;
import java.time.ZonedDateTime;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.UUID;
import org.apache.commons.io.FileUtils;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import com.google.gson.Gson;
import com.google.gson.GsonBuilder;
import prerna.engine.api.IEngine;
import prerna.engine.api.IModelEngine;
import prerna.engine.impl.SmssUtilities;
import prerna.engine.impl.model.responses.AskModelEngineResponse;
import prerna.engine.impl.model.responses.EmbeddingsModelEngineResponse;
import prerna.engine.impl.model.responses.InstructModelEngineResponse;
import prerna.engine.impl.model.workers.ModelEngineInferenceLogsWorker;
import prerna.io.connector.secrets.ISecrets;
import prerna.io.connector.secrets.SecretsFactory;
import prerna.om.Insight;
import prerna.util.Constants;
import prerna.util.EngineUtility;
import prerna.util.UploadUtilities;
import prerna.util.Utility;
public abstract class AbstractModelEngine implements IModelEngine {
private static final Logger classLogger = LogManager.getLogger(AbstractModelEngine.class);
public static final String OPEN_AI_KEY = "OPEN_AI_KEY";
public static final String AWS_SECRET_KEY = "AWS_SECRET_KEY";
public static final String AWS_ACCESS_KEY = "AWS_ACCESS_KEY";
public static final String GCP_SERVICE_ACCOUNT_KEY = "GCP_SERVICE_ACCOUNT_KEY";
public static final String MESSAGE_CONTENT = "content";
public static final String ROLE = "role";
public static final String TOOL_CALLS = "tool_calls";
public static final String TYPE = "type";
public static final String ID = "id";
public static final String FUNCTION = "function";
public static final String ARGUMENTS = "arguments";
public static final String NAME = "name";
// param keys
public static final String FULL_PROMPT = "full_prompt";
protected String engineId = null;
protected String engineName = null;
protected Properties smssProp = null;
protected String smssFilePath = null;
protected boolean keepConversationHistory = false;
protected boolean keepInputOutput = false;
protected boolean inferenceLogsEnbaled = Utility.isModelInferenceLogsEnabled();
@Override
public void open(String smssFilePath) throws Exception {
setSmssFilePath(smssFilePath);
this.open(Utility.loadProperties(smssFilePath));
}
@Override
public void open(Properties smssProp) throws Exception {
setSmssProp(smssProp);
this.engineId = this.smssProp.getProperty(Constants.ENGINE);
this.engineName = this.smssProp.getProperty(Constants.ENGINE_ALIAS);
ISecrets secretStore = SecretsFactory.getSecretConnector();
if(secretStore != null) {
Map engineSecrets = secretStore.getEngineSecrets(getCatalogType(), this.engineId, this.engineName);
if(engineSecrets == null || engineSecrets.isEmpty()) {
classLogger.info("No secrets found for " + SmssUtilities.getUniqueName(this.engineName, this.engineId));
} else {
classLogger.info("Successfully pulled secrets for " + SmssUtilities.getUniqueName(this.engineName, this.engineId));
this.smssProp.putAll(engineSecrets);
}
}
this.keepConversationHistory = Boolean.parseBoolean(this.smssProp.getProperty(Constants.KEEP_CONVERSATION_HISTORY));
this.keepInputOutput = Boolean.parseBoolean(this.smssProp.getProperty(Constants.KEEP_INPUT_OUTPUT));
if (this.smssProp.containsKey(Constants.KEEP_CONTEXT)) {
boolean keepContext = Boolean.parseBoolean(this.smssProp.getProperty(Constants.KEEP_CONTEXT));
this.keepConversationHistory = keepContext;
this.keepInputOutput = keepContext;
}
}
/**
* This is an abstract method for the implementation class such that tracking occurs
*
* @param question
* @param fullPrompt
* @param context
* @param insight
* @param hyperParameters
* @return
*/
protected abstract AskModelEngineResponse askCall(String question, Object fullPrompt, String context, Insight insight, Map hyperParameters);
@Override
public AskModelEngineResponse ask(String question, String context, Insight insight, Map parameters) {
/*
* We will check if there are any restrictions for the user's current token usage
* There might be a value set on the user-engine permission which takes priority
* or if there is none
* there might be a value set on the user for all their model engine usage
*/
// do we have any usage restriction on the user
Map userRestrictionMap = ModelUsageRestrictionUtility.getModelUsageRestriction(insight.getUser(), this.engineId);
if(parameters == null) {
parameters = new HashMap();
}
Object fullPrompt = parameters.remove(FULL_PROMPT);
ZonedDateTime inputTime = ZonedDateTime.now();
AskModelEngineResponse askModelResponse = askCall(question, fullPrompt, context, insight, parameters);
ZonedDateTime outputTime = ZonedDateTime.now();
askModelResponse.setMessageId(UUID.randomUUID().toString());
askModelResponse.setRoomId(insight.getInsightId());
if (inferenceLogsEnbaled) {
Thread inferenceRecorder = new Thread(new ModelEngineInferenceLogsWorker (
/*messageId*/askModelResponse.getMessageId(),
/*messageMethod*/"ask",
/*engine*/this,
/*insight*/insight,
/*context*/context,
/*prompt*/question,
/*fullPrompt*/fullPrompt,
/*promptTokens*/askModelResponse.getNumberOfTokensInPrompt(),
/*inputTime*/inputTime,
/*response*/askModelResponse.getStringResponse(),
/*responseTokens*/askModelResponse.getNumberOfTokensInResponse(),
/*outputTime*/outputTime
));
inferenceRecorder.start();
}
// update current usage based on this new request
ModelUsageRestrictionUtility.updateRestrictionMapCurrentUsage(userRestrictionMap, askModelResponse, inputTime, outputTime);
return askModelResponse;
}
/**
* This is an abstract method for the implementation class such that tracking occurs
*
* @param task
* @param context
* @param insight
* @param hyperParameters
* @return
*/
protected abstract InstructModelEngineResponse instructCall(String task, String context, List
© 2015 - 2025 Weber Informatics LLC | Privacy Policy