All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ee.carlrobert.llm.client.llama.LlamaClient Maven / Gradle / Ivy

There is a newer version: 0.8.28
Show newest version
package ee.carlrobert.llm.client.llama;

import static ee.carlrobert.llm.client.DeserializationUtil.OBJECT_MAPPER;
import static java.lang.String.format;

import com.fasterxml.jackson.core.JacksonException;
import com.fasterxml.jackson.core.JsonProcessingException;
import ee.carlrobert.llm.PropertiesLoader;
import ee.carlrobert.llm.client.DeserializationUtil;
import ee.carlrobert.llm.client.llama.completion.LlamaCompletionRequest;
import ee.carlrobert.llm.client.llama.completion.LlamaCompletionResponse;
import ee.carlrobert.llm.client.llama.completion.LlamaInfillRequest;
import ee.carlrobert.llm.client.openai.completion.ErrorDetails;
import ee.carlrobert.llm.completion.CompletionEventListener;
import ee.carlrobert.llm.completion.CompletionEventSourceListener;
import java.io.IOException;
import okhttp3.MediaType;
import okhttp3.OkHttpClient;
import okhttp3.Request;
import okhttp3.RequestBody;
import okhttp3.sse.EventSource;
import okhttp3.sse.EventSources;

public class LlamaClient {

  private static final String BASE_URL = PropertiesLoader.getValue("llama.baseUrl");
  private static final MediaType APPLICATION_JSON = MediaType.parse("application/json");

  private final OkHttpClient httpClient;
  private final String host;
  private final Integer port;
  private final String apiKey;

  protected LlamaClient(Builder builder, OkHttpClient.Builder httpClientBuilder) {
    this.httpClient = httpClientBuilder.build();
    this.host = builder.host;
    this.port = builder.port;
    this.apiKey = builder.apiKey;
  }

  public EventSource getChatCompletionAsync(
      LlamaCompletionRequest request,
      CompletionEventListener eventListener) {
    return EventSources.createFactory(httpClient)
        .newEventSource(buildCompletionHttpRequest(request), getEventSourceListener(eventListener));
  }

  public LlamaCompletionResponse getChatCompletion(LlamaCompletionRequest request) {
    try (var response = httpClient.newCall(buildCompletionHttpRequest(request)).execute()) {
      return DeserializationUtil.mapResponse(response, LlamaCompletionResponse.class);
    } catch (IOException e) {
      throw new RuntimeException(
          "Could not get llama completion for the given request:\n" + request, e);
    }
  }

  public LlamaCompletionResponse getInfill(LlamaInfillRequest request) {
    try (var response = httpClient.newCall(buildHttpRequest(request, "/infill")).execute()) {
      return DeserializationUtil.mapResponse(response, LlamaCompletionResponse.class);
    } catch (IOException e) {
      throw new RuntimeException(
          "Could not get llama completion for the given request:\n" + request, e);
    }
  }

  public EventSource getInfillAsync(
      LlamaInfillRequest request,
      CompletionEventListener eventListener) {
    return EventSources.createFactory(httpClient).newEventSource(
        buildHttpRequest(request, "/infill"),
        getEventSourceListener(eventListener));
  }

  private Request buildCompletionHttpRequest(LlamaCompletionRequest request) {
    return buildHttpRequest(request, "/completion");
  }

  private Request buildHttpRequest(LlamaCompletionRequest request, String path) {
    try {
      var baseHost = port == null ? BASE_URL : format("http://localhost:%d", port);
      Request.Builder builder = new Request.Builder()
          .url((host == null ? baseHost : host) + path)
          .header("Cache-Control", "no-cache")
          .header("Content-Type", "application/json")
          .header("Accept", request.isStream() ? "text/event-stream" : "text/json")
          .post(RequestBody.create(OBJECT_MAPPER.writeValueAsString(request), APPLICATION_JSON));
      if (apiKey != null) {
        builder.header("Authorization", "Bearer " + apiKey);
      }
      return builder.build();
    } catch (JsonProcessingException e) {
      throw new RuntimeException(e);
    }
  }

  private CompletionEventSourceListener getEventSourceListener(
      CompletionEventListener eventListener) {
    return new CompletionEventSourceListener(eventListener) {
      @Override
      protected String getMessage(String data) {
        try {
          var response = OBJECT_MAPPER.readValue(data, LlamaCompletionResponse.class);
          return response.getContent();
        } catch (JacksonException e) {
          // ignore
        }
        return "";
      }

      @Override
      protected ErrorDetails getErrorDetails(String error) {
        return new ErrorDetails(error);
      }
    };
  }

  public static class Builder {

    private String host;
    private Integer port;
    private String apiKey;

    public Builder setHost(String host) {
      this.host = host;
      return this;
    }

    public Builder setPort(Integer port) {
      this.port = port;
      return this;
    }

    public Builder setApiKey(String apiKey) {
      this.apiKey = apiKey;
      return this;
    }

    public LlamaClient build(OkHttpClient.Builder builder) {
      return new LlamaClient(this, builder);
    }

    public LlamaClient build() {
      return build(new OkHttpClient.Builder());
    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy