All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.neo4j.gds.extension.Neo4jSupportExtension Maven / Gradle / Ivy

/*
 * Copyright (c) "Neo4j"
 * Neo4j Sweden AB [http://neo4j.com]
 *
 * This file is part of Neo4j.
 *
 * Neo4j is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see .
 */
package org.neo4j.gds.extension;

import org.apache.commons.lang3.tuple.Pair;
import org.junit.jupiter.api.extension.BeforeEachCallback;
import org.junit.jupiter.api.extension.ExtensionContext;
import org.neo4j.configuration.GraphDatabaseSettings;
import org.neo4j.dbms.api.DatabaseManagementService;
import org.neo4j.gds.QueryRunner;
import org.neo4j.gds.TestSupport;
import org.neo4j.gds.compat.GraphDatabaseApiProxy;
import org.neo4j.gds.compat.Neo4jProxy;
import org.neo4j.graphdb.GraphDatabaseService;
import org.neo4j.graphdb.Node;
import org.neo4j.graphdb.Result;
import org.neo4j.internal.id.IdGeneratorFactory;
import org.neo4j.kernel.impl.core.NodeEntity;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Stream;

import static java.util.Arrays.stream;
import static org.neo4j.gds.extension.ExtensionUtil.injectInstance;
import static org.neo4j.gds.utils.StringFormatting.formatWithLocale;

public class Neo4jSupportExtension implements BeforeEachCallback {

    private static final String RETURN_STATEMENT = "RETURN *";

    // taken from org.neo4j.test.extension.DbmsSupportController
    private static final ExtensionContext.Namespace DBMS_NAMESPACE = ExtensionContext.Namespace.create(
        "org",
        "neo4j",
        "dbms"
    );

    // taken from org.neo4j.test.extension.DbmsSupportController
    private static final String DBMS_KEY = "service";

    @Override
    public void beforeEach(ExtensionContext context) {
        GraphDatabaseService db = getDbms(context)
            .map(dbms -> dbms.database(GraphDatabaseSettings.DEFAULT_DATABASE_NAME))
            .orElseThrow(() -> new IllegalStateException("No database was found."));

        Class requiredTestClass = context.getRequiredTestClass();
        Optional> createQuery = createQueryAndIdOffset(requiredTestClass);
        Map idMap = neo4jGraphSetup(db, createQuery);
        injectFields(context, db, idMap);
    }

    private Optional getDbms(ExtensionContext context) {
        return Optional.ofNullable(context.getStore(DBMS_NAMESPACE).get(DBMS_KEY, DatabaseManagementService.class));
    }

    private Optional> createQueryAndIdOffset(Class testClass) {
        return Stream.>iterate(testClass, c -> c.getSuperclass() != null, Class::getSuperclass)
            .flatMap(clazz -> stream(clazz.getDeclaredFields()))
            .filter(field -> field.isAnnotationPresent(Neo4jGraph.class))
            .findFirst()
            .map(field -> Pair.of(
                ExtensionUtil.getStringValueOfField(field),
                field.getAnnotation(Neo4jGraph.class).offsetIds()
            ));
    }

    private Map neo4jGraphSetup(GraphDatabaseService db, Optional> createQueryAndOffset) {
        offsetNodeIds(db, createQueryAndOffset.map(Pair::getRight).orElse(false));

        return createQueryAndOffset
            .map(Pair::getLeft)
            .map(query -> formatWithLocale("%s %s", query, RETURN_STATEMENT))
            .map(query -> QueryRunner.runQuery(db, query, Neo4jSupportExtension::extractVariableIds))
            .orElseGet(Map::of);
    }

    private static Map extractVariableIds(Result result) {
        if (!result.hasNext()) {
            throw new IllegalArgumentException("Result of create query was empty");
        }
        List columns = result.columns();
        Map row = result.next();

        Map idMap = new HashMap<>();
        columns.forEach(column -> {
            Object value = row.get(column);
            if (value instanceof NodeEntity) {
                idMap.put(column, (NodeEntity) value);
            }
        });

        return idMap;
    }

    private void offsetNodeIds(GraphDatabaseService db, boolean offsetIds) {
        if (!offsetIds) {
            return;
        }

        // try to convince the db that `idOffset` number of nodes have already been allocated
        var idGeneratorFactory = GraphDatabaseApiProxy.resolveDependency(db, IdGeneratorFactory.class);
        TestSupport.fullAccessTransaction(db).accept((tx, ktx) -> Neo4jProxy.reserveNeo4jIds(idGeneratorFactory, 42, ktx.cursorContext()));
    }

    private void injectFields(ExtensionContext context, GraphDatabaseService db, Map idMap) {
        NodeFunction nodeFunction = idMap::get;
        IdFunction idFunction = variable -> nodeFunction.of(variable).getId();
        context.getRequiredTestInstances().getAllInstances().forEach(testInstance -> {
            injectInstance(testInstance, nodeFunction, NodeFunction.class);
            injectInstance(testInstance, idFunction, IdFunction.class);
            injectInstance(testInstance, db, GraphDatabaseService.class);
        });
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy