All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.neo4j.gds.hits.HitsComputation Maven / Gradle / Ivy

/*
 * Copyright (c) "Neo4j"
 * Neo4j Sweden AB [http://neo4j.com]
 *
 * This file is part of Neo4j.
 *
 * Neo4j is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see .
 */
package org.neo4j.gds.hits;

import org.neo4j.gds.api.nodeproperties.ValueType;
import org.neo4j.gds.beta.pregel.BidirectionalPregelComputation;
import org.neo4j.gds.beta.pregel.Messages;
import org.neo4j.gds.beta.pregel.Pregel;
import org.neo4j.gds.beta.pregel.PregelSchema;
import org.neo4j.gds.beta.pregel.context.ComputeContext.BidirectionalComputeContext;
import org.neo4j.gds.beta.pregel.context.InitContext.BidirectionalInitContext;
import org.neo4j.gds.beta.pregel.context.MasterComputeContext;
import org.neo4j.gds.mem.MemoryEstimateDefinition;

import java.util.Map;
import java.util.concurrent.atomic.DoubleAdder;

import static org.neo4j.gds.hits.HitsState.CALCULATE_AUTHS;
import static org.neo4j.gds.hits.HitsState.CALCULATE_HUBS;
import static org.neo4j.gds.hits.HitsState.INIT;
import static org.neo4j.gds.hits.HitsState.NORMALIZE_AUTHS;
import static org.neo4j.gds.hits.HitsState.NORMALIZE_HUBS;


public class HitsComputation implements BidirectionalPregelComputation {

    // Global norm aggregator shared by all workers
    private final DoubleAdder globalNorm = new DoubleAdder();
    private HitsState state = INIT;

    @Override
    public PregelSchema schema(HitsConfig config) {
        return new PregelSchema.Builder()
            .add(config.authProperty(), ValueType.DOUBLE)
            .add(config.hubProperty(), ValueType.DOUBLE)
            .build();
    }

    @Override
    public MemoryEstimateDefinition estimateDefinition(boolean isAsynchronous) {
        return () -> Pregel.memoryEstimation(
            Map.of(
                "auth", ValueType.DOUBLE,
                "hub", ValueType.DOUBLE
            ),
            false,
            false
        );
    }

    @Override
    public void init(BidirectionalInitContext context) {
        context.setNodeValue(context.config().hubProperty(), 1D);
        context.setNodeValue(context.config().authProperty(), 1D);
    }

    @Override
    public void compute(BidirectionalComputeContext context, Messages messages) {
        switch (state) {
            case INIT:
                var auth = (double) context.incomingDegree();
                context.setNodeValue(context.config().authProperty(), auth);
                updateGlobalNorm(auth);
                break;
            case CALCULATE_AUTHS:
                calculateValue(context, messages, context.config().authProperty());
                break;
            case NORMALIZE_AUTHS:
                normalizeAuthValue(context);
                break;
            case CALCULATE_HUBS:
                calculateValue(context, messages, context.config().hubProperty());
                break;
            case NORMALIZE_HUBS:
                normalizeHubValue(context);
                break;
        }
    }

    @Override
    public boolean masterCompute(MasterComputeContext context) {
        if (state == INIT || state == CALCULATE_AUTHS || state == CALCULATE_HUBS) {
            var norm = globalNorm.sumThenReset();
            globalNorm.add(Math.sqrt(norm));
        } else if (state == NORMALIZE_AUTHS || state == NORMALIZE_HUBS) {
            globalNorm.reset();
        }
        state = state.advance();

        return false;
    }

    private void calculateValue(
        BidirectionalComputeContext context,
        Messages messages,
        String authProperty
    ) {
        var auth = 0D;
        for (Double message : messages) {
            auth += message;
        }
        context.setNodeValue(authProperty, auth);
        updateGlobalNorm(auth);
    }

    private void normalizeHubValue(BidirectionalComputeContext context) {
        // normalise hub
        var normalizedValue = normalize(context, context.config().hubProperty());
        // send normalised hubs to outgoing neighbors
        context.sendToNeighbors(normalizedValue);
    }

    private void normalizeAuthValue(BidirectionalComputeContext context) {
        // normalise auth
        var normalizedValue = normalize(context, context.config().authProperty());
        // send normalised auths to incoming neighbors

        context.sendToIncomingNeighbors(normalizedValue);
    }

    private void updateGlobalNorm(double value) {
        globalNorm.add(Math.pow(value, 2));
    }

    private double normalize(BidirectionalComputeContext context, String property) {
        var value = context.doubleNodeValue(property);
        var norm = globalNorm.sum();
        var normalizedValue = value / norm;
        context.setNodeValue(property, normalizedValue);
        return normalizedValue;
    }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy