All Downloads are FREE. Search and download functionalities are using the official Maven repository.

dev.langchain4j.rag.query.router.LanguageModelQueryRouter Maven / Gradle / Ivy

package dev.langchain4j.rag.query.router;

import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.input.Prompt;
import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.rag.content.retriever.ContentRetriever;
import dev.langchain4j.rag.query.Query;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.Map;

import static dev.langchain4j.internal.Utils.getOrDefault;
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
import static dev.langchain4j.internal.ValidationUtils.ensureNotEmpty;
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;
import static dev.langchain4j.rag.query.router.LanguageModelQueryRouter.FallbackStrategy.DO_NOT_ROUTE;
import static java.util.Arrays.stream;
import static java.util.Collections.emptyList;
import static java.util.stream.Collectors.toList;

/**
 * A {@link QueryRouter} that utilizes a {@link ChatLanguageModel} to make a routing decision.
 * 
* Each {@link ContentRetriever} provided in the constructor should be accompanied by a description which * should help the LLM to decide where to route a {@link Query}. *
* Refer to {@link #DEFAULT_PROMPT_TEMPLATE} and implementation for more details. *
*
* Configurable parameters (optional): *
* - {@link #promptTemplate}: The prompt template used to ask the LLM for routing decisions. *
* - {@link #fallbackStrategy}: The strategy applied if the call to the LLM fails of if LLM does not return a valid response. * Please check {@link FallbackStrategy} for more details. Default value: {@link FallbackStrategy#DO_NOT_ROUTE} * * @see DefaultQueryRouter */ public class LanguageModelQueryRouter implements QueryRouter { private static final Logger log = LoggerFactory.getLogger(LanguageModelQueryRouter.class); public static final PromptTemplate DEFAULT_PROMPT_TEMPLATE = PromptTemplate.from( """ Based on the user query, determine the most suitable data source(s) \ to retrieve relevant information from the following options: {{options}} It is very important that your answer consists of either a single number \ or multiple numbers separated by commas and nothing else! User query: {{query}}""" ); protected final ChatLanguageModel chatLanguageModel; protected final PromptTemplate promptTemplate; protected final String options; protected final Map idToRetriever; protected final FallbackStrategy fallbackStrategy; public LanguageModelQueryRouter(ChatLanguageModel chatLanguageModel, Map retrieverToDescription) { this(chatLanguageModel, retrieverToDescription, DEFAULT_PROMPT_TEMPLATE, DO_NOT_ROUTE); } public LanguageModelQueryRouter(ChatLanguageModel chatLanguageModel, Map retrieverToDescription, PromptTemplate promptTemplate, FallbackStrategy fallbackStrategy) { this.chatLanguageModel = ensureNotNull(chatLanguageModel, "chatLanguageModel"); ensureNotEmpty(retrieverToDescription, "retrieverToDescription"); this.promptTemplate = getOrDefault(promptTemplate, DEFAULT_PROMPT_TEMPLATE); Map idToRetriever = new HashMap<>(); StringBuilder optionsBuilder = new StringBuilder(); int id = 1; for (Map.Entry entry : retrieverToDescription.entrySet()) { idToRetriever.put(id, ensureNotNull(entry.getKey(), "ContentRetriever")); if (id > 1) { optionsBuilder.append("\n"); } optionsBuilder.append(id); optionsBuilder.append(": "); optionsBuilder.append(ensureNotBlank(entry.getValue(), "ContentRetriever description")); id++; } this.idToRetriever = idToRetriever; this.options = optionsBuilder.toString(); this.fallbackStrategy = getOrDefault(fallbackStrategy, DO_NOT_ROUTE); } public static LanguageModelQueryRouterBuilder builder() { return new LanguageModelQueryRouterBuilder(); } @Override public Collection route(Query query) { Prompt prompt = createPrompt(query); try { String response = chatLanguageModel.generate(prompt.text()); return parse(response); } catch (Exception e) { log.warn("Failed to route query '{}'", query.text(), e); return fallback(query, e); } } protected Collection fallback(Query query, Exception e) { return switch (fallbackStrategy) { case DO_NOT_ROUTE -> { log.debug("Fallback: query '{}' will not be routed", query.text()); yield emptyList(); } case ROUTE_TO_ALL -> { log.debug("Fallback: query '{}' will be routed to all available content retrievers", query.text()); yield new ArrayList<>(idToRetriever.values()); } default -> throw new RuntimeException(e); }; } protected Prompt createPrompt(Query query) { Map variables = new HashMap<>(); variables.put("query", query.text()); variables.put("options", options); return promptTemplate.apply(variables); } protected Collection parse(String choices) { return stream(choices.split(",")) .map(String::trim) .map(Integer::parseInt) .map(idToRetriever::get) .collect(toList()); } /** * Strategy applied if the call to the LLM fails of if LLM does not return a valid response. * It could be because it was formatted improperly, or it is unclear where to route. */ public enum FallbackStrategy { /** * In this case, the {@link Query} will not be routed to any {@link ContentRetriever}, * thus skipping the RAG flow. No content will be appended to the original {@link UserMessage}. */ DO_NOT_ROUTE, /** * In this case, the {@link Query} will be routed to all {@link ContentRetriever}s. */ ROUTE_TO_ALL, /** * In this case, an original exception will be re-thrown, and the RAG flow will fail. */ FAIL } public static class LanguageModelQueryRouterBuilder { private ChatLanguageModel chatLanguageModel; private Map retrieverToDescription; private PromptTemplate promptTemplate; private FallbackStrategy fallbackStrategy; LanguageModelQueryRouterBuilder() { } public LanguageModelQueryRouterBuilder chatLanguageModel(ChatLanguageModel chatLanguageModel) { this.chatLanguageModel = chatLanguageModel; return this; } public LanguageModelQueryRouterBuilder retrieverToDescription(Map retrieverToDescription) { this.retrieverToDescription = retrieverToDescription; return this; } public LanguageModelQueryRouterBuilder promptTemplate(PromptTemplate promptTemplate) { this.promptTemplate = promptTemplate; return this; } public LanguageModelQueryRouterBuilder fallbackStrategy(FallbackStrategy fallbackStrategy) { this.fallbackStrategy = fallbackStrategy; return this; } public LanguageModelQueryRouter build() { return new LanguageModelQueryRouter(this.chatLanguageModel, this.retrieverToDescription, this.promptTemplate, this.fallbackStrategy); } public String toString() { return "LanguageModelQueryRouter.LanguageModelQueryRouterBuilder(chatLanguageModel=" + this.chatLanguageModel + ", retrieverToDescription=" + this.retrieverToDescription + ", promptTemplate=" + this.promptTemplate + ", fallbackStrategy=" + this.fallbackStrategy + ")"; } } }




© 2015 - 2025 Weber Informatics LLC | Privacy Policy