All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ee.carlrobert.llm.client.google.GoogleClient Maven / Gradle / Ivy

There is a newer version: 0.8.28
Show newest version
package ee.carlrobert.llm.client.google;

import static ee.carlrobert.llm.client.DeserializationUtil.OBJECT_MAPPER;
import static java.lang.String.format;
import static java.util.stream.Collectors.toList;

import com.fasterxml.jackson.core.JacksonException;
import com.fasterxml.jackson.core.JsonProcessingException;
import ee.carlrobert.llm.PropertiesLoader;
import ee.carlrobert.llm.client.DeserializationUtil;
import ee.carlrobert.llm.client.google.completion.ApiResponseError;
import ee.carlrobert.llm.client.google.completion.GoogleCompletionContent;
import ee.carlrobert.llm.client.google.completion.GoogleCompletionRequest;
import ee.carlrobert.llm.client.google.completion.GoogleCompletionResponse;
import ee.carlrobert.llm.client.google.completion.GoogleCompletionResponse.Candidate;
import ee.carlrobert.llm.client.google.completion.GoogleContentPart;
import ee.carlrobert.llm.client.google.embedding.ContentEmbedding;
import ee.carlrobert.llm.client.google.embedding.GoogleBatchEmbeddingResponse;
import ee.carlrobert.llm.client.google.embedding.GoogleEmbeddingContentRequest;
import ee.carlrobert.llm.client.google.embedding.GoogleEmbeddingRequest;
import ee.carlrobert.llm.client.google.embedding.GoogleEmbeddingResponse;
import ee.carlrobert.llm.client.google.models.GoogleModel;
import ee.carlrobert.llm.client.google.models.GoogleModelsResponse;
import ee.carlrobert.llm.client.google.models.GoogleModelsResponse.GeminiModelDetails;
import ee.carlrobert.llm.client.google.models.GoogleTokensResponse;
import ee.carlrobert.llm.client.openai.completion.ErrorDetails;
import ee.carlrobert.llm.completion.CompletionEventListener;
import ee.carlrobert.llm.completion.CompletionEventSourceListener;
import java.io.IOException;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Stream;
import okhttp3.HttpUrl;
import okhttp3.MediaType;
import okhttp3.OkHttpClient;
import okhttp3.Request;
import okhttp3.RequestBody;
import okhttp3.sse.EventSource;
import okhttp3.sse.EventSources;

public class GoogleClient {

  private static final MediaType APPLICATION_JSON = MediaType.parse("application/json");

  private final OkHttpClient httpClient;
  private final String host;
  private final String apiKey;

  protected GoogleClient(Builder builder, OkHttpClient.Builder httpClientBuilder) {
    this.httpClient = httpClientBuilder.build();
    this.host = builder.host;
    this.apiKey = builder.apiKey;
  }

  public EventSource getChatCompletionAsync(
      GoogleCompletionRequest request,
      GoogleModel model,
      CompletionEventListener eventListener) {
    return getChatCompletionAsync(request, model.getCode(), eventListener);
  }

  public EventSource getChatCompletionAsync(
      GoogleCompletionRequest request,
      String model,
      CompletionEventListener eventListener) {
    return EventSources.createFactory(httpClient)
        .newEventSource(buildPostRequest(request, model, "streamGenerateContent", true),
            getEventSourceListener(eventListener));
  }

  /**
   * GenerateContent.
   */
  public GoogleCompletionResponse getChatCompletion(GoogleCompletionRequest request,
      GoogleModel model) {
    return getChatCompletion(request, model.getCode());
  }

  /**
   * GenerateContent.
   */
  public GoogleCompletionResponse getChatCompletion(GoogleCompletionRequest request, String model) {
    try (var response = httpClient.newCall(
        buildPostRequest(request, model, "generateContent", false)).execute()) {
      return DeserializationUtil.mapResponse(response, GoogleCompletionResponse.class);
    } catch (IOException e) {
      throw new RuntimeException(
          "Could not get llama completion for the given request:\n" + request, e);
    }
  }

  public double[] getEmbedding(String text, GoogleModel model) {
    return getEmbedding(List.of(text), model.getCode());
  }

  public double[] getEmbedding(String text, String model) {
    return getEmbedding(List.of(text), model);
  }

  public double[] getEmbedding(List texts, GoogleModel model) {
    return getEmbedding(texts, model.getCode());
  }

  public double[] getEmbedding(List texts, String model) {
    return getEmbedding(new GoogleEmbeddingRequest.Builder(new GoogleCompletionContent(texts))
        .build(), model);
  }

  /**
   * EmbedContent.
   */
  public double[] getEmbedding(GoogleEmbeddingRequest request, GoogleModel model) {
    return getEmbedding(request, model.getCode());
  }

  /**
   * EmbedContent.
   */
  public double[] getEmbedding(GoogleEmbeddingRequest request, String model) {
    try (var response = httpClient
        .newCall(buildPostRequest(request, model, "embedContent", false))
        .execute()) {

      return Optional.ofNullable(
              DeserializationUtil.mapResponse(response, GoogleEmbeddingResponse.class))
          .map(GoogleEmbeddingResponse::getEmbedding)
          .map(ContentEmbedding::getValues)
          .orElse(null);

    } catch (IOException e) {
      throw new RuntimeException("Unable to fetch embedding", e);
    }
  }

  /**
   * BatchEmbedContents.
   */
  public List getBatchEmbeddings(
      List requests,
      GoogleModel model) {
    return getBatchEmbeddings(requests, model.getCode());
  }

  public List getBatchEmbeddings(
      List requests,
      String model) {
    try (var response = httpClient
        .newCall(buildPostRequest(Map.of("requests", requests), model, "batchEmbedContents", false))
        .execute()) {

      var embeddings = Optional.ofNullable(
              DeserializationUtil.mapResponse(response, GoogleBatchEmbeddingResponse.class))
          .map(GoogleBatchEmbeddingResponse::getEmbeddings)
          .stream()
          .flatMap(Collection::stream)
          .filter(Objects::nonNull)
          .map(ContentEmbedding::getValues)
          .filter(Objects::nonNull)
          .collect(toList());
      return embeddings.isEmpty() ? null : embeddings;

    } catch (IOException e) {
      throw new RuntimeException("Unable to fetch embedding", e);
    }
  }


  /**
   * Models List.
   */
  public GoogleModelsResponse getModels(Integer pageSize, String pageToken) {
    String url = host + "/v1/models";
    HttpUrl.Builder urlBuilder = HttpUrl.parse(url).newBuilder();
    if (pageSize != null) {
      urlBuilder.addQueryParameter("pageSize", pageSize.toString());
    }
    if (pageToken != null) {
      urlBuilder.addQueryParameter("pageToken", pageToken);
    }
    try (var response = httpClient
        .newCall(defaultRequestBuilder(urlBuilder, false).get().build())
        .execute()) {
      return DeserializationUtil.mapResponse(response, GoogleModelsResponse.class);

    } catch (IOException e) {
      throw new RuntimeException("Unable to fetch models", e);
    }
  }

  /**
   * Get Model.
   */
  public GeminiModelDetails getModel(String name) {
    String url = host + "/v1/models/" + name;
    try (var response = httpClient.newCall(defaultRequestBuilder(url, false).get().build())
        .execute()) {
      return DeserializationUtil.mapResponse(response, GeminiModelDetails.class);
    } catch (IOException e) {
      throw new RuntimeException("Unable to fetch model", e);
    }
  }

  /**
   * CountTokens.
   */
  public GoogleTokensResponse getCountTokens(List contents,
      GoogleModel model) {
    return getCountTokens(contents, model.getCode());
  }

  /**
   * CountTokens.
   */
  public GoogleTokensResponse getCountTokens(List contents, String model) {
    try (var response = httpClient
        .newCall(buildPostRequest(Map.of("contents", contents), model, "countTokens", false))
        .execute()) {
      return DeserializationUtil.mapResponse(response, GoogleTokensResponse.class);
    } catch (IOException e) {
      throw new RuntimeException("Unable to fetch tokens count", e);
    }
  }

  private Request buildPostRequest(Object request, String model, String path,
      boolean stream) {
    try {
      Request.Builder builder = defaultRequestBuilder(
          host + format("/v1/models/%s:%s", model, path), stream)
          .post(RequestBody.create(OBJECT_MAPPER.writeValueAsString(request), APPLICATION_JSON));
      return builder.build();
    } catch (JsonProcessingException e) {
      throw new RuntimeException(e);
    }
  }

  private Request.Builder defaultRequestBuilder(String url, boolean stream) {
    return defaultRequestBuilder(HttpUrl.parse(url).newBuilder(), stream);
  }

  private Request.Builder defaultRequestBuilder(HttpUrl.Builder url, boolean stream) {
    if (apiKey != null && !apiKey.isEmpty()) {
      url.addQueryParameter("key", apiKey);
    }
    // see https://ai.google.dev/gemini-api/docs/get-started/rest#stream_generate_content
    if (stream) {
      url.addQueryParameter("alt", "sse");
    }
    return new Request.Builder()
        .url(url.build())
        .header("Cache-Control", "no-cache")
        .header("Content-Type", "application/json")
        .header("Accept", stream ? "text/event-stream" : "text/json");
  }

  private CompletionEventSourceListener getEventSourceListener(
      CompletionEventListener eventListener) {
    return new CompletionEventSourceListener<>(eventListener) {
      @Override
      protected String getMessage(String data) {
        try {
          var candidates = OBJECT_MAPPER.readValue(data, GoogleCompletionResponse.class)
              .getCandidates();
          return (candidates == null ? Stream.empty() : candidates.stream())
              .filter(Objects::nonNull)
              .flatMap(candidate -> candidate.getContent().getParts().stream())
              .filter(Objects::nonNull)
              .findFirst()
              .map(GoogleContentPart::getText)
              .orElse("");
        } catch (JacksonException e) {
          // ignore
          System.out.println();
        }
        return "";
      }

      @Override
      protected ErrorDetails getErrorDetails(String data) throws JsonProcessingException {
        var googleError = OBJECT_MAPPER.readValue(data, ApiResponseError.class).getError();
        return googleError == null ? null
            : new ErrorDetails(googleError.getMessage(), googleError.getStatus(), null,
                googleError.getCode());
      }
    };
  }

  public static class Builder {

    private String host = PropertiesLoader.getValue("google.baseUrl");
    private String apiKey;

    public Builder(String apiKey) {
      this.apiKey = apiKey;
    }

    public Builder setHost(String host) {
      this.host = host;
      return this;
    }

    public Builder setApiKey(String apiKey) {
      this.apiKey = apiKey;
      return this;
    }

    public GoogleClient build(OkHttpClient.Builder builder) {
      return new GoogleClient(this, builder);
    }

    public GoogleClient build() {
      return build(new OkHttpClient.Builder());
    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy