Please wait. This can take some minutes ...
Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance.
Project price only 1 $
You can buy this project and download/modify it how often you want.
dev.langchain4j.model.ollama.OllamaChatModel Maven / Gradle / Ivy
package dev.langchain4j.model.ollama;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.listener.ChatModelListener;
import dev.langchain4j.model.chat.listener.ChatModelRequest;
import dev.langchain4j.model.ollama.spi.OllamaChatModelBuilderFactory;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import java.time.Duration;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import static dev.langchain4j.internal.RetryUtils.withRetry;
import static dev.langchain4j.internal.Utils.getOrDefault;
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
import static dev.langchain4j.internal.ValidationUtils.ensureNotEmpty;
import static dev.langchain4j.model.ollama.OllamaChatModelListenerUtils.createModelListenerRequest;
import static dev.langchain4j.model.ollama.OllamaChatModelListenerUtils.onListenError;
import static dev.langchain4j.model.ollama.OllamaChatModelListenerUtils.onListenRequest;
import static dev.langchain4j.model.ollama.OllamaChatModelListenerUtils.onListenResponse;
import static dev.langchain4j.model.ollama.OllamaMessagesUtils.toOllamaMessages;
import static dev.langchain4j.model.ollama.OllamaMessagesUtils.toOllamaTools;
import static dev.langchain4j.model.ollama.OllamaMessagesUtils.toToolExecutionRequest;
import static dev.langchain4j.spi.ServiceHelper.loadFactories;
import static java.time.Duration.ofSeconds;
import static java.util.Collections.emptyList;
/**
* Ollama API reference
*
* Ollama API parameters .
*/
public class OllamaChatModel implements ChatLanguageModel {
private final OllamaClient client;
private final String modelName;
private final Options options;
private final String format;
private final Integer maxRetries;
private final List listeners;
public OllamaChatModel(String baseUrl,
String modelName,
Double temperature,
Integer topK,
Double topP,
Double repeatPenalty,
Integer seed,
Integer numPredict,
Integer numCtx,
List stop,
String format,
Duration timeout,
Integer maxRetries,
Map customHeaders,
Boolean logRequests,
Boolean logResponses,
List listeners) {
this.client = OllamaClient.builder()
.baseUrl(baseUrl)
.timeout(getOrDefault(timeout, ofSeconds(60)))
.customHeaders(customHeaders)
.logRequests(getOrDefault(logRequests, false))
.logResponses(logResponses)
.build();
this.modelName = ensureNotBlank(modelName, "modelName");
this.options = Options.builder()
.temperature(temperature)
.topK(topK)
.topP(topP)
.repeatPenalty(repeatPenalty)
.seed(seed)
.numPredict(numPredict)
.numCtx(numCtx)
.stop(stop)
.build();
this.format = format;
this.maxRetries = getOrDefault(maxRetries, 3);
this.listeners = new ArrayList<>(getOrDefault(listeners, emptyList()));
}
public static OllamaChatModelBuilder builder() {
for (OllamaChatModelBuilderFactory factory : loadFactories(OllamaChatModelBuilderFactory.class)) {
return factory.get();
}
return new OllamaChatModelBuilder();
}
@Override
public Response generate(List messages) {
ensureNotEmpty(messages, "messages");
return doGenerate(messages, null);
}
@Override
public Response generate(List messages, List toolSpecifications) {
ensureNotEmpty(messages, "messages");
return doGenerate(messages, toolSpecifications);
}
private Response doGenerate(List messages, List toolSpecifications) {
ChatRequest request = ChatRequest.builder()
.model(modelName)
.messages(toOllamaMessages(messages))
.options(options)
.format(format)
.stream(false)
.tools(toOllamaTools(toolSpecifications))
.build();
ChatModelRequest modelListenerRequest = createModelListenerRequest(request, messages, toolSpecifications);
Map attributes = new ConcurrentHashMap<>();
onListenRequest(listeners, modelListenerRequest, attributes);
try {
ChatResponse chatResponse = withRetry(() -> client.chat(request), maxRetries);
Response response = Response.from(
chatResponse.getMessage().getToolCalls() != null ?
AiMessage.from(toToolExecutionRequest(chatResponse.getMessage().getToolCalls())) :
AiMessage.from(chatResponse.getMessage().getContent()),
new TokenUsage(chatResponse.getPromptEvalCount(), chatResponse.getEvalCount())
);
onListenResponse(listeners, response, modelListenerRequest, attributes);
return response;
} catch (Exception e) {
onListenError(listeners, e, modelListenerRequest, null, attributes);
throw e;
}
}
public static class OllamaChatModelBuilder {
private String baseUrl;
private String modelName;
private Double temperature;
private Integer topK;
private Double topP;
private Double repeatPenalty;
private Integer seed;
private Integer numPredict;
private Integer numCtx;
private List stop;
private String format;
private Duration timeout;
private Integer maxRetries;
private Map customHeaders;
private Boolean logRequests;
private Boolean logResponses;
private List listeners;
public OllamaChatModelBuilder() {
// This is public so it can be extended
// By default with Lombok it becomes package private
}
public OllamaChatModelBuilder baseUrl(String baseUrl) {
this.baseUrl = baseUrl;
return this;
}
public OllamaChatModelBuilder modelName(String modelName) {
this.modelName = modelName;
return this;
}
public OllamaChatModelBuilder temperature(Double temperature) {
this.temperature = temperature;
return this;
}
public OllamaChatModelBuilder topK(Integer topK) {
this.topK = topK;
return this;
}
public OllamaChatModelBuilder topP(Double topP) {
this.topP = topP;
return this;
}
public OllamaChatModelBuilder repeatPenalty(Double repeatPenalty) {
this.repeatPenalty = repeatPenalty;
return this;
}
public OllamaChatModelBuilder seed(Integer seed) {
this.seed = seed;
return this;
}
public OllamaChatModelBuilder numPredict(Integer numPredict) {
this.numPredict = numPredict;
return this;
}
public OllamaChatModelBuilder numCtx(Integer numCtx) {
this.numCtx = numCtx;
return this;
}
public OllamaChatModelBuilder stop(List stop) {
this.stop = stop;
return this;
}
public OllamaChatModelBuilder format(String format) {
this.format = format;
return this;
}
public OllamaChatModelBuilder timeout(Duration timeout) {
this.timeout = timeout;
return this;
}
public OllamaChatModelBuilder maxRetries(Integer maxRetries) {
this.maxRetries = maxRetries;
return this;
}
public OllamaChatModelBuilder customHeaders(Map customHeaders) {
this.customHeaders = customHeaders;
return this;
}
public OllamaChatModelBuilder logRequests(Boolean logRequests) {
this.logRequests = logRequests;
return this;
}
public OllamaChatModelBuilder logResponses(Boolean logResponses) {
this.logResponses = logResponses;
return this;
}
public OllamaChatModelBuilder listeners(List listeners) {
this.listeners = listeners;
return this;
}
public OllamaChatModel build() {
return new OllamaChatModel(
baseUrl,
modelName,
temperature,
topK,
topP,
repeatPenalty,
seed,
numPredict,
numCtx,
stop,
format,
timeout,
maxRetries,
customHeaders,
logRequests,
logResponses,
listeners
);
}
}
}