dev.langchain4j.store.embedding.weaviate.WeaviateEmbeddingStore Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of langchain4j-weaviate Show documentation
Show all versions of langchain4j-weaviate Show documentation
Uses io.weaviate.client library which has a BSD 3-Clause license:
https://github.com/weaviate/java-client/blob/main/LICENSE
package dev.langchain4j.store.embedding.weaviate;
import dev.langchain4j.data.document.Metadata;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingStore;
import io.weaviate.client.Config;
import io.weaviate.client.WeaviateAuthClient;
import io.weaviate.client.WeaviateClient;
import io.weaviate.client.base.Result;
import io.weaviate.client.base.WeaviateErrorMessage;
import io.weaviate.client.v1.auth.exception.AuthException;
import io.weaviate.client.v1.data.model.WeaviateObject;
import io.weaviate.client.v1.filters.Operator;
import io.weaviate.client.v1.filters.WhereFilter;
import io.weaviate.client.v1.graphql.model.GraphQLError;
import io.weaviate.client.v1.graphql.model.GraphQLResponse;
import io.weaviate.client.v1.graphql.query.argument.NearVectorArgument;
import io.weaviate.client.v1.graphql.query.fields.Field;
import lombok.Builder;
import org.apache.commons.lang3.ArrayUtils;
import java.util.*;
import static dev.langchain4j.internal.Utils.*;
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
import static dev.langchain4j.internal.ValidationUtils.ensureNotEmpty;
import static io.weaviate.client.v1.data.replication.model.ConsistencyLevel.QUORUM;
import static java.util.Arrays.stream;
import static java.util.Collections.emptyList;
import static java.util.Collections.singletonList;
import static java.util.stream.Collectors.joining;
import static java.util.stream.Collectors.toList;
/**
* Represents the Weaviate vector database.
* Current implementation assumes the cosine distance metric is used.
*/
public class WeaviateEmbeddingStore implements EmbeddingStore {
private static final String METADATA_TEXT_SEGMENT = "text";
private static final String ADDITIONALS = "_additional";
private static final String METADATA = "_metadata";
private static final String NULL_VALUE = "";
private final WeaviateClient client;
private final String objectClass;
private final boolean avoidDups;
private final String consistencyLevel;
private final Collection metadataKeys;
/**
* Creates a new WeaviateEmbeddingStore instance.
*
* @param apiKey Your Weaviate API key. Not required for local deployment.
* @param scheme The scheme, e.g. "https" of cluster URL. Find in under Details of your Weaviate cluster.
* @param host The host, e.g. "langchain4j-4jw7ufd9.weaviate.network" of cluster URL.
* Find in under Details of your Weaviate cluster.
* @param port The port, e.g. 8080. This parameter is optional.
* @param objectClass The object class you want to store, e.g. "MyGreatClass". Must start from an uppercase letter.
* @param avoidDups If true (default), then WeaviateEmbeddingStore
will generate a hashed ID based on
* provided text segment, which avoids duplicated entries in DB.
* If false, then random ID will be generated.
* @param consistencyLevel Consistency level: ONE, QUORUM (default) or ALL. Find more details here.
* @param metadataKeys Metadata keys that should be persisted (optional)
* @param useGrpcForInserts Use GRPC instead of HTTP for batch inserts only. You still need HTTP configured for search
* @param securedGrpc The GRPC connection is secured
* @param grpcPort The port, e.g. 50051. This parameter is optional.
*/
@Builder
public WeaviateEmbeddingStore(
String apiKey,
String scheme,
String host,
Integer port,
Boolean useGrpcForInserts,
Boolean securedGrpc,
Integer grpcPort,
String objectClass,
Boolean avoidDups,
String consistencyLevel,
Collection metadataKeys
) {
try {
Config config = new Config(
ensureNotBlank(scheme, "scheme"),
concatenate(ensureNotBlank(host, "host"), port)
);
if (getOrDefault(useGrpcForInserts, Boolean.FALSE)) {
config.setGRPCSecured(getOrDefault(securedGrpc, Boolean.FALSE));
config.setGRPCHost(host + ":" + getOrDefault(grpcPort, 50051));
}
if (isNullOrBlank(apiKey)) {
this.client = new WeaviateClient(config);
} else {
this.client = WeaviateAuthClient.apiKey(config, apiKey);
}
} catch (AuthException e) {
throw new IllegalArgumentException(e);
}
this.objectClass = getOrDefault(objectClass, "Default");
this.avoidDups = getOrDefault(avoidDups, true);
this.consistencyLevel = getOrDefault(consistencyLevel, QUORUM);
this.metadataKeys = getOrDefault(metadataKeys, Collections.emptyList());
}
private static String concatenate(String host, Integer port) {
if (port == null) {
return host;
} else {
return host + ":" + port;
}
}
@Override
public String add(Embedding embedding) {
String id = randomUUID();
add(id, embedding);
return id;
}
/**
* Adds a new embedding with provided ID to the store.
*
* @param id the ID of the embedding to add in UUID format, since it's Weaviate requirement.
* See Weaviate docs and
* UUID on Wikipedia
* @param embedding the embedding to add
*/
@Override
public void add(String id, Embedding embedding) {
addAll(singletonList(id), singletonList(embedding), null);
}
@Override
public String add(Embedding embedding, TextSegment textSegment) {
return addAll(singletonList(embedding), singletonList(textSegment)).stream().findFirst().orElse(null);
}
@Override
public List addAll(List embeddings) {
return addAll(embeddings, null);
}
@Override
public List addAll(List embeddings, List embedded) {
return addAll(null, embeddings, embedded);
}
@Override
public void removeAll(Collection ids) {
ensureNotEmpty(ids, "ids");
client.batch().objectsBatchDeleter()
.withClassName(objectClass)
.withWhere(WhereFilter.builder()
.path("id")
.operator(Operator.ContainsAny)
.valueText(ids.toArray(new String[0]))
.build())
.run();
}
@Override
public void removeAll() {
client.batch().objectsBatchDeleter()
.withClassName(objectClass)
.run();
}
/**
* {@inheritDoc}
* The score inside {@link EmbeddingMatch} is Weaviate's certainty.
*/
@Override
public List> findRelevant(
Embedding referenceEmbedding,
int maxResults,
double minCertainty
) {
List fields = new ArrayList<>();
fields.add(Field.builder().name(METADATA_TEXT_SEGMENT).build());
fields.add(Field
.builder()
.name(ADDITIONALS)
.fields(
Field.builder().name("id").build(),
Field.builder().name("certainty").build(),
Field.builder().name("vector").build()
)
.build());
if (!metadataKeys.isEmpty()) {
List metadataFields = new ArrayList<>();
for (String property : metadataKeys) {
metadataFields.add(Field.builder().name(property).build());
}
fields.add(Field.builder().name(METADATA).fields(metadataFields.toArray(new Field[0])).build());
}
Result result = client
.graphQL()
.get()
.withClassName(objectClass)
.withFields(fields.toArray(new Field[0]))
.withNearVector(
NearVectorArgument
.builder()
.vector(referenceEmbedding.vectorAsList().toArray(new Float[0]))
.certainty((float) minCertainty)
.build()
)
.withLimit(maxResults)
.run();
if (result.hasErrors()) {
throw new IllegalArgumentException(
result.getError().getMessages().stream().map(WeaviateErrorMessage::getMessage).collect(joining("\n"))
);
}
GraphQLError[] errors = result.getResult().getErrors();
if (errors != null && errors.length > 0) {
throw new IllegalArgumentException(stream(errors).map(GraphQLError::getMessage).collect(joining("\n")));
}
Optional> resGetPart =
((Map) result.getResult().getData()).entrySet().stream().findFirst();
if (!resGetPart.isPresent()) {
return emptyList();
}
Optional resItemsPart = resGetPart.get().getValue().entrySet().stream().findFirst();
if (!resItemsPart.isPresent()) {
return emptyList();
}
List