All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.github.microcks.util.ai.OpenAICopilot Maven / Gradle / Ivy

The newest version!
/*
 * Copyright The Microcks Authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *  http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package io.github.microcks.util.ai;

import io.github.microcks.domain.Exchange;
import io.github.microcks.domain.Operation;
import io.github.microcks.domain.Resource;
import io.github.microcks.domain.Service;
import io.github.microcks.domain.ServiceType;
import io.github.microcks.util.DispatchStyles;

import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.PropertyNamingStrategies;
import com.theokanning.openai.completion.chat.ChatCompletionChoice;
import com.theokanning.openai.completion.chat.ChatCompletionRequest;
import com.theokanning.openai.completion.chat.ChatCompletionResult;
import com.theokanning.openai.completion.chat.ChatMessage;
import com.theokanning.openai.completion.chat.ChatMessageRole;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.boot.web.client.RestTemplateBuilder;
import org.springframework.http.HttpEntity;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.converter.json.MappingJackson2HttpMessageConverter;
import org.springframework.web.client.RestTemplate;

import java.time.Duration;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/**
 * This is an implementation of {@code AICopilot} using OpenAI API.
 * @author laurent
 */
public class OpenAICopilot implements AICopilot {

   /** A simple logger for diagnostic messages. */
   private static Logger log = LoggerFactory.getLogger(OpenAICopilot.class);


   /** Configuration parameter holding the OpenAI API key. */
   public static final String API_KEY_CONFIG = "api-key";

   /** Configuration parameter holding the OpenAI API URL. */
   public static final String API_URL_CONFIG = "api-url";

   /** Configuration parameters holding the timeout in seconds for API calls. */
   public static final String TIMEOUT_KEY_CONFIG = "timeout";

   /** Configuration parameter holding the name of model to use. */
   public static final String MODEL_KEY_CONFIG = "model";

   /** Configuration parameter holding the maximum number of tokens to use. */
   public static final String MAX_TOKENS_KEY_CONFIG = "maxTokens";

   /** The mandatory configuration keys required by this implementation. */
   protected static final String[] MANDATORY_CONFIG_KEYS = { API_KEY_CONFIG };


   /** Default online URL for OpenAI API. */
   private static final String OPENAI_BASE_URL = "https://api.openai.com/";

   private static final String SECTION_DELIMITER = "\n###\n";

   private RestTemplate restTemplate;

   private String apiUrl = OPENAI_BASE_URL;

   private String apiKey;

   private Duration timeout = Duration.ofSeconds(20);

   private String model = "gpt-3.5-turbo";

   private int maxTokens = 2000;


   /**
    * Build a new OpenAICopilot with its configuration.
    * @param configuration The configuration for connecting to OpenAI services.
    */
   public OpenAICopilot(Map configuration) {
      if (configuration.containsKey(TIMEOUT_KEY_CONFIG)) {
         try {
            timeout = Duration.ofSeconds(Integer.parseInt(configuration.get(TIMEOUT_KEY_CONFIG)));
         } catch (Exception e) {
            log.warn("Timeout was provided but cannot be parsed. Sticking to the default.");
         }
      }
      if (configuration.containsKey(MAX_TOKENS_KEY_CONFIG)) {
         try {
            maxTokens = Integer.parseInt(configuration.get(MAX_TOKENS_KEY_CONFIG));
         } catch (Exception e) {
            log.warn("MaxTokens was provided but cannot be parsed. Sticking to the default.");
         }
      }
      if (configuration.containsKey(MODEL_KEY_CONFIG)) {
         model = configuration.get(MODEL_KEY_CONFIG);
      }
      if (configuration.containsKey(API_URL_CONFIG)) {
         apiUrl = configuration.get(API_URL_CONFIG);
      }
      // Finally retrieve the OpenAI Api key.
      apiKey = configuration.get(API_KEY_CONFIG);

      // Initialize a Rest template for interacting with OpenAI API.
      // We need to register a custom Jackson converter to handle serialization of name and function_call of messages.
      restTemplate = new RestTemplateBuilder().setReadTimeout(timeout)
            .additionalMessageConverters(mappingJacksonHttpMessageConverter()).build();
   }

   /**
    * Get mandatory configuration parameters.
    * @return The mandatory configuration keys required by this implementation
    */
   public static final String[] getMandatoryConfigKeys() {
      return MANDATORY_CONFIG_KEYS;
   }

   @Override
   public List suggestSampleExchanges(Service service, Operation operation, Resource contract,
         int number) throws Exception {
      String prompt = "";

      if (service.getType() == ServiceType.REST) {
         prompt = preparePromptForOpenAPI(operation, contract, number);
      } else if (service.getType() == ServiceType.GRAPHQL) {
         prompt = preparePromptForGraphQL(operation, contract, number);
      } else if (service.getType() == ServiceType.EVENT) {
         prompt = preparePromptForAsyncAPI(operation, contract, number);
      } else if (service.getType() == ServiceType.GRPC) {
         prompt = preparePromptForGrpc(service, operation, contract, number);
      }

      log.debug("Asking OpenAI to suggest samples for this prompt: {}", prompt);

      final List messages = new ArrayList<>();
      final ChatMessage assistantMessage = new ChatMessage(ChatMessageRole.ASSISTANT.value(), prompt);
      messages.add(assistantMessage);

      ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest.builder().model(model).messages(messages).n(1)
            .maxTokens(maxTokens).logitBias(new HashMap<>()).build();

      // Build a full HttpEntity as we need to specify authentication headers.
      HttpEntity request = new HttpEntity<>(chatCompletionRequest,
            createAuthenticationHeaders());
      ChatCompletionResult completionResult = restTemplate
            .exchange(apiUrl + "/v1/chat/completions", HttpMethod.POST, request, ChatCompletionResult.class).getBody();

      if (completionResult != null) {
         ChatCompletionChoice choice = completionResult.getChoices().get(0);
         log.debug("Got this raw output from OpenAI: {}", choice.getMessage().getContent());

         if (service.getType() == ServiceType.EVENT) {
            return AICopilotHelper.parseUnidirectionalEventTemplateOutput(choice.getMessage().getContent());
         } else {
            return AICopilotHelper.parseRequestResponseTemplateOutput(service, operation,
                  choice.getMessage().getContent());
         }
      }
      // Return empty list.
      return new ArrayList<>();
   }

   private String preparePromptForOpenAPI(Operation operation, Resource contract, int number) throws Exception {
      StringBuilder prompt = new StringBuilder(
            AICopilotHelper.getOpenAPIOperationPromptIntro(operation.getName(), number));

      // Build a prompt reusing templates and elements from AICopilotHelper.
      prompt.append("\n");
      prompt.append(AICopilotHelper.YAML_FORMATTING_PROMPT);
      prompt.append("\n");
      prompt.append(AICopilotHelper.getRequestResponseExampleYamlFormattingDirective(1));
      prompt.append(SECTION_DELIMITER);
      prompt.append(AICopilotHelper.removeTokensFromSpec(contract.getContent(), operation.getName()));

      return prompt.toString();
   }

   private String preparePromptForGraphQL(Operation operation, Resource contract, int number) {
      StringBuilder prompt = new StringBuilder(
            AICopilotHelper.getGraphQLOperationPromptIntro(operation.getName(), number));

      // We need to indicate the name or variables we want.
      if (DispatchStyles.QUERY_ARGS.equals(operation.getDispatcher())) {
         StringBuilder variablesList = new StringBuilder();
         if (operation.getDispatcherRules().contains("&&")) {
            String[] variables = operation.getDispatcherRules().split("&&");
            for (int i = 0; i < variables.length; i++) {
               String variable = variables[i];
               variablesList.append("$").append(variable.trim());
               if (i < variables.length - 1) {
                  variablesList.append(", ");
               }
            }
         } else {
            variablesList.append("$").append(operation.getDispatcherRules());
         }
         prompt.append("Use only '").append(variablesList).append("' as variable identifiers.");
      }

      // Build a prompt reusing templates and elements from AICopilotHelper.
      prompt.append("\n");
      prompt.append(AICopilotHelper.YAML_FORMATTING_PROMPT);
      prompt.append("\n");
      prompt.append(AICopilotHelper.getRequestResponseExampleYamlFormattingDirective(1));
      prompt.append(SECTION_DELIMITER);
      prompt.append(contract.getContent());

      return prompt.toString();
   }

   private String preparePromptForAsyncAPI(Operation operation, Resource contract, int number) throws Exception {
      StringBuilder prompt = new StringBuilder(
            AICopilotHelper.getAsyncAPIOperationPromptIntro(operation.getName(), number));

      // Build a prompt reusing templates and elements from AICopilotHelper.
      prompt.append("\n");
      prompt.append(AICopilotHelper.YAML_FORMATTING_PROMPT);
      prompt.append("\n");
      prompt.append(AICopilotHelper.getUnidirectionalEventExampleYamlFormattingDirective(1));
      prompt.append(SECTION_DELIMITER);
      prompt.append(AICopilotHelper.removeTokensFromSpec(contract.getContent(), operation.getName()));

      return prompt.toString();
   }

   private String preparePromptForGrpc(Service service, Operation operation, Resource contract, int number)
         throws Exception {
      StringBuilder prompt = new StringBuilder(
            AICopilotHelper.getGrpcOperationPromptIntro(service.getName(), operation.getName(), number));

      // Build a prompt reusing templates and elements from AICopilotHelper.
      prompt.append("\n");
      prompt.append(AICopilotHelper.YAML_FORMATTING_PROMPT);
      prompt.append("\n");
      prompt.append(AICopilotHelper.getGrpcRequestResponseExampleYamlFormattingDirective(1));
      prompt.append(SECTION_DELIMITER);
      prompt.append(contract.getContent());

      return prompt.toString();
   }

   private MappingJackson2HttpMessageConverter mappingJacksonHttpMessageConverter() {
      MappingJackson2HttpMessageConverter converter = new MappingJackson2HttpMessageConverter();
      converter.setObjectMapper(customObjectMapper());
      return converter;
   }

   private static ObjectMapper customObjectMapper() {
      ObjectMapper mapper = new ObjectMapper();
      mapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
      mapper.setSerializationInclusion(JsonInclude.Include.NON_NULL);
      mapper.setPropertyNamingStrategy(PropertyNamingStrategies.SNAKE_CASE);
      mapper.addMixIn(ChatCompletionRequest.class, ChatCompletionRequestMixIn.class);
      return mapper;
   }

   private HttpHeaders createAuthenticationHeaders() {
      HttpHeaders headers = new HttpHeaders();
      headers.set("Authorization", "Bearer " + apiKey);
      return headers;
   }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy