![JAR search and dependency download from the Maven repository](/logo.png)
dev.langchain4j.model.ollama.OllamaEmbeddingModel Maven / Gradle / Ivy
package dev.langchain4j.model.ollama;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.DimensionAwareEmbeddingModel;
import dev.langchain4j.model.ollama.spi.OllamaEmbeddingModelBuilderFactory;
import dev.langchain4j.model.output.Response;
import java.time.Duration;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import static dev.langchain4j.internal.RetryUtils.withRetry;
import static dev.langchain4j.internal.Utils.getOrDefault;
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
import static dev.langchain4j.spi.ServiceHelper.loadFactories;
import static java.time.Duration.ofSeconds;
/**
* Ollama API reference
*/
public class OllamaEmbeddingModel extends DimensionAwareEmbeddingModel {
private final OllamaClient client;
private final String modelName;
private final Integer maxRetries;
public OllamaEmbeddingModel(String baseUrl,
String modelName,
Duration timeout,
Integer maxRetries,
Boolean logRequests,
Boolean logResponses,
Map customHeaders) {
this.client = OllamaClient.builder()
.baseUrl(baseUrl)
.timeout(getOrDefault(timeout, ofSeconds(60)))
.logRequests(logRequests)
.logResponses(logResponses)
.customHeaders(customHeaders)
.build();
this.modelName = ensureNotBlank(modelName, "modelName");
this.maxRetries = getOrDefault(maxRetries, 3);
}
public static OllamaEmbeddingModelBuilder builder() {
for (OllamaEmbeddingModelBuilderFactory factory : loadFactories(OllamaEmbeddingModelBuilderFactory.class)) {
return factory.get();
}
return new OllamaEmbeddingModelBuilder();
}
@Override
public Response> embedAll(List textSegments) {
List input = textSegments.stream()
.map(TextSegment::text)
.collect(Collectors.toList());
EmbeddingRequest request = EmbeddingRequest.builder()
.model(modelName)
.input(input)
.build();
EmbeddingResponse response = withRetry(() -> client.embed(request), maxRetries);
List embeddings = response.getEmbeddings()
.stream()
.map(Embedding::from)
.collect(Collectors.toList());
return Response.from(embeddings);
}
public static class OllamaEmbeddingModelBuilder {
private String baseUrl;
private String modelName;
private Duration timeout;
private Integer maxRetries;
private Boolean logRequests;
private Boolean logResponses;
private Map customHeaders;
public OllamaEmbeddingModelBuilder() {
// This is public so it can be extended
// By default with Lombok it becomes package private
}
public OllamaEmbeddingModelBuilder baseUrl(String baseUrl) {
this.baseUrl = baseUrl;
return this;
}
public OllamaEmbeddingModelBuilder modelName(String modelName) {
this.modelName = modelName;
return this;
}
public OllamaEmbeddingModelBuilder timeout(Duration timeout) {
this.timeout = timeout;
return this;
}
public OllamaEmbeddingModelBuilder maxRetries(Integer maxRetries) {
this.maxRetries = maxRetries;
return this;
}
public OllamaEmbeddingModelBuilder logRequests(Boolean logRequests) {
this.logRequests = logRequests;
return this;
}
public OllamaEmbeddingModelBuilder logResponses(Boolean logResponses) {
this.logResponses = logResponses;
return this;
}
public OllamaEmbeddingModelBuilder customHeaders(Map customHeaders) {
this.customHeaders = customHeaders;
return this;
}
public OllamaEmbeddingModel build() {
return new OllamaEmbeddingModel(baseUrl, modelName, timeout, maxRetries, logRequests, logResponses, customHeaders);
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy