All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.quarkiverse.langchain4j.watsonx.WatsonxEmbeddingModel Maven / Gradle / Ivy

There is a newer version: 0.21.0
Show newest version
package io.quarkiverse.langchain4j.watsonx;

import java.util.List;
import java.util.Objects;
import java.util.concurrent.Callable;
import java.util.stream.Collectors;

import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.embedding.TokenCountEstimator;
import dev.langchain4j.model.output.Response;
import io.quarkiverse.langchain4j.watsonx.bean.EmbeddingRequest;
import io.quarkiverse.langchain4j.watsonx.bean.EmbeddingResponse;
import io.quarkiverse.langchain4j.watsonx.bean.EmbeddingResponse.Result;
import io.quarkiverse.langchain4j.watsonx.bean.TokenizationRequest;

public class WatsonxEmbeddingModel extends WatsonxModel implements EmbeddingModel, TokenCountEstimator {

    public WatsonxEmbeddingModel(Builder config) {
        super(config);
    }

    @Override
    public Response> embedAll(List textSegments) {

        if (Objects.isNull(textSegments) || textSegments.isEmpty())
            return Response.from(List.of());

        var inputs = textSegments.stream()
                .map(TextSegment::text)
                .collect(Collectors.toList());

        EmbeddingRequest request = new EmbeddingRequest(modelId, projectId, inputs);
        EmbeddingResponse result = retryOn(new Callable() {
            @Override
            public EmbeddingResponse call() throws Exception {
                return client.embeddings(request, version);
            }
        });

        return Response.from(
                result.results()
                        .stream()
                        .map(Result::embedding)
                        .map(Embedding::from)
                        .collect(Collectors.toList()));
    }

    @Override
    public int estimateTokenCount(String text) {

        if (Objects.isNull(text) || text.isEmpty())
            return 0;

        var request = new TokenizationRequest(modelId, text, projectId);
        return retryOn(new Callable() {
            @Override
            public Integer call() throws Exception {
                return client.tokenization(request, version).result().tokenCount();
            }
        });
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy