
dev.langchain4j.rag.query.router.LanguageModelQueryRouter Maven / Gradle / Ivy
package dev.langchain4j.rag.query.router;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.input.Prompt;
import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.rag.content.retriever.ContentRetriever;
import dev.langchain4j.rag.query.Query;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.Map;
import static dev.langchain4j.internal.Utils.getOrDefault;
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
import static dev.langchain4j.internal.ValidationUtils.ensureNotEmpty;
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;
import static dev.langchain4j.rag.query.router.LanguageModelQueryRouter.FallbackStrategy.DO_NOT_ROUTE;
import static java.util.Arrays.stream;
import static java.util.Collections.emptyList;
import static java.util.stream.Collectors.toList;
/**
* A {@link QueryRouter} that utilizes a {@link ChatLanguageModel} to make a routing decision.
*
* Each {@link ContentRetriever} provided in the constructor should be accompanied by a description which
* should help the LLM to decide where to route a {@link Query}.
*
* Refer to {@link #DEFAULT_PROMPT_TEMPLATE} and implementation for more details.
*
*
* Configurable parameters (optional):
*
* - {@link #promptTemplate}: The prompt template used to ask the LLM for routing decisions.
*
* - {@link #fallbackStrategy}: The strategy applied if the call to the LLM fails of if LLM does not return a valid response.
* Please check {@link FallbackStrategy} for more details. Default value: {@link FallbackStrategy#DO_NOT_ROUTE}
*
* @see DefaultQueryRouter
*/
public class LanguageModelQueryRouter implements QueryRouter {
private static final Logger log = LoggerFactory.getLogger(LanguageModelQueryRouter.class);
public static final PromptTemplate DEFAULT_PROMPT_TEMPLATE = PromptTemplate.from(
"""
Based on the user query, determine the most suitable data source(s) \
to retrieve relevant information from the following options:
{{options}}
It is very important that your answer consists of either a single number \
or multiple numbers separated by commas and nothing else!
User query: {{query}}"""
);
protected final ChatLanguageModel chatLanguageModel;
protected final PromptTemplate promptTemplate;
protected final String options;
protected final Map idToRetriever;
protected final FallbackStrategy fallbackStrategy;
public LanguageModelQueryRouter(ChatLanguageModel chatLanguageModel,
Map retrieverToDescription) {
this(chatLanguageModel, retrieverToDescription, DEFAULT_PROMPT_TEMPLATE, DO_NOT_ROUTE);
}
public LanguageModelQueryRouter(ChatLanguageModel chatLanguageModel,
Map retrieverToDescription,
PromptTemplate promptTemplate,
FallbackStrategy fallbackStrategy) {
this.chatLanguageModel = ensureNotNull(chatLanguageModel, "chatLanguageModel");
ensureNotEmpty(retrieverToDescription, "retrieverToDescription");
this.promptTemplate = getOrDefault(promptTemplate, DEFAULT_PROMPT_TEMPLATE);
Map idToRetriever = new HashMap<>();
StringBuilder optionsBuilder = new StringBuilder();
int id = 1;
for (Map.Entry entry : retrieverToDescription.entrySet()) {
idToRetriever.put(id, ensureNotNull(entry.getKey(), "ContentRetriever"));
if (id > 1) {
optionsBuilder.append("\n");
}
optionsBuilder.append(id);
optionsBuilder.append(": ");
optionsBuilder.append(ensureNotBlank(entry.getValue(), "ContentRetriever description"));
id++;
}
this.idToRetriever = idToRetriever;
this.options = optionsBuilder.toString();
this.fallbackStrategy = getOrDefault(fallbackStrategy, DO_NOT_ROUTE);
}
public static LanguageModelQueryRouterBuilder builder() {
return new LanguageModelQueryRouterBuilder();
}
@Override
public Collection route(Query query) {
Prompt prompt = createPrompt(query);
try {
String response = chatLanguageModel.generate(prompt.text());
return parse(response);
} catch (Exception e) {
log.warn("Failed to route query '{}'", query.text(), e);
return fallback(query, e);
}
}
protected Collection fallback(Query query, Exception e) {
return switch (fallbackStrategy) {
case DO_NOT_ROUTE -> {
log.debug("Fallback: query '{}' will not be routed", query.text());
yield emptyList();
}
case ROUTE_TO_ALL -> {
log.debug("Fallback: query '{}' will be routed to all available content retrievers", query.text());
yield new ArrayList<>(idToRetriever.values());
}
default -> throw new RuntimeException(e);
};
}
protected Prompt createPrompt(Query query) {
Map variables = new HashMap<>();
variables.put("query", query.text());
variables.put("options", options);
return promptTemplate.apply(variables);
}
protected Collection parse(String choices) {
return stream(choices.split(","))
.map(String::trim)
.map(Integer::parseInt)
.map(idToRetriever::get)
.collect(toList());
}
/**
* Strategy applied if the call to the LLM fails of if LLM does not return a valid response.
* It could be because it was formatted improperly, or it is unclear where to route.
*/
public enum FallbackStrategy {
/**
* In this case, the {@link Query} will not be routed to any {@link ContentRetriever},
* thus skipping the RAG flow. No content will be appended to the original {@link UserMessage}.
*/
DO_NOT_ROUTE,
/**
* In this case, the {@link Query} will be routed to all {@link ContentRetriever}s.
*/
ROUTE_TO_ALL,
/**
* In this case, an original exception will be re-thrown, and the RAG flow will fail.
*/
FAIL
}
public static class LanguageModelQueryRouterBuilder {
private ChatLanguageModel chatLanguageModel;
private Map retrieverToDescription;
private PromptTemplate promptTemplate;
private FallbackStrategy fallbackStrategy;
LanguageModelQueryRouterBuilder() {
}
public LanguageModelQueryRouterBuilder chatLanguageModel(ChatLanguageModel chatLanguageModel) {
this.chatLanguageModel = chatLanguageModel;
return this;
}
public LanguageModelQueryRouterBuilder retrieverToDescription(Map retrieverToDescription) {
this.retrieverToDescription = retrieverToDescription;
return this;
}
public LanguageModelQueryRouterBuilder promptTemplate(PromptTemplate promptTemplate) {
this.promptTemplate = promptTemplate;
return this;
}
public LanguageModelQueryRouterBuilder fallbackStrategy(FallbackStrategy fallbackStrategy) {
this.fallbackStrategy = fallbackStrategy;
return this;
}
public LanguageModelQueryRouter build() {
return new LanguageModelQueryRouter(this.chatLanguageModel, this.retrieverToDescription, this.promptTemplate, this.fallbackStrategy);
}
public String toString() {
return "LanguageModelQueryRouter.LanguageModelQueryRouterBuilder(chatLanguageModel=" + this.chatLanguageModel + ", retrieverToDescription=" + this.retrieverToDescription + ", promptTemplate=" + this.promptTemplate + ", fallbackStrategy=" + this.fallbackStrategy + ")";
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy