All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ee.carlrobert.llm.client.azure.AzureClient Maven / Gradle / Ivy

There is a newer version: 0.8.28
Show newest version
package ee.carlrobert.llm.client.azure;

import static ee.carlrobert.llm.client.DeserializationUtil.OBJECT_MAPPER;

import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.core.JsonProcessingException;
import ee.carlrobert.llm.PropertiesLoader;
import ee.carlrobert.llm.client.DeserializationUtil;
import ee.carlrobert.llm.client.openai.completion.ErrorDetails;
import ee.carlrobert.llm.client.openai.completion.OpenAIChatCompletionEventSourceListener;
import ee.carlrobert.llm.client.openai.completion.request.OpenAIChatCompletionRequest;
import ee.carlrobert.llm.client.openai.completion.response.OpenAIChatCompletionResponse;
import ee.carlrobert.llm.client.openai.imagegen.request.OpenAIImageGenerationRequest;
import ee.carlrobert.llm.client.openai.imagegen.response.OpenAiImageGenerationResponse;
import ee.carlrobert.llm.completion.CompletionEventListener;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import okhttp3.Headers;
import okhttp3.MediaType;
import okhttp3.OkHttpClient;
import okhttp3.Request;
import okhttp3.RequestBody;
import okhttp3.sse.EventSource;
import okhttp3.sse.EventSources;

public class AzureClient {

  private static final MediaType APPLICATION_JSON = MediaType.parse("application/json");
  private static final String BASE_URL = PropertiesLoader.getValue("azure.openai.baseUrl");

  private final OkHttpClient httpClient;
  private final String apiKey;
  private final AzureCompletionRequestParams requestParams;
  private final boolean activeDirectoryAuthentication;
  private final String url;

  private AzureClient(Builder builder, OkHttpClient.Builder httpClientBuilder) {
    this.httpClient = httpClientBuilder.build();
    this.apiKey = builder.apiKey;
    this.requestParams = builder.requestParams;
    this.activeDirectoryAuthentication = builder.activeDirectoryAuthentication;
    this.url = String.format(BASE_URL, requestParams.getResourceName());
  }

  public EventSource getChatCompletionAsync(
      OpenAIChatCompletionRequest request,
      CompletionEventListener completionEventListener) {
    return EventSources.createFactory(httpClient)
        .newEventSource(buildChatRequest(request), getEventSourceListener(completionEventListener));
  }

  public OpenAIChatCompletionResponse getChatCompletion(OpenAIChatCompletionRequest request) {
    try (var response = httpClient.newCall(buildChatRequest(request)).execute()) {
      return DeserializationUtil.mapResponse(response, OpenAIChatCompletionResponse.class);
    } catch (IOException e) {
      throw new RuntimeException(e);
    }
  }

  public OpenAiImageGenerationResponse getImage(OpenAIImageGenerationRequest request) {
    try (var response = httpClient.newBuilder()
        .readTimeout(60, TimeUnit.SECONDS)
        .callTimeout(60, TimeUnit.SECONDS)
        .build()
        .newCall(buildImageRequest(request)).execute()) {
      return DeserializationUtil.mapResponse(response, OpenAiImageGenerationResponse.class);
    } catch (IOException e) {
      throw new RuntimeException(e);
    }
  }

  private Request buildChatRequest(OpenAIChatCompletionRequest completionRequest) {
    var headers = new HashMap<>(getRequiredHeaders());
    if (completionRequest.isStream()) {
      headers.put("Accept", "text/event-stream");
    }
    try {
      return new Request.Builder()
          .url(url + getChatCompletionPath(completionRequest))
          .headers(Headers.of(headers))
          .post(RequestBody.create(
              OBJECT_MAPPER
                  .setSerializationInclusion(JsonInclude.Include.NON_NULL)
                  .writeValueAsString(completionRequest),
              APPLICATION_JSON))
          .build();
    } catch (JsonProcessingException e) {
      throw new RuntimeException("Unable to process request", e);
    }
  }

  private Request buildImageRequest(OpenAIImageGenerationRequest imageRequest) {
    var headers = new HashMap<>(getRequiredHeaders());
    headers.put("Content-Type", "application/json");
    try {
      return new Request.Builder()
          .url(url + getImageGenerationPath(imageRequest))
          .headers(Headers.of(headers))
          .post(RequestBody.create(
              OBJECT_MAPPER
                  .setSerializationInclusion(JsonInclude.Include.NON_NULL)
                  .writeValueAsString(imageRequest),
              APPLICATION_JSON))
          .build();
    } catch (JsonProcessingException e) {
      throw new RuntimeException("Unable to process request", e);
    }
  }

  private Map getRequiredHeaders() {
    var headers = new HashMap();
    headers.put("X-LLM-Application-Tag", "codegpt");
    if (activeDirectoryAuthentication) {
      headers.put("Authorization", "Bearer " + apiKey);
    } else {
      headers.put("api-key", apiKey);
    }
    return headers;
  }

  private String getChatCompletionPath(OpenAIChatCompletionRequest request) {
    return String.format(
        request.getOverriddenPath() == null
            ? "/openai/deployments/%s/chat/completions?api-version=%s"
            : request.getOverriddenPath(),
        requestParams.getDeploymentId(),
        requestParams.getApiVersion());
  }

  private String getImageGenerationPath(OpenAIImageGenerationRequest request) {
    return String.format(
        request.getOverriddenPath() == null
            ? "/openai/deployments/%s/images/generations?api-version=%s"
            : request.getOverriddenPath(),
        request.getModel(),
        requestParams.getApiVersion());
  }

  private OpenAIChatCompletionEventSourceListener getEventSourceListener(
      CompletionEventListener listener) {
    return new OpenAIChatCompletionEventSourceListener(listener) {
      @Override
      protected ErrorDetails getErrorDetails(String data) throws JsonProcessingException {
        return OBJECT_MAPPER.readValue(data, AzureApiResponseError.class).getError();
      }
    };
  }

  public static class Builder {

    private final String apiKey;
    private final AzureCompletionRequestParams requestParams;
    private boolean activeDirectoryAuthentication;

    public Builder(String apiKey, AzureCompletionRequestParams requestParams) {
      this.apiKey = apiKey;
      this.requestParams = requestParams;
    }

    public Builder setActiveDirectoryAuthentication(boolean activeDirectoryAuthentication) {
      this.activeDirectoryAuthentication = activeDirectoryAuthentication;
      return this;
    }

    public AzureClient build(OkHttpClient.Builder builder) {
      return new AzureClient(this, builder);
    }

    public AzureClient build() {
      return build(new OkHttpClient.Builder());
    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy