
org.neo4j.gds.extension.Neo4jSupportExtension Maven / Gradle / Ivy
/*
* Copyright (c) "Neo4j"
* Neo4j Sweden AB [http://neo4j.com]
*
* This file is part of Neo4j.
*
* Neo4j is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see .
*/
package org.neo4j.gds.extension;
import org.apache.commons.lang3.tuple.Pair;
import org.junit.jupiter.api.extension.BeforeEachCallback;
import org.junit.jupiter.api.extension.ExtensionContext;
import org.neo4j.configuration.GraphDatabaseSettings;
import org.neo4j.dbms.api.DatabaseManagementService;
import org.neo4j.gds.QueryRunner;
import org.neo4j.gds.TestSupport;
import org.neo4j.gds.compat.GraphDatabaseApiProxy;
import org.neo4j.gds.compat.Neo4jProxy;
import org.neo4j.graphdb.GraphDatabaseService;
import org.neo4j.graphdb.Node;
import org.neo4j.graphdb.Result;
import org.neo4j.internal.id.IdGeneratorFactory;
import org.neo4j.kernel.impl.core.NodeEntity;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Stream;
import static java.util.Arrays.stream;
import static org.neo4j.gds.extension.ExtensionUtil.injectInstance;
import static org.neo4j.gds.utils.StringFormatting.formatWithLocale;
public class Neo4jSupportExtension implements BeforeEachCallback {
private static final String RETURN_STATEMENT = "RETURN *";
// taken from org.neo4j.test.extension.DbmsSupportController
private static final ExtensionContext.Namespace DBMS_NAMESPACE = ExtensionContext.Namespace.create(
"org",
"neo4j",
"dbms"
);
// taken from org.neo4j.test.extension.DbmsSupportController
private static final String DBMS_KEY = "service";
@Override
public void beforeEach(ExtensionContext context) {
GraphDatabaseService db = getDbms(context)
.map(dbms -> dbms.database(GraphDatabaseSettings.DEFAULT_DATABASE_NAME))
.orElseThrow(() -> new IllegalStateException("No database was found."));
Class> requiredTestClass = context.getRequiredTestClass();
Optional> createQuery = createQueryAndIdOffset(requiredTestClass);
Map idMap = neo4jGraphSetup(db, createQuery);
injectFields(context, db, idMap);
}
private Optional getDbms(ExtensionContext context) {
return Optional.ofNullable(context.getStore(DBMS_NAMESPACE).get(DBMS_KEY, DatabaseManagementService.class));
}
private Optional> createQueryAndIdOffset(Class> testClass) {
return Stream.>iterate(testClass, c -> c.getSuperclass() != null, Class::getSuperclass)
.flatMap(clazz -> stream(clazz.getDeclaredFields()))
.filter(field -> field.isAnnotationPresent(Neo4jGraph.class))
.findFirst()
.map(field -> Pair.of(
ExtensionUtil.getStringValueOfField(field),
field.getAnnotation(Neo4jGraph.class).offsetIds()
));
}
private Map neo4jGraphSetup(GraphDatabaseService db, Optional> createQueryAndOffset) {
offsetNodeIds(db, createQueryAndOffset.map(Pair::getRight).orElse(false));
return createQueryAndOffset
.map(Pair::getLeft)
.map(query -> formatWithLocale("%s %s", query, RETURN_STATEMENT))
.map(query -> QueryRunner.runQuery(db, query, Neo4jSupportExtension::extractVariableIds))
.orElseGet(Map::of);
}
private static Map extractVariableIds(Result result) {
if (!result.hasNext()) {
throw new IllegalArgumentException("Result of create query was empty");
}
List columns = result.columns();
Map row = result.next();
Map idMap = new HashMap<>();
columns.forEach(column -> {
Object value = row.get(column);
if (value instanceof NodeEntity) {
idMap.put(column, (NodeEntity) value);
}
});
return idMap;
}
private void offsetNodeIds(GraphDatabaseService db, boolean offsetIds) {
if (!offsetIds) {
return;
}
// try to convince the db that `idOffset` number of nodes have already been allocated
var idGeneratorFactory = GraphDatabaseApiProxy.resolveDependency(db, IdGeneratorFactory.class);
TestSupport.fullAccessTransaction(db).accept((tx, ktx) -> Neo4jProxy.reserveNeo4jIds(idGeneratorFactory, 42, ktx.cursorContext()));
}
private void injectFields(ExtensionContext context, GraphDatabaseService db, Map idMap) {
NodeFunction nodeFunction = idMap::get;
IdFunction idFunction = variable -> nodeFunction.of(variable).getId();
context.getRequiredTestInstances().getAllInstances().forEach(testInstance -> {
injectInstance(testInstance, nodeFunction, NodeFunction.class);
injectInstance(testInstance, idFunction, IdFunction.class);
injectInstance(testInstance, db, GraphDatabaseService.class);
});
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy