
org.neo4j.gds.BaseProcTest Maven / Gradle / Ivy
/*
* Copyright (c) "Neo4j"
* Neo4j Sweden AB [http://neo4j.com]
*
* This file is part of Neo4j.
*
* Neo4j is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see .
*/
package org.neo4j.gds;
import org.assertj.core.api.HamcrestCondition;
import org.hamcrest.Matcher;
import org.intellij.lang.annotations.Language;
import org.intellij.lang.annotations.RegExp;
import org.jetbrains.annotations.Nullable;
import org.junit.jupiter.api.AfterEach;
import org.neo4j.gds.api.DatabaseId;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.compat.CompatUserAggregationFunction;
import org.neo4j.gds.compat.GraphDatabaseApiProxy;
import org.neo4j.gds.compat.Neo4jProxy;
import org.neo4j.gds.core.ExceptionMessageMatcher;
import org.neo4j.gds.core.Username;
import org.neo4j.gds.core.loading.GraphStoreCatalog;
import org.neo4j.graphdb.GraphDatabaseService;
import org.neo4j.graphdb.ResourceIterator;
import org.neo4j.graphdb.Result;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import static java.util.Collections.emptyMap;
import static org.assertj.core.api.Assertions.assertThat;
import static org.hamcrest.CoreMatchers.allOf;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;
import static org.neo4j.gds.utils.StringFormatting.formatWithLocale;
public class BaseProcTest extends BaseTest {
@AfterEach
void cleanupGraphStoreCatalog() {
GraphStoreCatalog.removeAllLoadedGraphs();
}
protected void registerFunctions(Class>... functionClasses) throws Exception {
GraphDatabaseApiProxy.registerFunctions(db, functionClasses);
}
protected void registerAggregationFunctions(Class>... functionClasses) throws Exception {
GraphDatabaseApiProxy.registerAggregationFunctions(db, functionClasses);
}
protected void registerAggregationFunction(CompatUserAggregationFunction function) throws Exception {
GraphDatabaseApiProxy.register(db, Neo4jProxy.callableUserAggregationFunction(function));
}
protected void registerFunctions(GraphDatabaseService db, Class>... functionClasses) throws Exception {
GraphDatabaseApiProxy.registerFunctions(db, functionClasses);
}
protected void registerProcedures(Class>... procedureClasses) throws Exception {
registerProcedures(db, procedureClasses);
}
protected void registerProcedures(GraphDatabaseService db, Class>... procedureClasses) throws Exception {
GraphDatabaseApiProxy.registerProcedures(db, procedureClasses);
}
T resolveDependency(Class dependency) {
return GraphDatabaseApiProxy.resolveDependency(db, dependency);
}
protected String getUsername() {
return Username.EMPTY_USERNAME.username();
}
protected void assertError(
@Language("Cypher") String query,
String messageSubstring
) {
assertError(query, emptyMap(), messageSubstring);
}
protected void assertError(
@Language("Cypher") String query,
Map queryParameters,
String messageSubstring
) {
assertError(query, queryParameters, ExceptionMessageMatcher.containsMessage(messageSubstring));
}
protected void assertError(
@Language("Cypher") String query,
Map queryParameters,
List messageSubstrings
) {
assertError(
query,
queryParameters,
allOf(messageSubstrings.stream()
.map(ExceptionMessageMatcher::containsMessage)
.collect(Collectors.toList()))
);
}
protected void assertErrorRegex(
@Language("Cypher") String query,
@RegExp String regex
) {
assertErrorRegex(query, emptyMap(), regex);
}
private void assertErrorRegex(
@Language("Cypher") String query,
Map queryParameters,
@RegExp String regex
) {
assertError(query, queryParameters, ExceptionMessageMatcher.containsMessageRegex(regex));
}
private void assertError(
@Language("Cypher") String query,
Map queryParameters,
Matcher matcher
) {
try {
runQueryWithResultConsumer(query, queryParameters, BaseProcTest::consume);
fail(formatWithLocale("Expected an exception to be thrown by query:\n%s", query));
} catch (Throwable e) {
assertThat(e).has(new HamcrestCondition<>(matcher));
}
}
protected void assertUserInput(Result.ResultRow row, String key, @Nullable Object expected) {
Map configMap = extractUserInput(row);
assertTrue(configMap.containsKey(key), formatWithLocale("Key %s is not present in config", key));
assertEquals(expected, configMap.get(key));
}
@SuppressWarnings("unchecked")
private Map extractUserInput(Result.ResultRow row) {
return ((Map) row.get("configuration"));
}
protected void loadCompleteGraph(String graphName) {
loadCompleteGraph(graphName, Orientation.NATURAL);
}
protected void loadCompleteGraph(String graphName, Orientation orientation) {
var createQuery = GdsCypher.call(graphName)
.graphProject()
.loadEverything(orientation)
.yields();
runQuery(createQuery);
}
protected void assertGraphExists(String graphName) {
Set graphs = getLoadedGraphs(graphName);
assertEquals(1, graphs.size());
}
protected void assertGraphDoesNotExist(String graphName) {
Set graphs = getLoadedGraphs(graphName);
assertTrue(graphs.isEmpty());
}
protected Graph findLoadedGraph(String graphName) {
return GraphStoreCatalog
.getGraphStores("", DatabaseId.of(db))
.entrySet()
.stream()
.filter(e -> e.getKey().graphName().equals(graphName))
.map(e -> e.getValue().getUnion())
.findFirst()
.orElseThrow(() -> new RuntimeException(formatWithLocale("Graph %s not found.", graphName)));
}
private Set getLoadedGraphs(String graphName) {
return GraphStoreCatalog
.getGraphStores("", DatabaseId.of(db))
.entrySet()
.stream()
.filter(e -> e.getKey().graphName().equals(graphName))
.map(e -> e.getValue().getUnion())
.collect(Collectors.toSet());
}
private static void consume(ResourceIterator
© 2015 - 2025 Weber Informatics LLC | Privacy Policy