All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.composum.ai.backend.slingbase.impl.RAGServiceImpl Maven / Gradle / Ivy

Go to download

Common Functionality for Composum AI specific to Sling but would be useable in both Composum and AEM and similar.

There is a newer version: 1.2.1
Show newest version
package com.composum.ai.backend.slingbase.impl;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.TreeMap;
import java.util.concurrent.atomic.AtomicLong;
import java.util.stream.Collectors;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import javax.jcr.RepositoryException;
import javax.jcr.Session;
import javax.jcr.query.Query;
import javax.jcr.query.QueryManager;
import javax.jcr.query.QueryResult;
import javax.jcr.query.Row;
import javax.jcr.query.RowIterator;

import org.apache.sling.api.SlingHttpServletRequest;
import org.apache.sling.api.SlingHttpServletResponse;
import org.apache.sling.api.resource.Resource;
import org.apache.sling.api.resource.ResourceResolver;
import org.jetbrains.annotations.NotNull;
import org.osgi.service.component.annotations.Component;
import org.osgi.service.component.annotations.Reference;
import org.osgi.service.component.annotations.ReferenceCardinality;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.composum.ai.backend.base.service.GPTException;
import com.composum.ai.backend.base.service.chat.GPTChatCompletionService;
import com.composum.ai.backend.base.service.chat.GPTChatRequest;
import com.composum.ai.backend.base.service.chat.GPTConfiguration;
import com.composum.ai.backend.base.service.chat.GPTEmbeddingService;
import com.composum.ai.backend.base.service.chat.GPTMessageRole;
import com.composum.ai.backend.slingbase.AIConfigurationService;
import com.composum.ai.backend.slingbase.ApproximateMarkdownService;
import com.composum.ai.backend.slingbase.PageCachedValueService;
import com.composum.ai.backend.slingbase.RAGService;

/**
 * Basic services for retrieval augmented generation (RAG).
 */
@Component(service = RAGService.class)
public class RAGServiceImpl implements RAGService {

    private static final Logger LOG = LoggerFactory.getLogger(RAGServiceImpl.class);

    @Reference
    protected ApproximateMarkdownService markdownService;

    @Reference
    protected GPTEmbeddingService embeddingService;

    @Reference
    protected AIConfigurationService aiConfigurationService;

    @Reference
    protected GPTChatCompletionService chatCompletionService;

    @Reference(cardinality = ReferenceCardinality.OPTIONAL)
    protected PageCachedValueService pageCachedValueService;

    protected final AtomicLong requestCounter = new AtomicLong(System.currentTimeMillis() / 2);

    @Override
    @Nonnull
    public List searchRelated(@Nullable Resource root, @Nullable String querytext, int limit) {
        if (root == null || querytext == null || limit <= 0) {
            return Collections.emptyList();
        }
        int restOfLimit = limit * 5 / 4 + 3; // a little larger since there might be exact and inexact matches

        String exactQuery = "\"" + querytext.replaceAll("\"", "") + "\"";
        String normalizedQuery = normalize(querytext);

        @NotNull List exactResult = Collections.emptyList();
        try {
            exactResult = containsQuery(root, exactQuery, restOfLimit);
        } catch (RepositoryException e) {
            LOG.error("Error searching for exact query {}", exactQuery, e);
        }
        restOfLimit -= exactResult.size();
        LOG.trace("Exact query result: {}", exactResult);

        @NotNull List normalizedResult = Collections.emptyList();
        try {
            normalizedResult = containsQuery(root, normalizedQuery, restOfLimit);
        } catch (RepositoryException e) {
            LOG.error("Error searching for normalized query {}", normalizedQuery, e);
        }
        LOG.trace("Normalized query result: {}", normalizedResult);

        List result = new ArrayList<>(exactResult);
        result.addAll(normalizedResult);
        result = result.stream().distinct().limit(limit).collect(Collectors.toList());
        return result;
    }

    protected @NotNull List containsQuery(@NotNull Resource root, @NotNull String querytext, int restOfLimit) throws RepositoryException {
        List result = new ArrayList<>();
        ResourceResolver resolver = root.getResourceResolver();
        final Session session = Objects.requireNonNull(resolver.adaptTo(Session.class));
        final QueryManager queryManager = session.getWorkspace().getQueryManager();
        String statement = "SELECT [jcr:path], [jcr:score] FROM [nt:base] AS content WHERE " +
                "ISDESCENDANTNODE(content, '" + root.getPath() + "') " +
                "AND NAME(content) = 'jcr:content' " +
                "AND CONTAINS(content.*, $queryText) " +
                "ORDER BY [jcr:score] DESC";
        // equivalent Composum Nodes query template for testing
        // SELECT [jcr:path], [jcr:score] FROM [nt:base] AS content WHERE ISDESCENDANTNODE(content, '${root_path.path}')  AND NAME(content) = 'jcr:content' AND CONTAINS(content.*, '${text.3}') ORDER BY [jcr:score] DESC
        Query query = queryManager.createQuery(statement, Query.JCR_SQL2);
        query.bindValue("queryText", session.getValueFactory().createValue(querytext));
        query.setLimit(restOfLimit);
        LOG.trace("Executing query:\n{}\nwith\n{}", query.getStatement(), querytext);
        QueryResult queryResult = query.execute();
        for (RowIterator rowIterator = queryResult.getRows(); rowIterator.hasNext(); ) {
            if (restOfLimit-- <= 0) {
                return result;
            }
            Row row = rowIterator.nextRow();
            String path = row.getValue("jcr:path").getString();
            LOG.trace("Found path {} with score {}", path, row.getValue("jcr:score").getDouble());
            if (!result.contains(path)) {
                result.add(path);
            }
        }
        return result;
    }

    /**
     * Turn it into a query for the words mentioned in there - that is, remove all meta characters for CONTAINS queries:
     * AND, OR, words prefixed with -, quotes, backslashes. We use an OR query to find pages with as many words as possible.
     */
    @Nonnull
    protected String normalize(@Nonnull String querytext) {
        return Arrays.stream(querytext.split("\\s+"))
                .map(s -> s.replaceAll("[\"\\\\']", ""))
                .map(s -> s.replaceAll("^-+", ""))
                .filter(s -> !s.equals("OR"))
                .filter(s -> !s.equals("AND"))
                .collect(Collectors.joining(" OR "));
    }

    /**
     * Finds the resources whose markdown approximation has embeddings that are the most similar to the querytext embedding.
     * Useable e.g. as filter after {@link #searchRelated(Resource, String, int)}.
     */
    @Override
    @Nonnull
    public List orderByEmbedding(@Nullable String querytext, @Nonnull List resources,
                                           @NotNull SlingHttpServletRequest request, @NotNull SlingHttpServletResponse response,
                                           @NotNull Resource rootResource) {
        Map textToPath = new TreeMap<>();
        Map textToResource = new TreeMap<>();
        for (Resource resource : resources) {
            String markdown = markdownService.approximateMarkdown(resource, request, response);
            textToPath.put(markdown, resource.getPath());
            textToResource.put(markdown, resource);
        }
        GPTConfiguration config = aiConfigurationService.getGPTConfiguration(rootResource.getResourceResolver(), rootResource.getPath());
        List relatedTexts = embeddingService.findMostRelated(querytext, new ArrayList<>(textToPath.keySet()),
                Integer.MAX_VALUE, config, getEmbeddingsCache(textToResource));
        Map pathToResource = resources.stream().collect(Collectors.toMap(r -> r.getPath(), r -> r));
        List result = relatedTexts.stream()
                .map(textToPath::get)
                .map(pathToResource::get)
                .collect(Collectors.toList());
        return result;
    }

    /**
     * Answer a question with RAG from the given resources, e.g. found with {@link #searchRelated(Resource, String, int)}.
     *
     * @param querytext     the query text
     * @param resources     the list of resources to answer from
     * @param request       the request to use when determining the markdown approximation - not modified
     * @param response      the response to use when determining the markdown approximation - not modified
     * @param rootResource  the root resource to find GPT configuration from
     * @param limitRagTexts the maximum number of RAG texts to consider
     * @return the answer text
     */
    @Override
    public String ragAnswer(@Nullable String querytext, @Nonnull List resources,
                            @Nonnull SlingHttpServletRequest request, @Nonnull SlingHttpServletResponse response,
                            @NotNull Resource rootResource, int limitRagTexts) {
        long id = requestCounter.incrementAndGet();
        Map textToPath = new TreeMap<>();
        Map textToResource = new TreeMap<>();
        for (Resource resource : resources) {
            String markdown = markdownService.approximateMarkdown(resource, request, response);
            textToPath.put(markdown, resource.getPath());
            textToResource.put(markdown, resource);
        }
        GPTConfiguration config = aiConfigurationService.getGPTConfiguration(rootResource.getResourceResolver(), rootResource.getPath());
        List bestMatches = embeddingService.findMostRelated(querytext,
                new ArrayList<>(textToPath.keySet()), limitRagTexts, config, getEmbeddingsCache(textToResource));
        LOG.debug("ragAnswer: query for {} is {}", id, request);
        GPTChatRequest chatRequest = new GPTChatRequest(config);
        Collections.reverse(bestMatches); // make the most relevant last, near the actual question
        int limit = bestMatches.size();
        while (limit >= 1) {
            try {
                for (String text : bestMatches.subList(0, limit)) {
                    String textPath = textToPath.get(text);
                    chatRequest.addMessage(GPTMessageRole.USER, "For answering my question later, retrieve the text of the possibly relevant page: "
                            + textPath.replaceAll("/jcr:content", ".html"));
                    chatRequest.addMessage(GPTMessageRole.ASSISTANT, text);
                    LOG.debug("ragAnswer: Using for {} path {}", id, textPath);
                }
                chatRequest.addMessage(GPTMessageRole.USER, "Considering this information, please answer the following as Markdown text without enumeration, including links to the relevant retrieved pages above:\n\n" + querytext);
                LOG.debug("ragAnswer: request {} : {}", id, request);
                String answer = chatCompletionService.getSingleChatCompletion(chatRequest);
                LOG.debug("ragAnswer: response {} : {}", id, answer);
                return answer;
            } catch (GPTException.GPTContextLengthExceededException e) {
                // retry with lower number of texts
                limit = limit * 2 / 3;
                LOG.info("ragAnswer: retrying with lower number of texts because of content length exceeded exception: {}", limit);
            }
        }
        if (limit == 0 && !bestMatches.isEmpty()) {
            return "(No answer: context length exceeded.)";
        }
        return "(No answer found).";
    }

    protected GPTEmbeddingService.EmbeddingsCache getEmbeddingsCache(final Map textToResource) {
        if (pageCachedValueService == null) {
            return null;
        }
        final String key = "pagemarkdown-embedding-" + chatCompletionService.getEmbeddingsModel();
        return new GPTEmbeddingService.EmbeddingsCache() {

            @Override
            public String getCachedEmbedding(String text) {
                Resource resource = textToResource.get(text);
                return resource != null ? pageCachedValueService.getPageCachedValue(key, resource) : null;
            }

            @Override
            public void putCachedEmbedding(String text, String embedding) {
                Resource resource = textToResource.get(text);
                if (resource != null) {
                    pageCachedValueService.putPageCachedValue(key, resource, embedding);
                }
            }
        };
    }


    /**
     * Processes a query to have the AI suggest a couple of search keywords for use with the other methods that might find the most relevant results.
     *
     * @param querytext   the query text for which we find keywords
     * @param rootResource  the root resource to find GPT configuration from
     * @return a list of keywords
     * @throws RepositoryException
     */
    @Override
    @Nonnull
    public List collectSearchKeywords(@Nullable String querytext, @Nonnull Resource rootResource) throws RepositoryException {
        GPTConfiguration config = aiConfigurationService.getGPTConfiguration(rootResource.getResourceResolver(), rootResource.getPath());
        GPTChatRequest request = new GPTChatRequest(config)
                .addMessage(GPTMessageRole.SYSTEM, "Print up to 7 keywords to search for in documents with a BM25 algorithm which are likely to appear in documents answering the users question, but not in documents irrelevant to that.\n" +
                        "The keywords should be selected to maximize the relevance of the retrieved high scoring documents, specifically aiming to answer the user's question.\n" +
                        "The keywords can be words from the users question, synonyms or other words you would expect to be present especially in a document answering the question.\n" +
                        "Print the keywords (single words) as comma separated list.")
                .addMessage(GPTMessageRole.USER, querytext);
        String result = chatCompletionService.getSingleChatCompletion(request);
        LOG.debug("collectSearchKeywords: for '{}' got '{}'", querytext, result);
        if (result == null) {
            return Collections.emptyList();
        }
        return Arrays.asList(result.trim().split("\\s*,\\s*"));
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy