fi.evolver.ai.spring.embedding.EmbeddingCache Maven / Gradle / Ivy
package fi.evolver.ai.spring.embedding;
import java.time.ZonedDateTime;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.TreeMap;
import java.util.function.Function;
import java.util.stream.Collectors;
import fi.evolver.ai.spring.Model;
import fi.evolver.ai.spring.embedding.model.EmbeddingData;
public class EmbeddingCache {
private final Model model;
private final Map itemsByIdentifier;
private final Map vectorsById;
public EmbeddingCache(Model model, List items) {
this.model = model;
itemsByIdentifier = items.stream().collect(Collectors.toMap(
EmbeddingCacheItem::dataIdentifier,
Function.identity()));
vectorsById = items.stream().collect(Collectors.toMap(
EmbeddingCacheItem::id,
EmbeddingCacheItem::vector));
}
public Model getModel() {
return model;
}
public Optional getCreationTime(String dataIdentifier) {
return Optional.ofNullable(itemsByIdentifier.get(dataIdentifier))
.map(EmbeddingCacheItem::creationTime);
}
public boolean hasChanged(String dataIdentifier, String newValue) {
String newHash = EmbeddingService.calculateHash(newValue);
return !Optional.ofNullable(itemsByIdentifier.get(dataIdentifier))
.map(EmbeddingCacheItem::hash)
.map(newHash::equals)
.orElse(false);
}
List findClosestMatches(EmbeddingData input, int maxCount) {
TreeMap similaritiesById = new TreeMap<>(Collections.reverseOrder());
for (Map.Entry embeddingById : vectorsById.entrySet())
similaritiesById.put(cosineSimilarity(input.embedding(), embeddingById.getValue()), embeddingById.getKey());
return similaritiesById.values().stream()
.limit(maxCount)
.toList();
}
/**
* Calculates cosine similarity between two embeddings
*
* @return similarity score between 0.0 and 1.0 where 1.0 means identical
*/
private static double cosineSimilarity(double[] a, double[] b) {
if (a.length != b.length)
throw new IllegalArgumentException("Vectors must be of same length");
double dotProduct = 0;
double aLength = 0;
double bLength = 0;
for (int i = 0; i < a.length; i++) {
dotProduct += a[i] * b[i];
aLength += a[i] * a[i];
bLength += b[i] * b[i];
}
return dotProduct / (Math.sqrt(aLength) * Math.sqrt(bLength));
}
public static record EmbeddingCacheItem(
Long id,
String dataIdentifier,
String hash,
double[] vector,
ZonedDateTime creationTime) {}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy