All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.composum.ai.backend.base.service.chat.impl.GPTEmbeddingServiceImpl Maven / Gradle / Ivy

package com.composum.ai.backend.base.service.chat.impl;


import java.nio.ByteBuffer;
import java.util.Base64;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;

import org.osgi.service.component.annotations.Component;
import org.osgi.service.component.annotations.Reference;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.composum.ai.backend.base.service.GPTException;
import com.composum.ai.backend.base.service.chat.GPTChatCompletionService;
import com.composum.ai.backend.base.service.chat.GPTConfiguration;
import com.composum.ai.backend.base.service.chat.GPTEmbeddingService;

@Component(service = GPTEmbeddingService.class)
public class GPTEmbeddingServiceImpl implements GPTEmbeddingService {

    private static final Logger LOG = LoggerFactory.getLogger(GPTEmbeddingServiceImpl.class);

    @Reference
    protected GPTChatCompletionService chatCompletionService;

    /**
     * Generates a map of the stored embeddings for texts.
     */
    @Nonnull
    Map embeddingsMap(@Nullable List texts , @Nullable EmbeddingsCache cache) {
        Map result = new java.util.HashMap<>();
        if (cache != null && texts != null) {
            texts.stream()
                    .filter(Objects::nonNull)
                    .distinct()
                    .forEach(text -> {
                        String cachedEmbedding = cache.getCachedEmbedding(text);
                        float[] embedding = decodeFloatArray(cachedEmbedding);
                        if (embedding != null) {
                            result.put(text, embedding);
                        }
                    });
        }
        return result;
    }


    @Override
    public List getEmbeddings(List texts, @Nullable GPTConfiguration configuration, @Nullable EmbeddingsCache cache) throws GPTException {
        LOG.debug("Getting embeddings for {} texts", texts.size());
        if (cache != null) {
            Map cached = embeddingsMap(texts, cache);

            List toFetch = texts.stream()
                    .filter(Objects::nonNull)
                    .filter(text -> !cached.containsKey(text))
                    .distinct()
                    .collect(Collectors.toList());
            List toFetchEmbeddings = chatCompletionService.getEmbeddings(toFetch, configuration);
            if (toFetchEmbeddings == null || toFetch.size() != toFetchEmbeddings.size()) {
                throw new GPTException("BUG: Expected " + toFetch.size() + " embeddings, got " + toFetchEmbeddings.size());
            }

            for (int i = 0; i < toFetch.size(); i++) {
                cache.putCachedEmbedding(toFetch.get(i), encodeFloatArray(toFetchEmbeddings.get(i)));
                cached.put(toFetch.get(i), toFetchEmbeddings.get(i));
            }

            return texts.stream()
                    .map(cached::get)
                    .collect(Collectors.toList());
        } else { // uncached
            return chatCompletionService.getEmbeddings(texts, configuration);
        }
    }

    @Override
    public List findMostRelated(String query, List comparedStrings, int limit, @Nullable GPTConfiguration configuration, @Nullable EmbeddingsCache thecache) throws GPTException {
        List embeddings = getEmbeddings(comparedStrings, configuration, thecache);
        float[] queryEmbedding = getEmbeddings(Collections.singletonList(query), configuration, thecache).get(0);
        Map withSimilarity = comparedStrings.stream()
                .collect(Collectors.toMap(s -> s,
                        s -> cosineSimilarity(queryEmbedding, embeddings.get(comparedStrings.indexOf(s)))));
        List> entries = withSimilarity.entrySet().stream()
                .sorted((e1, e2) -> Double.compare(e2.getValue(), e1.getValue()))
                .collect(Collectors.toList());
        if (LOG.isDebugEnabled()) {
            LOG.debug("cosine similarities: " + entries.stream().map(e -> e.getValue().floatValue()).collect(Collectors.toList()));
        }
        return entries.stream()
                .limit(limit)
                .map(Map.Entry::getKey)
                .collect(Collectors.toList());
    }

    protected double cosineSimilarity(float[] a, float[] b) {
        double dotProduct = 0.0;
        double normA = 0.0;
        double normB = 0.0;
        for (int i = 0; i < a.length; i++) {
            dotProduct += a[i] * b[i];
            normA += a[i] * a[i];
            normB += b[i] * b[i];
        }
        return dotProduct / (Math.sqrt(normA) * Math.sqrt(normB));
    }

    protected static String encodeFloatArray(float[] floatArray) {
        if (floatArray == null) {
            return null;
        }
        ByteBuffer byteBuffer = ByteBuffer.allocate(floatArray.length * 4); // Each float is 4 bytes
        for (float value : floatArray) {
            byteBuffer.putFloat(value);
        }
        byte[] byteArray = byteBuffer.array();
        return Base64.getEncoder().encodeToString(byteArray);
    }

    protected static float[] decodeFloatArray(String encodedString) {
        if (encodedString == null || encodedString.isEmpty()) {
            return null;
        }
        try {
            byte[] byteArray = Base64.getDecoder().decode(encodedString);
            ByteBuffer byteBuffer = ByteBuffer.wrap(byteArray);

            float[] floatArray = new float[byteArray.length / 4]; // Each float is 4 bytes
            for (int i = 0; i < floatArray.length; i++) {
                floatArray[i] = byteBuffer.getFloat();
            }
            return floatArray;
        } catch (IllegalArgumentException e) {
            LOG.warn("Could not decode float array from {}", encodedString.substring(0, Math.min(encodedString.length(), 80)), e);
            return null;
        }
    }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy