Please wait. This can take some minutes ...
Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance.
Project price only 1 $
You can buy this project and download/modify it how often you want.
apoc.vectordb.Weaviate Maven / Gradle / Ivy
package apoc.vectordb;
import apoc.Extended;
import apoc.ml.RestAPIConfig;
import apoc.result.ListResult;
import apoc.result.MapResult;
import apoc.util.UrlResolver;
import org.neo4j.graphdb.GraphDatabaseService;
import org.neo4j.graphdb.Transaction;
import org.neo4j.graphdb.security.URLAccessChecker;
import org.neo4j.internal.kernel.api.procs.ProcedureCallContext;
import org.neo4j.procedure.Context;
import org.neo4j.procedure.Description;
import org.neo4j.procedure.Mode;
import org.neo4j.procedure.Name;
import org.neo4j.procedure.Procedure;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Stream;
import static apoc.ml.RestAPIConfig.METHOD_KEY;
import static apoc.util.Util.map;
import static apoc.vectordb.VectorDb.executeRequest;
import static apoc.vectordb.VectorDb.getEmbeddingResult;
import static apoc.vectordb.VectorDb.getEmbeddingResultStream;
import static apoc.vectordb.VectorDbHandler.Type.WEAVIATE;
import static apoc.vectordb.VectorDbUtil.*;
@Extended
public class Weaviate {
public static final VectorDbHandler DB_HANDLER = WEAVIATE.get();
@Context
public ProcedureCallContext procedureCallContext;
@Context
public Transaction tx;
@Context
public GraphDatabaseService db;
@Context
public URLAccessChecker urlAccessChecker;
@Procedure("apoc.vectordb.weaviate.info")
@Description("apoc.vectordb.weaviate.info(hostOrKey, collection, $configuration) - Get information about the specified existing collection or throws an error if it does not exist")
public Stream createCollection(@Name("hostOrKey") String hostOrKey,
@Name("collection") String collection,
@Name(value = "configuration", defaultValue = "{}") Map configuration) throws Exception {
var config = getVectorDbInfo(hostOrKey, collection, configuration, "%s/schema/%s");
methodAndPayloadNull(config);
Map additionalBodies = Map.of("class", collection);
RestAPIConfig restAPIConfig = new RestAPIConfig(config, Map.of(), additionalBodies);
return executeRequest(restAPIConfig, urlAccessChecker)
.map(v -> (Map) v)
.map(MapResult::new);
}
@Procedure("apoc.vectordb.weaviate.createCollection")
@Description("apoc.vectordb.weaviate.createCollection(hostOrKey, collection, similarity, size, $configuration) - Creates a collection, with the name specified in the 2nd parameter, and with the specified `similarity` and `size`")
public Stream createCollection(@Name("hostOrKey") String hostOrKey,
@Name("collection") String collection,
@Name("similarity") String similarity,
@Name("size") Long size,
@Name(value = "configuration", defaultValue = "{}") Map configuration) throws Exception {
var config = getVectorDbInfo(hostOrKey, collection, configuration, "%s/schema");
config.putIfAbsent(METHOD_KEY, "POST");
Map additionalBodies = Map.of("class", collection,
"vectorIndexConfig", Map.of("distance", similarity, "size", size)
);
RestAPIConfig restAPIConfig = new RestAPIConfig(config, Map.of(), additionalBodies);
return executeRequest(restAPIConfig, urlAccessChecker)
.map(v -> (Map) v)
.map(MapResult::new);
}
@Procedure("apoc.vectordb.weaviate.deleteCollection")
@Description("apoc.vectordb.weaviate.deleteCollection(hostOrKey, collection, $configuration) - Deletes a collection with the name specified in the 2nd parameter")
public Stream deleteCollection(
@Name("hostOrKey") String hostOrKey,
@Name("collection") String collection,
@Name(value = "configuration", defaultValue = "{}") Map configuration) throws Exception {
var config = getVectorDbInfo(hostOrKey, collection, configuration, "%s/schema/%s");
config.putIfAbsent(METHOD_KEY, "DELETE");
RestAPIConfig restAPIConfig = new RestAPIConfig(config);
return executeRequest(restAPIConfig, urlAccessChecker)
.map(v -> (Map) v)
.map(MapResult::new);
}
@Procedure("apoc.vectordb.weaviate.upsert")
@Description("apoc.vectordb.weaviate.upsert(hostOrKey, collection, vectors, $configuration) - Upserts, in the collection with the name specified in the 2nd parameter, the vectors [{id: 'id', vector: '', medatada: ''}]")
public Stream upsert(
@Name("hostOrKey") String hostOrKey,
@Name("collection") String collection,
@Name("vectors") List> vectors,
@Name(value = "configuration", defaultValue = "{}") Map configuration) throws Exception {
var config = getVectorDbInfo(hostOrKey, collection, configuration, "%s/objects");
config.putIfAbsent(METHOD_KEY, "POST");
Map body = new HashMap<>();
body.put("class", collection);
RestAPIConfig restAPIConfig = new RestAPIConfig(config, Map.of(), body);
return vectors.stream()
.flatMap(vector -> {
try {
Map configBody = new HashMap<>(restAPIConfig.getBody());
configBody.putAll(vector);
configBody.put("properties", vector.remove("metadata"));
restAPIConfig.setBody(configBody);
Stream objectStream = executeRequest(restAPIConfig, urlAccessChecker);
return objectStream;
} catch (Exception e) {
throw new RuntimeException(e);
}
})
.map(v -> (Map) v)
.map(MapResult::new);
}
@Procedure(value = "apoc.vectordb.weaviate.delete")
@Description("apoc.vectordb.weaviate.delete(hostOrKey, collection, ids, $configuration) - Deletes the vectors with the specified `ids`")
public Stream delete(@Name("hostOrKey") String hostOrKey,
@Name("collection") String collection,
@Name("ids") List ids,
@Name(value = "configuration", defaultValue = "{}") Map configuration) throws Exception {
var config = getVectorDbInfo(hostOrKey, collection, configuration, "%s/schema");
config.putIfAbsent(METHOD_KEY, "DELETE");
RestAPIConfig restAPIConfig = new RestAPIConfig(config, map(), map());
List objects = ids.stream()
.peek(id -> {
String endpoint = "%s/objects/%s/%s".formatted(restAPIConfig.getBaseUrl(), collection, id);
restAPIConfig.setEndpoint(endpoint);
try {
executeRequest(restAPIConfig, urlAccessChecker);
} catch (Exception e) {
throw new RuntimeException(e);
}
})
.toList();
return Stream.of(new ListResult(objects));
}
@Procedure(value = "apoc.vectordb.weaviate.getAndUpdate", mode = Mode.WRITE)
@Description("apoc.vectordb.weaviate.getAndUpdate(hostOrKey, collection, ids, $configuration) - Gets the vectors with the specified `ids`")
public Stream getAndUpdate(@Name("hostOrKey") String hostOrKey,
@Name("collection") String collection,
@Name("ids") List ids,
@Name(value = "configuration", defaultValue = "{}") Map configuration) {
return getCommon(hostOrKey, collection, ids, configuration);
}
@Procedure(value = "apoc.vectordb.weaviate.get")
@Description("apoc.vectordb.weaviate.get(hostOrKey, collection, ids, $configuration) - Gets the vectors with the specified `ids`")
public Stream get(@Name("hostOrKey") String hostOrKey,
@Name("collection") String collection,
@Name("ids") List ids,
@Name(value = "configuration", defaultValue = "{}") Map configuration) {
setReadOnlyMappingMode(configuration);
return getCommon(hostOrKey, collection, ids, configuration);
}
private Stream getCommon(String hostOrKey, String collection, List ids, Map configuration) {
Map config = getVectorDbInfo(hostOrKey, collection, configuration, "%s/schema");
/**
* TODO: we put method: null as a workaround, it should be "GET": https://weaviate.io/developers/weaviate/api/rest#tag/objects/get/objects/{className}/{id}
* Since with `method: GET` the {@link apoc.util.Util#openUrlConnection(URL, Map)} has a `setChunkedStreamingMode`
* that makes the request to respond with an error 405 Method Not Allowed
*/
config.putIfAbsent(METHOD_KEY, null);
List fields = procedureCallContext.outputFields().toList();
VectorEmbeddingConfig conf = DB_HANDLER.getEmbedding().fromGet(config, procedureCallContext, ids, collection);
boolean hasEmbedding = fields.contains("vector") && conf.isAllResults();
boolean hasMetadata = fields.contains("metadata");
VectorMappingConfig mapping = conf.getMapping();
String suffix = hasEmbedding ? "?include=vector" : "";
return ids.stream()
.flatMap(id -> {
String endpoint = "%s/objects/%s/%s".formatted(conf.getApiConfig().getBaseUrl(), collection, id) + suffix;
conf.getApiConfig().setEndpoint(endpoint);
try {
return executeRequest(conf.getApiConfig(), urlAccessChecker)
.map(v -> (Map) v)
.map(m -> getEmbeddingResult(conf, tx, hasEmbedding, hasMetadata, mapping, m));
} catch (Exception e) {
throw new RuntimeException(e);
}
});
}
@Procedure(value = "apoc.vectordb.weaviate.query")
@Description("apoc.vectordb.weaviate.query(hostOrKey, collection, vector, filter, limit, $configuration) - Retrieves closest vectors from the defined `vector`, `limit` of results, in the collection with the name specified in the 2nd parameter")
public Stream query(@Name("hostOrKey") String hostOrKey,
@Name("collection") String collection,
@Name(value = "vector", defaultValue = "[]") List vector,
@Name(value = "filter", defaultValue = "null") Object filter,
@Name(value = "limit", defaultValue = "10") long limit,
@Name(value = "configuration", defaultValue = "{}") Map configuration) throws Exception {
setReadOnlyMappingMode(configuration);
return queryCommon(hostOrKey, collection, vector, filter, limit, configuration);
}
@Procedure(value = "apoc.vectordb.weaviate.queryAndUpdate", mode = Mode.WRITE)
@Description("apoc.vectordb.weaviate.queryAndUpdate(hostOrKey, collection, vector, filter, limit, $configuration) - Retrieves closest vectors from the defined `vector`, `limit` of results, in the collection with the name specified in the 2nd parameter")
public Stream queryAndUpdate(@Name("hostOrKey") String hostOrKey,
@Name("collection") String collection,
@Name(value = "vector", defaultValue = "[]") List vector,
@Name(value = "filter", defaultValue = "null") Object filter,
@Name(value = "limit", defaultValue = "10") long limit,
@Name(value = "configuration", defaultValue = "{}") Map configuration) throws Exception {
return queryCommon(hostOrKey, collection, vector, filter, limit, configuration);
}
private Stream queryCommon(String hostOrKey, String collection, List vector, Object filter, long limit, Map configuration) throws Exception {
Map config = getVectorDbInfo(hostOrKey, collection, configuration, "%s/graphql");
VectorEmbeddingConfig conf = DB_HANDLER.getEmbedding().fromQuery(config, procedureCallContext, vector, filter, limit, collection);
return getEmbeddingResultStream(conf, procedureCallContext, urlAccessChecker, tx,
v -> {
Object getValue = ((Map) v).get("data").get("Get");
Object collectionValue = ((Map) getValue).get(collection);
return ((List) collectionValue).stream()
.map(i -> {
Map additional = (Map) i.remove("_additional");
Map map = new HashMap<>();
map.put(conf.getMetadataKey(), i);
map.put(conf.getScoreKey(), additional.get("distance"));
map.put(conf.getIdKey(), additional.get("id"));
map.put(conf.getVectorKey(), additional.get("vector"));
return map;
});
}
);
}
private Map getVectorDbInfo(String hostOrKey, String collection, Map configuration, String templateUrl) {
return getCommonVectorDbInfo(hostOrKey, collection, configuration, templateUrl, DB_HANDLER);
}
protected String getWeaviateUrl(String hostOrKey) {
String baseUrl = new UrlResolver("http", "localhost", 8000)
.getUrl("weaviate", hostOrKey);
return baseUrl + "/v1";
}
}