All Downloads are FREE. Search and download functionalities are using the official Maven repository.

prerna.test.AskToolReactor Maven / Gradle / Ivy

The newest version!
package prerna.test;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.UUID;

import prerna.auth.User;
import prerna.auth.utils.SecurityEngineUtils;
import prerna.engine.api.IFunctionEngine;
import prerna.engine.api.IModelEngine;
import prerna.engine.impl.model.AbstractModelEngine;
import prerna.engine.impl.model.responses.AbstractModelEngineResponse;
import prerna.engine.impl.model.responses.AskModelEngineResponse;
import prerna.engine.impl.model.responses.AskToolModelEngineResponse;
import prerna.project.api.IProject;
import prerna.reactor.AbstractReactor;
import prerna.sablecc2.om.GenRowStruct;
import prerna.sablecc2.om.PixelDataType;
import prerna.sablecc2.om.PixelOperationType;
import prerna.sablecc2.om.ReactorKeysEnum;
import prerna.sablecc2.om.nounmeta.NounMetadata;
import prerna.util.Utility;

import java.util.regex.Matcher;
import java.util.regex.Pattern;


import com.fasterxml.jackson.databind.ObjectMapper;

import java.util.ArrayList;

public class AskToolReactor extends AbstractReactor {
    private static final Pattern MARKDOWN_CODE_PATTERN = Pattern.compile(
        "```" +                          // Opening backticks
        "(?:([a-zA-Z0-9]+))?" +          // Language (optional, group 1)
        "(?:" +                          // Non-capturing group for title alternatives
            "\\s+title=\"([^\"]+)\"" +   // Either title="filename" (group 2)
            "|\\s+([^\\s\\n]+)" +        // Or direct filename (group 3)
        ")?" +                           // Title is optional
        "\\s*\\n" +                      // Whitespace and mandatory newline
        "(.*?)" +                        // Code content (group 4)
        "```",                           // Closing backticks
        Pattern.DOTALL
    );
    public AskToolReactor() {
        this.keysToGet = new String[] { ReactorKeysEnum.ENGINE.getKey(), ReactorKeysEnum.COMMAND.getKey(),
                ReactorKeysEnum.CONTEXT.getKey(), ReactorKeysEnum.PARAM_VALUES_MAP.getKey() ,"engine_tools","project_tools"};
        this.keyRequired = new int[] { 1, 1, 0, 0 , 0, 0};
    }

    @Override
    public NounMetadata execute() {
        organizeKeys();
        String engineId = this.keyValue.get(this.keysToGet[0]);
        User user = this.insight.getUser();
        if (!SecurityEngineUtils.userCanViewEngine(user, engineId)) {
            throw new IllegalArgumentException(
                    "Model " + engineId + " does not exist or user does not have access to this model");
        }

        String question = Utility.decodeURIComponent(this.keyValue.get(this.keysToGet[1]));
        String context = this.keyValue.get(this.keysToGet[2]);
        if (context != null) {
            context = Utility.decodeURIComponent(context);
        }

        Map paramMap = getMap();
        IModelEngine modelEngine = Utility.getModel(engineId);
        if (paramMap == null) {
            paramMap = new HashMap();
        }

		List engineToolIDs = getEngineToolIDs();
		List projectToolIDs = getProjectToolIDs();

        if (!engineToolIDs.isEmpty() || !projectToolIDs.isEmpty()) {
        	
            // Check if the "tools_choice" key exists in the paramMap, else add it
            if (!paramMap.containsKey("tool_choice")) {
            	paramMap.put("tool_choice", "auto");                
            }

            List > toolsList;
            
            // Check if the "tools" key exists in the paramMap
            // this is if a user has explicitly adding tools to the param map. 
            if (paramMap.containsKey("tools")) {
                // Retrieve the existing list of tools
                toolsList = (List >) paramMap.get("tools");
            } else {
                // Create a new list for tools
                toolsList = new ArrayList >();
                paramMap.put("tools", toolsList);
            }

            
            // Iterate over each engine ID and add the tool to the tools list
            for (String engineToolID : engineToolIDs) {
            	//TODO add a safety check here for function engines only
                IFunctionEngine function = Utility.getFunctionEngine(engineToolID);
                Map functionToolMap = function.buildFunctionEngineToolMap();
                toolsList.add(functionToolMap);
            }
            
            // Iterate over each project ID and add the tool to the tools list
            for (String projectToolID : projectToolIDs) {
            	//TODO add a safety check here for code projects only
                IProject project = Utility.getProject(projectToolID);
                Map projectToolMap = project.buildProjectToolMap();
                toolsList.add(projectToolMap);
            }
            
        }
		
        AskModelEngineResponse modelResponse = modelEngine.ask(question, context, this.insight, paramMap);

        Map>> output = processModelResponse(modelResponse);
          
        return new NounMetadata(output, PixelDataType.MAP, PixelOperationType.OPERATION);
    }
    
    private Map>> processModelResponse(AskModelEngineResponse modelResponse){
        Map>> output = new HashMap>>();
        output.put("response", new ArrayList>());
        if(modelResponse.getMessageType().equalsIgnoreCase(AskModelEngineResponse.TOOL)) {  	
            // the response is for a tool call
            // we need to call the actual tool now. 
            AskToolModelEngineResponse toolResponse = (AskToolModelEngineResponse) modelResponse;

            // {"function_id":"123-3345-567","map":{"lat":"123","lon":"321"}}
            String toolArguments = toolResponse.getToolCallArgumentsAsString();

            ObjectMapper mapper = new ObjectMapper();
            Map functionParams = new HashMap();
            try {
                functionParams = mapper.readValue(toolArguments, Map.class);
            } catch (Exception e) {
                // Handle parsing error
                functionParams = null;
            }

            Map outputObject = new HashMap();
            String toolName;
            String toolType;
            
            if(toolResponse.getResponse().get("name").equals("project_engine")){
                IProject project = Utility.getProject((String) functionParams.get("id"));
                toolName = project.getProjectName();
                toolType = "PROJECT";
            } else {
                IFunctionEngine function = Utility.getFunctionEngine((String) functionParams.get("id"));
                toolName = function.getEngineName();
                toolType = "FUNCTION";
            }

            // object to store params needed to call the tool
            List> toolCallInfoData = new ArrayList>();
            for(Entry functionParam : ((Map)functionParams.get("map")).entrySet()){
                HashMap paramInfo = new HashMap();
                paramInfo.put("name", functionParam.getKey());
                paramInfo.put("type", functionParam.getValue().getClass().getSimpleName());
                paramInfo.put("value", functionParam.getValue());
                toolCallInfoData.add(paramInfo);
            }

            outputObject.put("type", toolType);
            outputObject.put("name", toolName);
            outputObject.put("id", (String) functionParams.get("id"));
            outputObject.put("parameters", toolCallInfoData);

            output.get("response").add(outputObject);

            //remove the execution of the function for now. will add back later with a boolean passed in
//            Object functionReturn = function.execute((Map )functionParams.get("map"));
//            String functionReturnString = null;
//
//            try {
//                functionReturnString = mapper.writeValueAsString(functionReturn);
//            } catch (JsonProcessingException e) {
//                // Handle the exception, maybe log it or return a default value
//                e.printStackTrace();
//                functionReturnString = "{}";
//            }
//
//            toolExecutionMap.put("content", functionReturnString);         
//            paramMap.put("toolExecution", toolExecutionMap);
//            AskModelEngineResponse toolExecutionResponse = modelEngine.ask("", null, this.insight, paramMap);
//            output = toolExecutionResponse.toMap();
        } else {
            // 	this is a standard response - process it for code blocks.
        	
            // Process the response to extract code blocks and replace with UUID references
            ProcessedResponse processedResponse = processMarkdownCodeBlocks(modelResponse.getStringResponse());

            // Add code blocks to output if any exist
            if (!processedResponse.getCodeBlocks().isEmpty()) {
                String[] splitResponse = processedResponse.getModifiedResponse().split(".*<\\/CODEBLOCK>");

                for(int i = 0; i < splitResponse.length; i++){
                    Map outputObject = new HashMap();
                    outputObject.put("type", "CONTENT");
                    outputObject.put("content", splitResponse[i]);
                    output.get("response").add(outputObject);
                    if(i < processedResponse.getCodeBlocks().values().toArray().length){
                        CodeBlock codeBlock = (CodeBlock) processedResponse.getCodeBlocks().values().toArray()[i];
                        HashMap paramInfo = new HashMap();
                        paramInfo.put("type", "CODE");
                        paramInfo.put("language", codeBlock.getLanguage());
                        paramInfo.put("name", codeBlock.getTitle());
                        paramInfo.put("content", codeBlock.getCode());
                        output.get("response").add(paramInfo);
                    }
                }

                Map outputObject = new HashMap();
                outputObject.put("originalResponse", modelResponse.getStringResponse());
                output.get("response").add(outputObject);
            } else {
                Map outputObject = new HashMap();
                outputObject.put("type", "CONTENT");
                outputObject.put("content", modelResponse.getStringResponse());
                output.get("response").add(outputObject);
            }
        }
        return output;
    }

 // Method to parse markdown code blocks
    private ProcessedResponse processMarkdownCodeBlocks(String response) {
        Map codeBlocks = new HashMap<>();
        Matcher matcher = MARKDOWN_CODE_PATTERN.matcher(response);
        StringBuffer modifiedResponse = new StringBuffer();

        while (matcher.find()) {
            String language = matcher.group(1) != null ? matcher.group(1).trim() : "";
            // Check both title formats and use the first non-null one
            String title = matcher.group(2) != null ? matcher.group(2).trim() : 
                        matcher.group(3) != null ? matcher.group(3).trim() : "";
            String code = matcher.group(4).trim();
            
            String uuid = UUID.randomUUID().toString();
            codeBlocks.put(uuid, new CodeBlock(language, code, title));
            
            matcher.appendReplacement(modifiedResponse, 
                Matcher.quoteReplacement("" + uuid + ""));
        }
        matcher.appendTail(modifiedResponse);

        return new ProcessedResponse(modifiedResponse.toString(), codeBlocks);
    }

	/**
	 * 
	 * @return list of engines 
	 */
	public List getEngineToolIDs() {
		List inputStrings = new ArrayList<>();

		// see if added as key
		GenRowStruct grs = this.store.getNoun(this.keysToGet[4]);
		if (grs != null && !grs.isEmpty()) {
			int size = grs.size();
			for (int i = 0; i < size; i++) {
				inputStrings.add(grs.get(i).toString());
			}
			return inputStrings;
		}

		// no key is added, grab all inputs
		int size = this.curRow.size();
		for (int i = 0; i < size; i++) {
			inputStrings.add(this.curRow.get(i).toString());
		}
		
		return inputStrings;
	}
	
	public List getProjectToolIDs() {
		List inputStrings = new ArrayList<>();

		// see if added as key
		GenRowStruct grs = this.store.getNoun(this.keysToGet[5]);
		if (grs != null && !grs.isEmpty()) {
			int size = grs.size();
			for (int i = 0; i < size; i++) {
				inputStrings.add(grs.get(i).toString());
			}
			return inputStrings;
		}

		// no key is added, grab all inputs
		int size = this.curRow.size();
		for (int i = 0; i < size; i++) {
			inputStrings.add(this.curRow.get(i).toString());
		}
		
		return inputStrings;
	}

    private Map getMap() {
        GenRowStruct mapGrs = this.store.getNoun(keysToGet[3]);
        if (mapGrs != null && !mapGrs.isEmpty()) {
            List mapInputs = mapGrs.getNounsOfType(PixelDataType.MAP);
            if (mapInputs != null && !mapInputs.isEmpty()) {
                return (Map) mapInputs.get(0).getValue();
            }
        }
        List mapInputs = this.curRow.getNounsOfType(PixelDataType.MAP);
        if (mapInputs != null && !mapInputs.isEmpty()) {
            return (Map) mapInputs.get(0).getValue();
        }
        return null;
    }
    

    // Helper class to represent the processed response
    private static class ProcessedResponse {
        private final String modifiedResponse;
        private final Map codeBlocks;

        public ProcessedResponse(String modifiedResponse, Map codeBlocks) {
            this.modifiedResponse = modifiedResponse;
            this.codeBlocks = codeBlocks;
        }

        public String getModifiedResponse() {
            return modifiedResponse;
        }

        public Map getCodeBlocks() {
            return codeBlocks;
        }
    }

    // Class to represent a code block
    private static class CodeBlock {
        private final String language;
        private final String code;
        private final String title;

        public CodeBlock(String language, String code, String title) {
            this.language = language;
            this.code = code;
            this.title = title;
        }

        public String getLanguage() {
            return language;
        }

        public String getCode() {
            return code;
        }

        public String getTitle() {
            return title;
        }
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy