All Downloads are FREE. Search and download functionalities are using the official Maven repository.

dev.langchain4j.model.ollama.OllamaClient Maven / Gradle / Ivy

There is a newer version: 1.0.0-alpha1
Show newest version
package dev.langchain4j.model.ollama;

import com.google.gson.Gson;
import com.google.gson.GsonBuilder;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.model.StreamingResponseHandler;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import lombok.Builder;
import okhttp3.OkHttpClient;
import okhttp3.ResponseBody;
import retrofit2.Call;
import retrofit2.Callback;
import retrofit2.Retrofit;
import retrofit2.converter.gson.GsonConverterFactory;

import java.io.IOException;
import java.io.InputStream;
import java.time.Duration;

import static com.google.gson.FieldNamingPolicy.LOWER_CASE_WITH_UNDERSCORES;
import static java.lang.Boolean.TRUE;

class OllamaClient {

    private static final Gson GSON = new GsonBuilder()
            .setFieldNamingPolicy(LOWER_CASE_WITH_UNDERSCORES)
            .create();

    private final OllamaApi ollamaApi;

    @Builder
    public OllamaClient(String baseUrl, Duration timeout) {

        OkHttpClient okHttpClient = new OkHttpClient.Builder()
                .callTimeout(timeout)
                .connectTimeout(timeout)
                .readTimeout(timeout)
                .writeTimeout(timeout)
                .build();

        Retrofit retrofit = new Retrofit.Builder()
                .baseUrl(baseUrl)
                .client(okHttpClient)
                .addConverterFactory(GsonConverterFactory.create(GSON))
                .build();

        ollamaApi = retrofit.create(OllamaApi.class);
    }

    public CompletionResponse completion(CompletionRequest request) {
        try {
            retrofit2.Response retrofitResponse
                    = ollamaApi.completion(request).execute();

            if (retrofitResponse.isSuccessful()) {
                return retrofitResponse.body();
            } else {
                throw toException(retrofitResponse);
            }
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    public ChatResponse chat(ChatRequest request) {
        try {
            retrofit2.Response retrofitResponse
                    = ollamaApi.chat(request).execute();

            if (retrofitResponse.isSuccessful()) {
                return retrofitResponse.body();
            } else {
                throw toException(retrofitResponse);
            }
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    public void streamingCompletion(CompletionRequest request, StreamingResponseHandler handler) {
        ollamaApi.streamingCompletion(request).enqueue(new Callback() {

            @Override
            public void onResponse(Call call, retrofit2.Response retrofitResponse) {
                try (InputStream inputStream = retrofitResponse.body().byteStream()) {
                    StringBuilder contentBuilder = new StringBuilder();
                    while (true) {
                        byte[] bytes = new byte[1024];
                        int len = inputStream.read(bytes);
                        String partialResponse = new String(bytes, 0, len);
                        CompletionResponse completionResponse = GSON.fromJson(partialResponse, CompletionResponse.class);

                        contentBuilder.append(completionResponse.getResponse());
                        handler.onNext(completionResponse.getResponse());

                        if (TRUE.equals(completionResponse.getDone())) {
                            Response response = Response.from(
                                    contentBuilder.toString(),
                                    new TokenUsage(
                                            completionResponse.getPromptEvalCount(),
                                            completionResponse.getEvalCount()
                                    )
                            );
                            handler.onComplete(response);
                            return;
                        }
                    }
                } catch (IOException e) {
                    throw new RuntimeException(e);
                }
            }

            @Override
            public void onFailure(Call call, Throwable throwable) {
                handler.onError(throwable);
            }
        });
    }

    public void streamingChat(ChatRequest request, StreamingResponseHandler handler) {
        ollamaApi.streamingChat(request).enqueue(new Callback() {

            @Override
            public void onResponse(Call call, retrofit2.Response retrofitResponse) {
                try (InputStream inputStream = retrofitResponse.body().byteStream()) {
                    StringBuilder contentBuilder = new StringBuilder();
                    while (true) {
                        byte[] bytes = new byte[1024];
                        int len = inputStream.read(bytes);
                        String partialResponse = new String(bytes, 0, len);
                        ChatResponse chatResponse = GSON.fromJson(partialResponse, ChatResponse.class);

                        String content = chatResponse.getMessage().getContent();
                        contentBuilder.append(content);
                        handler.onNext(content);

                        if (TRUE.equals(chatResponse.getDone())) {
                            Response response = Response.from(
                                    AiMessage.from(contentBuilder.toString()),
                                    new TokenUsage(
                                            chatResponse.getPromptEvalCount(),
                                            chatResponse.getEvalCount()
                                    )
                            );
                            handler.onComplete(response);
                            return;
                        }
                    }
                } catch (IOException e) {
                    throw new RuntimeException(e);
                }
            }

            @Override
            public void onFailure(Call call, Throwable throwable) {
                handler.onError(throwable);
            }
        });
    }

    public EmbeddingResponse embed(EmbeddingRequest request) {
        try {
            retrofit2.Response retrofitResponse = ollamaApi.embedd(request).execute();
            if (retrofitResponse.isSuccessful()) {
                return retrofitResponse.body();
            } else {
                throw toException(retrofitResponse);
            }
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    public ModelsListResponse listModels() {
        try {
            retrofit2.Response retrofitResponse = ollamaApi.listModels().execute();
            if (retrofitResponse.isSuccessful()) {
                return retrofitResponse.body();
            } else {
                throw toException(retrofitResponse);
            }
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    public OllamaModelCard showInformation(ShowModelInformationRequest showInformationRequest) {
        try {
            retrofit2.Response retrofitResponse = ollamaApi.showInformation(showInformationRequest).execute();
            if (retrofitResponse.isSuccessful()) {
                return retrofitResponse.body();
            } else {
                throw toException(retrofitResponse);
            }
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    private RuntimeException toException(retrofit2.Response response) throws IOException {
        int code = response.code();
        String body = response.errorBody().string();

        String errorMessage = String.format("status code: %s; body: %s", code, body);
        return new RuntimeException(errorMessage);
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy