All Downloads are FREE. Search and download functionalities are using the official Maven repository.

dev.langchain4j.model.ollama.OllamaMessagesUtils Maven / Gradle / Ivy

There is a newer version: 0.36.0
Show newest version
package dev.langchain4j.model.ollama;

import com.fasterxml.jackson.core.type.TypeReference;
import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.*;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Predicate;
import java.util.stream.Collectors;

import static dev.langchain4j.data.message.ContentType.IMAGE;
import static dev.langchain4j.data.message.ContentType.TEXT;
import static dev.langchain4j.model.ollama.OllamaJsonUtils.toJson;
import static dev.langchain4j.model.ollama.OllamaJsonUtils.toObject;

class OllamaMessagesUtils {

    private static final Predicate isUserMessage =
            UserMessage.class::isInstance;
    private static final Predicate hasImages =
            userMessage -> userMessage.contents().stream()
                    .anyMatch(content -> IMAGE.equals(content.type()));

    static List toOllamaMessages(List messages) {
        return messages.stream()
                .map(message -> isUserMessage.test(message) && hasImages.test((UserMessage) message) ?
                        messagesWithImageSupport((UserMessage) message)
                        : otherMessages(message)
                ).collect(Collectors.toList());
    }

    static List toOllamaTools(List toolSpecifications) {
        if (toolSpecifications == null) {
            return null;
        }
        return toolSpecifications.stream().map(toolSpecification ->
                        Tool.builder()
                                .function(Function.builder()
                                        .name(toolSpecification.name())
                                        .description(toolSpecification.description())
                                        .parameters(toolSpecification.parameters() == null ? null : Parameters.builder()
                                                .properties(toolSpecification.parameters().properties())
                                                .required(toolSpecification.parameters().required())
                                                .build())
                                        .build())
                                .build())
                .collect(Collectors.toList());
    }

    static List toToolExecutionRequest(List toolCalls) {
        return toolCalls.stream().map(toolCall ->
                        ToolExecutionRequest.builder()
                                .name(toolCall.getFunction().getName())
                                .arguments(toJson(toolCall.getFunction().getArguments()))
                                .build())
                .collect(Collectors.toList());
    }

    private static Message messagesWithImageSupport(UserMessage userMessage) {
        Map> groupedContents = userMessage.contents().stream()
                .collect(Collectors.groupingBy(Content::type));

        if (groupedContents.get(TEXT).size() != 1) {
            throw new RuntimeException("Expecting single text content, but got: " + userMessage.contents());
        }

        String text = ((TextContent) groupedContents.get(TEXT).get(0)).text();

        List imageContents = groupedContents.get(IMAGE).stream()
                .map(content -> (ImageContent) content)
                .collect(Collectors.toList());

        return Message.builder()
                .role(toOllamaRole(userMessage.type()))
                .content(text)
                .images(ImageUtils.base64EncodeImageList(imageContents))
                .build();
    }

    private static Message otherMessages(ChatMessage chatMessage) {
        List toolCalls = null;
        if (ChatMessageType.AI == chatMessage.type()) {
            AiMessage aiMessage = (AiMessage) chatMessage;
            List toolExecutionRequests = aiMessage.toolExecutionRequests();
            toolCalls = Optional.ofNullable(toolExecutionRequests)
                    .map(reqs -> reqs.stream()
                            .map(toolExecutionRequest -> {
                                TypeReference> typeReference = new TypeReference>() {
                                };
                                FunctionCall functionCall = FunctionCall.builder()
                                        .name(toolExecutionRequest.name())
                                        .arguments(toObject(toolExecutionRequest.arguments(), typeReference))
                                        .build();
                                return ToolCall.builder()
                                        .function(functionCall).build();
                            }).collect(Collectors.toList()))
                    .orElse(null);

        }
        return Message.builder()
                .role(toOllamaRole(chatMessage.type()))
                .content(chatMessage.text())
                .toolCalls(toolCalls)
                .build();
    }

    private static Role toOllamaRole(ChatMessageType chatMessageType) {
        switch (chatMessageType) {
            case SYSTEM:
                return Role.SYSTEM;
            case USER:
                return Role.USER;
            case AI:
                return Role.ASSISTANT;
            case TOOL_EXECUTION_RESULT:
                return Role.TOOL;
            default:
                throw new IllegalArgumentException("Unknown ChatMessageType: " + chatMessageType);
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy