org.neo4j.gds.extension.Neo4jSupportExtension Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of test-utils Show documentation
Show all versions of test-utils Show documentation
Neo4j Graph Data Science :: Test Utils
/*
* Copyright (c) "Neo4j"
* Neo4j Sweden AB [http://neo4j.com]
*
* This file is part of Neo4j.
*
* Neo4j is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see .
*/
package org.neo4j.gds.extension;
import org.apache.commons.lang3.tuple.Pair;
import org.junit.jupiter.api.extension.BeforeEachCallback;
import org.junit.jupiter.api.extension.ExtensionContext;
import org.neo4j.configuration.GraphDatabaseSettings;
import org.neo4j.dbms.api.DatabaseManagementService;
import org.neo4j.gds.QueryRunner;
import org.neo4j.gds.TestSupport;
import org.neo4j.gds.compat.GraphDatabaseApiProxy;
import org.neo4j.gds.compat.Neo4jProxy;
import org.neo4j.graphdb.GraphDatabaseService;
import org.neo4j.graphdb.Node;
import org.neo4j.graphdb.Result;
import org.neo4j.internal.id.IdGeneratorFactory;
import org.neo4j.kernel.impl.core.NodeEntity;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Stream;
import static java.util.Arrays.stream;
import static org.neo4j.gds.extension.ExtensionUtil.injectInstance;
import static org.neo4j.gds.utils.StringFormatting.formatWithLocale;
public class Neo4jSupportExtension implements BeforeEachCallback {
private static final String RETURN_STATEMENT = "RETURN *";
// taken from org.neo4j.test.extension.DbmsSupportController
private static final ExtensionContext.Namespace COMUNITY_NAMESPACE = ExtensionContext.Namespace.create(
"org",
"neo4j",
"dbms"
);
// taken from com.neo4j.test.extension.EnterpriseDbmsSupportExtension
private static final ExtensionContext.Namespace ENTERPRISE_NAMESPACE = ExtensionContext.Namespace.create(
"org",
"neo4j",
"dbms",
"support"
);
// taken from org.neo4j.test.extension.DbmsSupportController
private static final String DBMS_KEY = "service";
private static final String DATABASE_NAME_KEY = "database";
@Override
public void beforeEach(ExtensionContext context) {
GraphDatabaseService db;
var enterpriseStore = context.getStore(ENTERPRISE_NAMESPACE);
var communityStore = context.getStore(COMUNITY_NAMESPACE);
String databaseName = enterpriseStore.get(DATABASE_NAME_KEY) != null
? enterpriseStore.get(
DATABASE_NAME_KEY,
String.class
)
: GraphDatabaseSettings.DEFAULT_DATABASE_NAME;
var dbms = communityStore.get(DBMS_KEY, DatabaseManagementService.class);
db = dbms.database(databaseName);
Class> requiredTestClass = context.getRequiredTestClass();
Optional> createQuery = createQueryAndIdOffset(requiredTestClass);
var idFunctions = neo4jGraphSetup(db, createQuery);
injectFields(context, db, idFunctions);
}
private Optional> createQueryAndIdOffset(Class> testClass) {
return Stream.>iterate(testClass, c -> c.getSuperclass() != null, Class::getSuperclass)
.flatMap(clazz -> stream(clazz.getDeclaredFields()))
.filter(field -> field.isAnnotationPresent(Neo4jGraph.class))
.findFirst()
.map(
field -> Pair.of(
ExtensionUtil.getStringValueOfField(field),
field.getAnnotation(Neo4jGraph.class).offsetIds()
)
);
}
private IdFunctions neo4jGraphSetup(GraphDatabaseService db, Optional> createQueryAndOffset) {
offsetNodeIds(db, createQueryAndOffset.map(Pair::getRight).orElse(false));
return createQueryAndOffset
.map(Pair::getLeft)
.map(query -> formatWithLocale("%s %s", query, RETURN_STATEMENT))
.map(query -> QueryRunner.runQuery(db, query, Neo4jSupportExtension::extractVariableIds))
.orElse(IdFunctions.EMPTY);
}
private static IdFunctions extractVariableIds(Result result) {
if (!result.hasNext()) {
throw new IllegalArgumentException("Result of create query was empty");
}
List columns = result.columns();
Map row = result.next();
Map idMap = new HashMap<>();
Map variableToIddMap = new HashMap<>();
columns.forEach(column -> {
Object value = row.get(column);
if (value instanceof NodeEntity) {
idMap.put(column, (NodeEntity) value);
variableToIddMap.put(((NodeEntity) value).getId(), column);
}
});
return new IdFunctions(variableToIddMap, idMap);
}
private void offsetNodeIds(GraphDatabaseService db, boolean offsetIds) {
if (!offsetIds) {
return;
}
// try to convince the db that `idOffset` number of nodes have already been allocated
var idGeneratorFactory = GraphDatabaseApiProxy.resolveDependency(db, IdGeneratorFactory.class);
TestSupport.fullAccessTransaction(db)
.accept((tx, ktx) -> Neo4jProxy.reserveNeo4jIds(idGeneratorFactory, 42, ktx.cursorContext()));
}
private void injectFields(ExtensionContext context, GraphDatabaseService db, IdFunctions idFunctions) {
NodeFunction nodeFunction = idFunctions.variableToId::get;
IdFunction idFunction = variable -> nodeFunction.of(variable).getId();
context.getRequiredTestInstances().getAllInstances().forEach(testInstance -> {
injectInstance(testInstance, nodeFunction, NodeFunction.class);
injectInstance(testInstance, idFunction, IdFunction.class);
injectInstance(testInstance, idFunctions.idToVariable::get, IdToVariable.class);
injectInstance(testInstance, db, GraphDatabaseService.class);
});
}
// Inverse Id mapping
private static class IdFunctions {
static final IdFunctions EMPTY = new IdFunctions(Map.of(), Map.of());
final Map idToVariable;
final Map variableToId;
IdFunctions(Map idToVariable, Map variableToId) {
this.idToVariable = idToVariable;
this.variableToId = variableToId;
}
}
}