All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.neo4j.gds.embeddings.graphsage.MaxPoolingAggregator Maven / Gradle / Ivy

/*
 * Copyright (c) "Neo4j"
 * Neo4j Sweden AB [http://neo4j.com]
 *
 * This file is part of Neo4j.
 *
 * Neo4j is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see .
 */
package org.neo4j.gds.embeddings.graphsage;

import org.neo4j.gds.ml.core.Variable;
import org.neo4j.gds.ml.core.functions.ElementWiseMax;
import org.neo4j.gds.ml.core.functions.MatrixMultiplyWithTransposedSecondOperand;
import org.neo4j.gds.ml.core.functions.MatrixSum;
import org.neo4j.gds.ml.core.functions.MatrixVectorSum;
import org.neo4j.gds.ml.core.functions.Slice;
import org.neo4j.gds.ml.core.functions.Weights;
import org.neo4j.gds.ml.core.subgraph.SubGraph;
import org.neo4j.gds.ml.core.tensor.Matrix;
import org.neo4j.gds.ml.core.tensor.Tensor;
import org.neo4j.gds.ml.core.tensor.Vector;

import java.util.List;

public class MaxPoolingAggregator implements Aggregator {

    private final Weights poolWeights;
    private final Weights selfWeights;
    private final Weights neighborsWeights;
    private final Weights bias;
    private final ActivationFunction activationFunction;
    private final ActivationFunctionType activationFunctionType;

    public MaxPoolingAggregator(
        Weights poolWeights,
        Weights selfWeights,
        Weights neighborsWeights,
        Weights bias,
        ActivationFunctionWrapper activationFunctionWrapper
    ) {
        this.poolWeights = poolWeights;
        this.selfWeights = selfWeights;
        this.neighborsWeights = neighborsWeights;
        this.bias = bias;

        this.activationFunction = activationFunctionWrapper.activationFunction();
        this.activationFunctionType = activationFunctionWrapper.activationFunctionType();
    }

    @Override
    public Variable aggregate(
        Variable previousLayerRepresentations,
        SubGraph subGraph
    ) {
        Variable weightedPreviousLayer = MatrixMultiplyWithTransposedSecondOperand.of(
            previousLayerRepresentations,
            poolWeights
        );
        Variable biasedWeightedPreviousLayer = new MatrixVectorSum(weightedPreviousLayer, bias);
        Variable neighborhoodActivations = activationFunction.apply(biasedWeightedPreviousLayer);

        Variable elementwiseMax = new ElementWiseMax(neighborhoodActivations, subGraph);


        Variable selfPreviousLayer = new Slice(previousLayerRepresentations, subGraph.batchIds());
        Variable self = MatrixMultiplyWithTransposedSecondOperand.of(selfPreviousLayer, selfWeights);
        Variable neighbors = MatrixMultiplyWithTransposedSecondOperand.of(elementwiseMax, neighborsWeights);
        Variable sum = new MatrixSum(List.of(self, neighbors));

        return activationFunction.apply(sum);
    }

    @Override
    public List>> weights() {
        return List.of(
            poolWeights,
            selfWeights,
            neighborsWeights,
            bias
        );
    }

    @Override
    public List>> weightsWithoutBias() {
        return List.of(poolWeights, selfWeights, neighborsWeights);
    }

    @Override
    public AggregatorType type() {
        return AggregatorType.POOL;
    }

    @Override
    public ActivationFunctionType activationFunctionType() {
        return activationFunctionType;
    }

    public Matrix poolWeights() {
        return poolWeights.data();
    }

    public Matrix selfWeights() {
        return selfWeights.data();
    }

    public Matrix neighborsWeights() {
        return neighborsWeights.data();
    }

    public Vector bias() {
        return bias.data();
    }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy