dev.langchain4j.store.embedding.redis.RedisEmbeddingStore Maven / Gradle / Ivy
The newest version!
package dev.langchain4j.store.embedding.redis;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import dev.langchain4j.data.document.Metadata;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingStore;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import redis.clients.jedis.JedisPooled;
import redis.clients.jedis.Pipeline;
import redis.clients.jedis.json.Path2;
import redis.clients.jedis.search.Document;
import redis.clients.jedis.search.FTCreateParams;
import redis.clients.jedis.search.IndexDataType;
import redis.clients.jedis.search.Query;
import redis.clients.jedis.search.SearchResult;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import static dev.langchain4j.internal.Utils.getOrDefault;
import static dev.langchain4j.internal.Utils.isNullOrEmpty;
import static dev.langchain4j.internal.Utils.randomUUID;
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;
import static dev.langchain4j.internal.ValidationUtils.ensureTrue;
import static dev.langchain4j.store.embedding.redis.RedisSchema.SCORE_FIELD_NAME;
import static java.lang.String.format;
import static java.util.Arrays.asList;
import static java.util.Collections.singletonList;
import static java.util.stream.Collectors.toList;
import static java.util.stream.Collectors.toMap;
import static redis.clients.jedis.search.RediSearchUtil.ToByteArray;
/**
* Represents a Redis index as an embedding store.
* Current implementation assumes the index uses the cosine distance metric.
*/
public class RedisEmbeddingStore implements EmbeddingStore {
private static final Logger log = LoggerFactory.getLogger(RedisEmbeddingStore.class);
private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper();
private final JedisPooled client;
private final RedisSchema schema;
/**
* Creates an instance of RedisEmbeddingStore
*
* @param host Redis Stack Server host
* @param port Redis Stack Server port
* @param user Redis Stack username (optional)
* @param password Redis Stack password (optional)
* @param indexName The name of the index (optional). Default value: "embedding-index".
* @param prefix The prefix of the key, should end with a colon (e.g., "embedding:") (optional). Default value: "embedding:".
* @param dimension Embedding vector dimension
* @param metadataKeys Metadata keys that should be persisted (optional)
*/
public RedisEmbeddingStore(String host,
Integer port,
String user,
String password,
String indexName,
String prefix,
Integer dimension,
Collection metadataKeys) {
ensureNotBlank(host, "host");
ensureNotNull(port, "port");
ensureNotNull(dimension, "dimension");
this.client = user == null ? new JedisPooled(host, port) : new JedisPooled(host, port, user, password);
this.schema = RedisSchema.builder()
.indexName(getOrDefault(indexName, "embedding-index"))
.prefix(getOrDefault(prefix, "embedding:"))
.dimension(dimension)
.metadataKeys(metadataKeys)
.build();
if (!isIndexExist(schema.indexName())) {
createIndex(schema.indexName());
}
}
@Override
public String add(Embedding embedding) {
String id = randomUUID();
add(id, embedding);
return id;
}
@Override
public void add(String id, Embedding embedding) {
addInternal(id, embedding, null);
}
@Override
public String add(Embedding embedding, TextSegment textSegment) {
String id = randomUUID();
addInternal(id, embedding, textSegment);
return id;
}
@Override
public List addAll(List embeddings) {
List ids = embeddings.stream()
.map(ignored -> randomUUID())
.collect(toList());
addAllInternal(ids, embeddings, null);
return ids;
}
@Override
public List addAll(List embeddings, List embedded) {
List ids = embeddings.stream()
.map(ignored -> randomUUID())
.collect(toList());
addAllInternal(ids, embeddings, embedded);
return ids;
}
@Override
public List> findRelevant(Embedding referenceEmbedding, int maxResults, double minScore) {
// Using KNN query on @vector field
String queryTemplate = "*=>[ KNN %d @%s $BLOB AS %s ]";
List returnFields = new ArrayList<>(schema.metadataKeys());
returnFields.addAll(asList(schema.vectorFieldName(), schema.scalarFieldName(), SCORE_FIELD_NAME));
Query query = new Query(format(queryTemplate, maxResults, schema.vectorFieldName(), SCORE_FIELD_NAME))
.addParam("BLOB", ToByteArray(referenceEmbedding.vector()))
.returnFields(returnFields.toArray(new String[0]))
.setSortBy(SCORE_FIELD_NAME, true)
.dialect(2);
SearchResult result = client.ftSearch(schema.indexName(), query);
List documents = result.getDocuments();
return toEmbeddingMatch(documents, minScore);
}
private void createIndex(String indexName) {
String res = client.ftCreate(indexName, FTCreateParams.createParams()
.on(IndexDataType.JSON)
.addPrefix(schema.prefix()), schema.toSchemaFields());
if (!"OK".equals(res)) {
if (log.isErrorEnabled()) {
log.error("create index error, msg={}", res);
}
throw new RedisRequestFailedException("create index error, msg=" + res);
}
}
private boolean isIndexExist(String indexName) {
Set indexes = client.ftList();
return indexes.contains(indexName);
}
private void addInternal(String id, Embedding embedding, TextSegment embedded) {
addAllInternal(singletonList(id), singletonList(embedding), embedded == null ? null : singletonList(embedded));
}
private void addAllInternal(List ids, List embeddings, List embedded) {
if (isNullOrEmpty(ids) || isNullOrEmpty(embeddings)) {
log.info("do not add empty embeddings to redis");
return;
}
ensureTrue(ids.size() == embeddings.size(), "ids size is not equal to embeddings size");
ensureTrue(embedded == null || embeddings.size() == embedded.size(), "embeddings size is not equal to embedded size");
List
© 2015 - 2024 Weber Informatics LLC | Privacy Policy