All Downloads are FREE. Search and download functionalities are using the official Maven repository.

dev.langchain4j.store.embedding.neo4j.Neo4jEmbeddingUtils Maven / Gradle / Ivy

There is a newer version: 0.36.2
Show newest version
package dev.langchain4j.store.embedding.neo4j;

import dev.langchain4j.data.document.Metadata;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import org.neo4j.driver.Record;
import org.neo4j.driver.Value;
import org.neo4j.driver.Values;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.IntStream;
import java.util.stream.Stream;

import static org.neo4j.cypherdsl.support.schema_name.SchemaNames.sanitize;

class Neo4jEmbeddingUtils {
    
    /* not-configurable strings, just used under-the-hood in `UNWIND $rows ...` statement */
    public static final String EMBEDDINGS_ROW_KEY = "embeddingRow";
    
    /* default configs */
    public static final String DEFAULT_ID_PROP = "id";
    public static final String DEFAULT_DATABASE_NAME = "neo4j";
    public static final String DEFAULT_EMBEDDING_PROP = "embedding";
    public static final String PROPS = "props";
    public static final String DEFAULT_IDX_NAME = "vector";
    public static final String DEFAULT_LABEL = "Document";
    public static final String DEFAULT_TEXT_PROP = "text";
    public static final long DEFAULT_AWAIT_INDEX_TIMEOUT = 60L;

    public static EmbeddingMatch toEmbeddingMatch(Neo4jEmbeddingStore store, Record neo4jRecord) {
        Map metaData = new HashMap<>();
        neo4jRecord.get("metadata").asMap().forEach((key, value) -> {
            if (!store.getNotMetaKeys().contains(key)) {
                String stringValue = value == null ? null : value.toString();
                metaData.put(key.replace(store.getMetadataPrefix(), ""), stringValue);
            }
        });

        Metadata metadata = new Metadata(metaData);

        Value text = neo4jRecord.get(store.getTextProperty());
        TextSegment textSegment = text.isNull()
                ? null
                : TextSegment.from(text.asString(), metadata);

        List embeddingList = neo4jRecord.get(store.getEmbeddingProperty())
                .asList(Value::asFloat);

        Embedding embedding = Embedding.from(embeddingList);

        return new EmbeddingMatch<>(neo4jRecord.get("score").asDouble(),
                neo4jRecord.get(store.getIdProperty()).asString(),
                embedding,
                textSegment);
    }

    public static Map toRecord(Neo4jEmbeddingStore store, int idx, List ids, List embeddings, List embedded) {
        String id = ids.get(idx);
        Embedding embedding = embeddings.get(idx);

        Map row = new HashMap<>();
        row.put(store.getIdProperty(), id);

        Map properties = new HashMap<>();
        if (embedded != null) {
            TextSegment segment = embedded.get(idx);
            properties.put(store.getTextProperty(), segment.text());
            Map metadata = segment.metadata().asMap();
            metadata.forEach((k, v) -> properties.put(store.getMetadataPrefix() + k, Values.value(v)));
        }

        row.put(EMBEDDINGS_ROW_KEY, Values.value(embedding.vector()));
        row.put(PROPS, properties);
        return row;
    }

    public static Stream>> getRowsBatched(Neo4jEmbeddingStore store, List ids, List embeddings, List embedded) {
        int batchSize = 10_000;
        AtomicInteger batchCounter = new AtomicInteger();
        int total = ids.size();
        int batchNumber = (int) Math.ceil( (double) total / batchSize );
        return IntStream.range(0, batchNumber)
                .mapToObj(part -> {
                    List> maps = ids.subList(Math.min(part * batchSize, total), Math.min((part + 1) * batchSize, total))
                                    .stream()
                                    .map(i -> toRecord(store, batchCounter.getAndIncrement(), ids, embeddings, embedded))
                                    .toList();
                            return maps;
                        }
                );
    }

    public static String sanitizeOrThrows(String value, String config) {
        return sanitize(value)
                .orElseThrow(() -> {
                    String invalidSanitizeValue = String.format("The value %s, to assign to configuration %s, cannot be safely quoted",
                            value,
                            config);
                    return new RuntimeException(invalidSanitizeValue);
                });
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy