All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.microsoft.semantickernel.aiservices.openai.chatcompletion.XMLPromptParser Maven / Gradle / Ivy

// Copyright (c) Microsoft. All rights reserved.
package com.microsoft.semantickernel.aiservices.openai.chatcompletion;

import com.azure.ai.openai.models.ChatRequestAssistantMessage;
import com.azure.ai.openai.models.ChatRequestFunctionMessage;
import com.azure.ai.openai.models.ChatRequestMessage;
import com.azure.ai.openai.models.ChatRequestSystemMessage;
import com.azure.ai.openai.models.ChatRequestToolMessage;
import com.azure.ai.openai.models.ChatRequestUserMessage;
import com.azure.ai.openai.models.FunctionDefinition;
import com.azure.core.util.BinaryData;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.microsoft.semantickernel.exceptions.SKException;
import com.microsoft.semantickernel.orchestration.ToolCallBehavior;
import com.microsoft.semantickernel.services.chatcompletion.AuthorRole;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.UUID;
import javax.xml.namespace.QName;
import javax.xml.stream.XMLEventReader;
import javax.xml.stream.XMLInputFactory;
import javax.xml.stream.XMLStreamException;
import javax.xml.stream.events.Attribute;
import javax.xml.stream.events.StartElement;
import javax.xml.stream.events.XMLEvent;
import org.apache.commons.text.StringEscapeUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

class XMLPromptParser {

    private static final Logger LOGGER = LoggerFactory.getLogger(XMLPromptParser.class);

    public static ParsedPrompt parse(String rawPrompt) {
        List prompts = Arrays.asList(
            rawPrompt,
            "" + rawPrompt + "");

        for (String prompt : prompts) {
            try {
                List parsedMessages = getChatRequestMessages(prompt);
                List parsedFunctions = getFunctionDefinitions(prompt);

                if (!parsedMessages.isEmpty()) {
                    return new ParsedPrompt(parsedMessages, parsedFunctions);
                }
            } catch (SKException e) {
                //ignore
            }
        }

        ChatRequestUserMessage message = new ChatRequestUserMessage(rawPrompt);

        if (message.getName() == null) {
            message.setName(UUID.randomUUID().toString());
        }

        return new ParsedPrompt(Collections.singletonList(message), null);
    }

    private static List getChatRequestMessages(String prompt) {

        // TODO: XML parsing should be done as a chain of XMLEvent handlers.
        // If one handler does not recognize the element, it should pass it to the next handler.
        // In this way, we can avoid parsing the whole prompt twice and easily extend the parsing logic.
        List messages = new ArrayList<>();
        try (InputStream is = new ByteArrayInputStream(prompt.getBytes(StandardCharsets.UTF_8))) {
            XMLInputFactory factory = XMLInputFactory.newInstance();
            XMLEventReader reader = factory.createXMLEventReader(is);
            while (reader.hasNext()) {
                XMLEvent event = reader.nextEvent();
                if (event.isStartElement()) {
                    String name = getElementName(event);
                    if (name.equals("message")) {
                        String role = getAttributeValue(event, "role");
                        String content = reader.getElementText();
                        messages.add(getChatRequestMessage(role, content));
                    }
                }
            }
        } catch (IOException | XMLStreamException | IllegalArgumentException e) {
            throw new SKException("Failed to parse messages");
        }
        return messages;
    }

    private static List getFunctionDefinitions(String prompt) {
        // TODO: XML parsing should be done as a chain of XMLEvent handlers. See previous remark.
        // 
        //      ...
        // 
        List functionDefinitions = new ArrayList<>();
        try (InputStream is = new ByteArrayInputStream(prompt.getBytes(StandardCharsets.UTF_8))) {
            XMLInputFactory factory = XMLInputFactory.newFactory();
            XMLEventReader reader = factory.createXMLEventReader(is);
            FunctionDefinition functionDefinition = null;
            Map parameters = new HashMap<>();
            List requiredParameters = new ArrayList<>();
            while (reader.hasNext()) {
                XMLEvent event = reader.nextEvent();
                if (event.isStartElement()) {
                    String elementName = getElementName(event);
                    if (elementName.equals("function")) {
                        assert functionDefinition == null;
                        assert parameters.isEmpty();
                        assert requiredParameters.isEmpty();
                        String pluginName = getAttributeValue(event, "pluginName");
                        String name = getAttributeValue(event, "name");
                        String description = getAttributeValue(event, "description");
                        // name has to match '^[a-zA-Z0-9_-]{1,64}$'
                        functionDefinition = new FunctionDefinition(
                            ToolCallBehavior.formFullFunctionName(pluginName, name))
                            .setDescription(description);
                    } else if (elementName.equals("parameter")) {
                        String name = getAttributeValue(event, "name");
                        String type = getAttributeValue(event, "type").toLowerCase(Locale.ROOT);
                        String description = getAttributeValue(event, "description");
                        parameters.put(name,
                            String.format("{\"type\": \"%s\", \"description\": \"%s\"}",
                                "string",
                                description));

                        String isRequired = getAttributeValue(event, "isRequired");
                        if (Boolean.parseBoolean(isRequired)) {
                            requiredParameters.add(name);
                        }
                    }
                } else if (event.isEndElement()) {
                    String elementName = getElementName(event);
                    if (elementName.equals("function")) {
                        // Example JSON Schema:
                        // {
                        //    "type": "function",
                        //    "function": {
                        //        "name": "get_current_weather",
                        //        "description": "Get the current weather in a given location",
                        //        "parameters": {
                        //            "type": "object",
                        //            "properties": {
                        //                "location": {
                        //                    "type": "string",
                        //                    "description": "The city and state, e.g. San Francisco, CA",
                        //                },
                        //               "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
                        //            },
                        //            "required": ["location"],
                        //        },
                        //    },
                        //}
                        if (functionDefinition == null) {
                            throw new SKException("Failed to parse function definition");
                        }
                        if (!parameters.isEmpty()) {
                            StringBuilder sb = new StringBuilder(
                                "{\"type\": \"object\", \"properties\": {");
                            parameters.forEach((name, value) -> {
                                // make "param": {"type": "string", "description": "desc"},
                                sb.append(String.format("\"%s\": %s,", name, value));
                            });
                            // strip off trailing comma and close the properties object
                            sb.replace(sb.length() - 1, sb.length(), "}");
                            if (!requiredParameters.isEmpty()) {
                                sb.append(", \"required\": [");
                                requiredParameters.forEach(name -> {
                                    sb.append(String.format("\"%s\",", name));
                                });
                                // strip off trailing comma and close the required array
                                sb.replace(sb.length() - 1, sb.length(), "]");
                            }
                            // close the object
                            sb.append("}");
                            //System.out.println(sb.toString());
                            ObjectMapper objectMapper = new ObjectMapper();
                            JsonNode jsonNode = objectMapper.readTree(sb.toString());
                            BinaryData binaryData = BinaryData.fromObject(jsonNode);
                            functionDefinition.setParameters(binaryData);
                        }
                        functionDefinitions.add(functionDefinition);
                        functionDefinition = null;
                        parameters.clear();
                        requiredParameters.clear();
                    }
                }
            }
        } catch (IOException | XMLStreamException | IllegalArgumentException e) {
            LOGGER.error("Error parsing prompt", e);
        }
        return functionDefinitions;
    }

    private static String getElementName(XMLEvent xmlEvent) {
        if (xmlEvent.isStartElement()) {
            return xmlEvent.asStartElement().getName().getLocalPart();
        } else if (xmlEvent.isEndElement()) {
            return xmlEvent.asEndElement().getName().getLocalPart();
        }
        // TODO: programmer's error - log at debug
        return "";
    }

    private static String getAttributeValue(XMLEvent xmlEvent, String attributeName) {
        if (xmlEvent.isStartElement()) {
            StartElement element = xmlEvent.asStartElement();
            Attribute attribute = element.getAttributeByName(QName.valueOf(attributeName));
            return attribute != null ? attribute.getValue() : "";
        }
        // TODO: programmer's error - log at debug
        return "";
    }

    private static ChatRequestMessage getChatRequestMessage(
        String role,
        String content) {
        try {
            AuthorRole authorRole = AuthorRole.valueOf(role.toUpperCase(Locale.ROOT));
            return OpenAIChatCompletion.getChatRequestMessage(authorRole, content);
        } catch (IllegalArgumentException e) {
            LOGGER.debug("Unknown author role: " + role);
            throw new SKException("Unknown author role: " + role);
        }
    }

    public static ChatRequestMessage unescapeRequest(ChatRequestMessage message) {
        if (message instanceof ChatRequestUserMessage) {
            ChatRequestUserMessage chatRequestMessage = (ChatRequestUserMessage) message;
            String content = StringEscapeUtils.unescapeXml(
                chatRequestMessage.getContent().toString());

            return new ChatRequestUserMessage(content)
                .setName(chatRequestMessage.getName());
        } else if (message instanceof ChatRequestSystemMessage) {
            ChatRequestSystemMessage chatRequestMessage = (ChatRequestSystemMessage) message;
            String content = StringEscapeUtils.unescapeXml(chatRequestMessage.getContent());

            return new ChatRequestSystemMessage(content)
                .setName(chatRequestMessage.getName());
        } else if (message instanceof ChatRequestAssistantMessage) {
            ChatRequestAssistantMessage chatRequestMessage = (ChatRequestAssistantMessage) message;
            String content = StringEscapeUtils.unescapeXml(chatRequestMessage.getContent());

            return new ChatRequestAssistantMessage(content)
                .setToolCalls(chatRequestMessage.getToolCalls())
                .setFunctionCall(chatRequestMessage.getFunctionCall())
                .setName(chatRequestMessage.getName());
        } else if (message instanceof ChatRequestFunctionMessage) {
            ChatRequestFunctionMessage chatRequestMessage = (ChatRequestFunctionMessage) message;
            String content = StringEscapeUtils.unescapeXml(chatRequestMessage.getContent());

            return new ChatRequestFunctionMessage(
                chatRequestMessage.getName(),
                content);
        } else if (message instanceof ChatRequestToolMessage) {
            ChatRequestToolMessage chatRequestMessage = (ChatRequestToolMessage) message;
            String content = StringEscapeUtils.unescapeXml(chatRequestMessage.getContent());

            return new ChatRequestToolMessage(
                content,
                chatRequestMessage.getToolCallId());
        }

        throw new SKException("Unknown message type: " + message.getClass().getSimpleName());
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy