fi.evolver.ai.spring.embedding.EmbeddingVectors Maven / Gradle / Ivy
package fi.evolver.ai.spring.embedding;
import java.time.Duration;
import java.time.OffsetDateTime;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import org.apache.commons.codec.digest.DigestUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import fi.evolver.ai.spring.Model;
import fi.evolver.ai.spring.embedding.entity.EmbeddingVector;
import fi.evolver.basics.spring.lock.LockException;
import fi.evolver.basics.spring.lock.LockHandle;
import fi.evolver.basics.spring.lock.LockService;
public class EmbeddingVectors {
private static final Logger LOG = LoggerFactory.getLogger(EmbeddingVectors.class);
private static final String UPDATE_LOCK_NAME = EmbeddingVectors.class.getSimpleName() + "_UpdateLock";
private static final Duration UPDATE_LOCK_VALIDITY = Duration.ofMinutes(1);
private static final Duration UPDATE_LOCK_TRY_FOR = Duration.ofMinutes(2);
private final EmbeddingVectorApi embeddingVectorApi;
private final EmbeddingVectorRepository embeddingVectorRepository;
private final LockService lockService;
private final Model model;
private final Map memCache;
private final Map persistedTimestampsByHash;
private final Duration timeout;
private final String provider;
public EmbeddingVectors(EmbeddingVectorApi embeddingVectorApi, EmbeddingVectorRepository embeddingVectorRepository, LockService lockService, Model model, Duration timeout, String provider) {
this.embeddingVectorApi = embeddingVectorApi;
this.embeddingVectorRepository = embeddingVectorRepository;
this.lockService = lockService;
this.model = model;
this.timeout = timeout;
this.memCache = new ConcurrentHashMap<>();
this.persistedTimestampsByHash = new ConcurrentHashMap<>();
this.provider = provider;
}
/**
* Persist vectors from memory to the database. This serves two purposes:
* 1) persist calculated embeddings to be used on other instances and after restart
* 2) updates the lastAccessed timestamps that are used for cleaning up stale items from the database. The interval
* for updating should be considerably shorter than the duration after which non-accessed items are considered stale
* to avoid needless recalculation.
* @param timestampRefreshThreshold already persisted vectors are not updated if their timestamp is younger than timestampRefreshThreshold. Use Duration.ZERO for unconditional refresh.
* @return true if success, false if exclusive access could not be acquired
* @see #clearStale
*/
public boolean persist(Duration timestampRefreshThreshold) {
try (LockHandle handle = lockService.takeLock(UPDATE_LOCK_NAME, UPDATE_LOCK_VALIDITY, UPDATE_LOCK_TRY_FOR)) {
List vectorsToAdd = new ArrayList<>();
List vectorsToUpdate = new ArrayList<>();
for (EmbeddingVector memVector : memCache.values()) {
OffsetDateTime persistedTimestamp = persistedTimestampsByHash.get(memVector.getHash());
if (persistedTimestamp == null)
vectorsToAdd.add(memVector);
else if (persistedTimestamp.plus(timestampRefreshThreshold).isBefore(memVector.getLastAccessed())) {
vectorsToUpdate.add(memVector);
}
}
embeddingVectorRepository.persistChanges(vectorsToAdd, vectorsToUpdate);
vectorsToAdd.forEach(v -> persistedTimestampsByHash.put(v.getHash(), v.getLastAccessed()));
vectorsToUpdate.forEach(v -> persistedTimestampsByHash.put(v.getHash(), v.getLastAccessed()));
return true;
} catch (LockException e) {
LOG.info("Unable to get lock for embedding vector update");
return false;
}
}
/**
* Deletes from the database items whose lastUpdated timestamp is before stalenessCutoff. Depends on
* {@link #persist} to update the timestamps frequently enough.
* @param stalenessCutoff the cut-off date to decide which items are considered stale
* @return true if success, false if exclusive access could not be acquired
*/
public boolean clearStale(OffsetDateTime stalenessCutoff) {
try (LockHandle handle = lockService.takeLock(UPDATE_LOCK_NAME, UPDATE_LOCK_VALIDITY, UPDATE_LOCK_TRY_FOR)) {
embeddingVectorRepository.deleteStaleData(stalenessCutoff);
memCache.entrySet().removeIf(e ->
e.getValue().getLastAccessed().isBefore(stalenessCutoff));
return true;
} catch (LockException e) {
LOG.info("Unable to get lock for embedding vector clean-up");
return false;
}
}
public double[] getEmbedding(String text) {
return getEmbeddings(List.of(text)).get(text);
}
public Map getEmbeddings(Collection texts) {
Map vectorsByText = new LinkedHashMap<>();
for (String text : texts)
vectorsByText.put(text, getCachedVectorOrNull(calculateHash(text)));
List missingTexts = vectorsByText.entrySet().stream()
.filter(e -> e.getValue() == null)
.map(Map.Entry::getKey)
.toList();
if (!missingTexts.isEmpty())
vectorsByText.putAll(createNewEntities(missingTexts));
return vectorsByText;
}
private Map createNewEntities(List missingTexts) {
List newVectors = embeddingVectorApi.createEmbeddingVectorsInBatches(provider, model, missingTexts, timeout);
Map results = new HashMap<>();
for (int i = 0; i < missingTexts.size(); i++) {
String hash = calculateHash(missingTexts.get(i));
EmbeddingVector newEntity = new EmbeddingVector(model, hash, newVectors.get(i));
memCache.put(hash, newEntity);
results.put(missingTexts.get(i), newEntity.getVector());
}
return results;
}
private static String calculateHash(String x) {
return DigestUtils.sha256Hex(x);
}
private double[] getCachedVectorOrNull(String hash) {
return getCachedVector(hash)
.map(EmbeddingVector::getVector)
.orElse(null);
}
private Optional getCachedVector(String hash) {
Optional result = Optional.ofNullable(
memCache.computeIfAbsent(hash, this::findPersistedVector)
);
result.ifPresent(x -> x.setLastAccessed(OffsetDateTime.now()));
return result;
}
private EmbeddingVector findPersistedVector(String hash) {
Optional entity = embeddingVectorRepository.findByModelAndHash(model.name(), hash);
entity.ifPresent(e -> persistedTimestampsByHash.put(e.getHash(), e.getLastAccessed()));
return entity.orElse(null);
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy