All Downloads are FREE. Search and download functionalities are using the official Maven repository.

dev.langchain4j.model.ollama.OllamaLanguageModel Maven / Gradle / Ivy

There is a newer version: 0.36.0
Show newest version
package dev.langchain4j.model.ollama;

import dev.langchain4j.model.language.LanguageModel;
import dev.langchain4j.model.ollama.spi.OllamaLanguageModelBuilderFactory;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;

import java.time.Duration;
import java.util.List;
import java.util.Map;

import static dev.langchain4j.internal.RetryUtils.withRetry;
import static dev.langchain4j.internal.Utils.getOrDefault;
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
import static dev.langchain4j.spi.ServiceHelper.loadFactories;
import static java.time.Duration.ofSeconds;

/**
 * Ollama API reference
 * 
* Ollama API parameters. */ public class OllamaLanguageModel implements LanguageModel { private final OllamaClient client; private final String modelName; private final Options options; private final String format; private final Integer maxRetries; public OllamaLanguageModel(String baseUrl, String modelName, Double temperature, Integer topK, Double topP, Double repeatPenalty, Integer seed, Integer numPredict, Integer numCtx, List stop, String format, Duration timeout, Integer maxRetries, Boolean logRequests, Boolean logResponses, Map customHeaders ) { this.client = OllamaClient.builder() .baseUrl(baseUrl) .timeout(getOrDefault(timeout, ofSeconds(60))) .logRequests(logRequests) .logResponses(logResponses) .customHeaders(customHeaders) .build(); this.modelName = ensureNotBlank(modelName, "modelName"); this.options = Options.builder() .temperature(temperature) .topK(topK) .topP(topP) .repeatPenalty(repeatPenalty) .seed(seed) .numPredict(numPredict) .numCtx(numCtx) .stop(stop) .build(); this.format = format; this.maxRetries = getOrDefault(maxRetries, 3); } public static OllamaLanguageModelBuilder builder() { for (OllamaLanguageModelBuilderFactory factory : loadFactories(OllamaLanguageModelBuilderFactory.class)) { return factory.get(); } return new OllamaLanguageModelBuilder(); } @Override public Response generate(String prompt) { CompletionRequest request = CompletionRequest.builder() .model(modelName) .prompt(prompt) .options(options) .format(format) .stream(false) .build(); CompletionResponse response = withRetry(() -> client.completion(request), maxRetries); return Response.from( response.getResponse(), new TokenUsage(response.getPromptEvalCount(), response.getEvalCount()) ); } public static class OllamaLanguageModelBuilder { private String baseUrl; private String modelName; private Double temperature; private Integer topK; private Double topP; private Double repeatPenalty; private Integer seed; private Integer numPredict; private Integer numCtx; private List stop; private String format; private Duration timeout; private Integer maxRetries; private Boolean logRequests; private Boolean logResponses; private Map customHeaders; public OllamaLanguageModelBuilder() { // This is public so it can be extended // By default with Lombok it becomes package private } public OllamaLanguageModelBuilder baseUrl(String baseUrl) { this.baseUrl = baseUrl; return this; } public OllamaLanguageModelBuilder modelName(String modelName) { this.modelName = modelName; return this; } public OllamaLanguageModelBuilder temperature(Double temperature) { this.temperature = temperature; return this; } public OllamaLanguageModelBuilder topK(Integer topK) { this.topK = topK; return this; } public OllamaLanguageModelBuilder topP(Double topP) { this.topP = topP; return this; } public OllamaLanguageModelBuilder repeatPenalty(Double repeatPenalty) { this.repeatPenalty = repeatPenalty; return this; } public OllamaLanguageModelBuilder seed(Integer seed) { this.seed = seed; return this; } public OllamaLanguageModelBuilder numPredict(Integer numPredict) { this.numPredict = numPredict; return this; } public OllamaLanguageModelBuilder numCtx(Integer numCtx) { this.numCtx = numCtx; return this; } public OllamaLanguageModelBuilder stop(List stop) { this.stop = stop; return this; } public OllamaLanguageModelBuilder format(String format) { this.format = format; return this; } public OllamaLanguageModelBuilder timeout(Duration timeout) { this.timeout = timeout; return this; } public OllamaLanguageModelBuilder maxRetries(Integer maxRetries) { this.maxRetries = maxRetries; return this; } public OllamaLanguageModelBuilder logRequests(Boolean logRequests) { this.logRequests = logRequests; return this; } public OllamaLanguageModelBuilder logResponses(Boolean logResponses) { this.logResponses = logResponses; return this; } public OllamaLanguageModelBuilder customHeaders(Map customHeaders) { this.customHeaders = customHeaders; return this; } public OllamaLanguageModel build() { return new OllamaLanguageModel( baseUrl, modelName, temperature, topK, topP, repeatPenalty, seed, numPredict, numCtx, stop, format, timeout, maxRetries, logRequests, logResponses, customHeaders ); } } }




© 2015 - 2024 Weber Informatics LLC | Privacy Policy