All Downloads are FREE. Search and download functionalities are using the official Maven repository.

fi.evolver.ai.spring.embedding.EmbeddingVectors Maven / Gradle / Ivy

package fi.evolver.ai.spring.embedding;

import java.time.Duration;
import java.time.OffsetDateTime;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;

import org.apache.commons.codec.digest.DigestUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import fi.evolver.ai.spring.Model;
import fi.evolver.ai.spring.embedding.entity.EmbeddingVector;
import fi.evolver.basics.spring.lock.LockException;
import fi.evolver.basics.spring.lock.LockHandle;
import fi.evolver.basics.spring.lock.LockService;

public class EmbeddingVectors {
	private static final Logger LOG = LoggerFactory.getLogger(EmbeddingVectors.class);

	private static final String UPDATE_LOCK_NAME = EmbeddingVectors.class.getSimpleName() + "_UpdateLock";
	private static final Duration UPDATE_LOCK_VALIDITY = Duration.ofMinutes(1);
	private static final Duration UPDATE_LOCK_TRY_FOR = Duration.ofMinutes(2);


	private final EmbeddingVectorApi embeddingVectorApi;
	private final EmbeddingVectorRepository embeddingVectorRepository;
	private final LockService lockService;
	private final Model model;
	private final Map memCache;

	private final Map persistedTimestampsByHash;
	private final Duration timeout;
	private final String provider;

	public EmbeddingVectors(EmbeddingVectorApi embeddingVectorApi, EmbeddingVectorRepository embeddingVectorRepository, LockService lockService, Model model, Duration timeout, String provider) {
		this.embeddingVectorApi = embeddingVectorApi;
		this.embeddingVectorRepository = embeddingVectorRepository;
		this.lockService = lockService;
		this.model = model;
		this.timeout = timeout;
		this.memCache = new ConcurrentHashMap<>();
		this.persistedTimestampsByHash = new ConcurrentHashMap<>();
		this.provider = provider;
	}

	/**
	 * Persist vectors from memory to the database. This serves two purposes:
	 * 1) persist calculated embeddings to be used on other instances and after restart
	 * 2) updates the lastAccessed timestamps that are used for cleaning up stale items from the database. The interval
	 * for updating should be considerably shorter than the duration after which non-accessed items are considered stale
	 * to avoid needless recalculation.
	 * @param timestampRefreshThreshold already persisted vectors are not updated if their timestamp is younger than timestampRefreshThreshold. Use Duration.ZERO for unconditional refresh.
	 * @return true if success, false if exclusive access could not be acquired
	 * @see #clearStale
	 */
	public boolean persist(Duration timestampRefreshThreshold) {
		try (LockHandle handle = lockService.takeLock(UPDATE_LOCK_NAME, UPDATE_LOCK_VALIDITY, UPDATE_LOCK_TRY_FOR)) {
			List vectorsToAdd = new ArrayList<>();
			List vectorsToUpdate = new ArrayList<>();

			for (EmbeddingVector memVector : memCache.values()) {
				OffsetDateTime persistedTimestamp = persistedTimestampsByHash.get(memVector.getHash());
				if (persistedTimestamp == null)
					vectorsToAdd.add(memVector);
				else if	(persistedTimestamp.plus(timestampRefreshThreshold).isBefore(memVector.getLastAccessed())) {
					vectorsToUpdate.add(memVector);
				}
			}

			embeddingVectorRepository.persistChanges(vectorsToAdd, vectorsToUpdate);
			vectorsToAdd.forEach(v -> persistedTimestampsByHash.put(v.getHash(), v.getLastAccessed()));
			vectorsToUpdate.forEach(v -> persistedTimestampsByHash.put(v.getHash(), v.getLastAccessed()));

			return true;
		} catch (LockException e) {
			LOG.info("Unable to get lock for embedding vector update");
			return false;
		}
	}

	/**
	 * Deletes from the database items whose lastUpdated timestamp is before stalenessCutoff. Depends on
	 * {@link #persist} to update the timestamps frequently enough.
	 * @param stalenessCutoff the cut-off date to decide which items are considered stale
	 * @return true if success, false if exclusive access could not be acquired
	 */
	public boolean clearStale(OffsetDateTime stalenessCutoff) {
		try (LockHandle handle = lockService.takeLock(UPDATE_LOCK_NAME, UPDATE_LOCK_VALIDITY, UPDATE_LOCK_TRY_FOR)) {
			embeddingVectorRepository.deleteStaleData(stalenessCutoff);
			memCache.entrySet().removeIf(e ->
					e.getValue().getLastAccessed().isBefore(stalenessCutoff));
			return true;
		} catch (LockException e) {
			LOG.info("Unable to get lock for embedding vector clean-up");
			return false;
		}
	}

	public double[] getEmbedding(String text) {
		return getEmbeddings(List.of(text)).get(text);
	}

	public Map getEmbeddings(Collection texts) {
		Map vectorsByText = new LinkedHashMap<>();
		for (String text : texts)
			vectorsByText.put(text, getCachedVectorOrNull(calculateHash(text)));

		List missingTexts = vectorsByText.entrySet().stream()
				.filter(e -> e.getValue() == null)
				.map(Map.Entry::getKey)
				.toList();

		if (!missingTexts.isEmpty())
			vectorsByText.putAll(createNewEntities(missingTexts));

		return vectorsByText;
	}

	private Map createNewEntities(List missingTexts) {
		List newVectors = embeddingVectorApi.createEmbeddingVectorsInBatches(provider, model, missingTexts, timeout);

		Map results = new HashMap<>();
		for (int i = 0; i < missingTexts.size(); i++) {
			String hash = calculateHash(missingTexts.get(i));
			EmbeddingVector newEntity = new EmbeddingVector(model, hash, newVectors.get(i));
			memCache.put(hash, newEntity);
			results.put(missingTexts.get(i), newEntity.getVector());
		}
		return results;
	}

	private static String calculateHash(String x) {
		return DigestUtils.sha256Hex(x);
	}

	private double[] getCachedVectorOrNull(String hash) {
		return getCachedVector(hash)
				.map(EmbeddingVector::getVector)
				.orElse(null);
	}

	private Optional getCachedVector(String hash) {
		Optional result = Optional.ofNullable(
				memCache.computeIfAbsent(hash, this::findPersistedVector)
		);
		result.ifPresent(x -> x.setLastAccessed(OffsetDateTime.now()));
		return result;
	}

	private EmbeddingVector findPersistedVector(String hash) {
		Optional entity = embeddingVectorRepository.findByModelAndHash(model.name(), hash);
		entity.ifPresent(e -> persistedTimestampsByHash.put(e.getHash(), e.getLastAccessed()));

		return entity.orElse(null);
	}

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy