org.neo4j.gds.embeddings.graphsage.GraphSageModelTrainer Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of algo Show documentation
Show all versions of algo Show documentation
Neo4j Graph Data Science :: Algorithms
The newest version!
/*
* Copyright (c) "Neo4j"
* Neo4j Sweden AB [http://neo4j.com]
*
* This file is part of Neo4j.
*
* Neo4j is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see .
*/
package org.neo4j.gds.embeddings.graphsage;
import org.immutables.value.Value;
import org.neo4j.gds.annotation.ValueClass;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.collections.ha.HugeObjectArray;
import org.neo4j.gds.core.concurrency.RunWithConcurrency;
import org.neo4j.gds.core.model.Model.CustomInfo;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.core.utils.progress.tasks.Task;
import org.neo4j.gds.core.utils.progress.tasks.Tasks;
import org.neo4j.gds.embeddings.graphsage.algo.GraphSageTrainParameters;
import org.neo4j.gds.ml.core.ComputationContext;
import org.neo4j.gds.ml.core.Variable;
import org.neo4j.gds.ml.core.functions.ConstantScale;
import org.neo4j.gds.ml.core.functions.ElementSum;
import org.neo4j.gds.ml.core.functions.L2NormSquared;
import org.neo4j.gds.ml.core.functions.Weights;
import org.neo4j.gds.ml.core.optimizer.AdamOptimizer;
import org.neo4j.gds.ml.core.subgraph.SubGraph;
import org.neo4j.gds.ml.core.tensor.Matrix;
import org.neo4j.gds.ml.core.tensor.Scalar;
import org.neo4j.gds.ml.core.tensor.Tensor;
import org.neo4j.gds.termination.TerminationFlag;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.SplittableRandom;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.ThreadLocalRandom;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import static org.neo4j.gds.embeddings.graphsage.GraphSageHelper.embeddingsComputationGraph;
import static org.neo4j.gds.ml.core.tensor.TensorFunctions.averageTensors;
import static org.neo4j.gds.utils.StringFormatting.formatWithLocale;
public class GraphSageModelTrainer {
private final long randomSeed;
private final GraphSageTrainParameters parameters;
private final FeatureFunction featureFunction;
private final Collection> labelProjectionWeights;
private final ExecutorService executor;
private final ProgressTracker progressTracker;
private final Layer[] layers;
private final TerminationFlag terminationFlag;
public GraphSageModelTrainer(GraphSageTrainParameters parameters, int featureDimension, ExecutorService executor, ProgressTracker progressTracker, TerminationFlag terminationFlag) {
this(parameters, executor, progressTracker, terminationFlag, new SingleLabelFeatureFunction(), Collections.emptyList(), featureDimension);
}
public GraphSageModelTrainer(
GraphSageTrainParameters parameters,
ExecutorService executor,
ProgressTracker progressTracker,
TerminationFlag terminationFlag,
FeatureFunction featureFunction,
Collection> labelProjectionWeights,
int featureDimension
) {
this.parameters = parameters;
this.featureFunction = featureFunction;
this.labelProjectionWeights = labelProjectionWeights;
this.executor = executor;
this.progressTracker = progressTracker;
this.terminationFlag = terminationFlag;
this.randomSeed = parameters.randomSeed().orElseGet(() -> ThreadLocalRandom.current().nextLong());
this.layers = parameters.layerConfigs(featureDimension)
.stream()
.map(LayerFactory::createLayer)
.toArray(Layer[]::new);
}
public static List progressTasks(long numberOfBatches, int batchesPerIteration, int maxIterations, int epochs) {
return List.of(
Tasks.leaf("Prepare batches", numberOfBatches),
Tasks.iterativeDynamic(
"Train model",
() -> List.of(Tasks.iterativeDynamic(
"Epoch",
() -> List.of(Tasks.leaf("Iteration", batchesPerIteration)),
maxIterations
)),
epochs
)
);
}
public ModelTrainResult train(Graph graph, HugeObjectArray features) {
var weights = new ArrayList>>(labelProjectionWeights);
for (Layer layer : layers) {
weights.addAll(layer.weights());
}
progressTracker.beginSubTask("Prepare batches");
var batchSampler = new BatchSampler(graph, progressTracker, terminationFlag);
List extendedBatches = batchSampler
.extendedBatches(parameters.batchSize(), parameters.searchDepth(), randomSeed);
var random = new SplittableRandom(randomSeed);
progressTracker.endSubTask("Prepare batches");
progressTracker.beginSubTask("Train model");
boolean converged = false;
var iterationLossesPerEpoch = new ArrayList>();
var prevEpochLoss = Double.NaN;
int epochs = parameters.epochs();
// if each batch is used more than once, we cache the tasks, otherwise we compute them lazily
boolean createBatchTasksEagerly = parameters.batchesPerIteration(graph.nodeCount()) * parameters.maxIterations() > extendedBatches.size();
for (int epoch = 1; epoch <= epochs && !converged; epoch++) {
progressTracker.beginSubTask("Epoch");
terminationFlag.assertRunning();
// also tried using random.nextLong() but this somehow had a worse quality
long epochLocalSeed = epoch + randomSeed;
Supplier> batchTaskSampler;
if (createBatchTasksEagerly) {
List tasksForEpoch = extendedBatches
.stream()
.map(extendedBatch -> createBatchTask(
extendedBatch,
graph,
features,
layers,
weights,
epochLocalSeed
))
.collect(Collectors.toList());
batchTaskSampler = () -> IntStream
.range(0, parameters.batchesPerIteration(graph.nodeCount()))
.mapToObj(__ -> tasksForEpoch.get(random.nextInt(tasksForEpoch.size())))
.collect(Collectors.toList());
} else {
batchTaskSampler = () -> IntStream
.range(0, parameters.batchesPerIteration(graph.nodeCount()))
.mapToObj(__ -> createBatchTask(
extendedBatches.get(random.nextInt(extendedBatches.size())),
graph,
features,
layers,
weights,
epochLocalSeed
))
.collect(Collectors.toList());
}
var epochResult = trainEpoch(batchTaskSampler, weights, prevEpochLoss);
List epochLosses = epochResult.losses();
iterationLossesPerEpoch.add(epochLosses);
prevEpochLoss = epochLosses.get(epochLosses.size() - 1);
converged = epochResult.converged();
progressTracker.endSubTask("Epoch");
}
progressTracker.endSubTask("Train model");
return ModelTrainResult.of(iterationLossesPerEpoch, converged, layers);
}
/**
* sampling the neighbor subgraph for each layer + constructing the loss function
*/
private BatchTask createBatchTask(
long[] extendedBatch,
Graph graph,
HugeObjectArray features,
Layer[] layers,
ArrayList>> weights,
long localSeed
) {
// as we pass a reference for the relationshipWeights, we need a local copy
var localGraph = graph.concurrentCopy();
List subGraphs = GraphSageHelper.subGraphsPerLayer(localGraph, extendedBatch, layers, localSeed);
Variable batchedFeaturesExtractor = featureFunction.apply(
localGraph,
subGraphs.get(subGraphs.size() - 1).originalNodeIds(),
features
);
Variable embeddingVariable = embeddingsComputationGraph(subGraphs, layers, batchedFeaturesExtractor);
Variable lossWithoutPenalty = new GraphSageLoss(
SubGraph.relationshipWeightFunction(localGraph),
embeddingVariable,
extendedBatch,
parameters.negativeSampleWeight()
);
long originalBatchSize = extendedBatch.length / 3;
Variable loss;
if (parameters.penaltyL2() > 0) {
List> l2penalty = Arrays
.stream(layers)
.map(layer -> layer.aggregator().weightsWithoutBias())
.flatMap(layerWeights -> layerWeights.stream().map(L2NormSquared::new))
.collect(Collectors.toList());
loss = new ElementSum(List.of(
lossWithoutPenalty,
new ConstantScale<>(
new ElementSum(l2penalty),
// we scale the penalty to achieve the same impact on the last (smaller) batch as on every other batch
parameters.penaltyL2() * originalBatchSize / graph.nodeCount()
)
));
} else {
loss = lossWithoutPenalty;
}
return new BatchTask(loss, weights, progressTracker);
}
private EpochResult trainEpoch(
Supplier> sampledBatchTaskSupplier,
List>> weights,
double prevEpochLoss
) {
var updater = new AdamOptimizer(weights, parameters.learningRate());
int iteration = 1;
var iterationLosses = new ArrayList();
double prevLoss = prevEpochLoss;
var converged = false;
int maxIterations = parameters.maxIterations();
for (; iteration <= maxIterations; iteration++) {
progressTracker.beginSubTask("Iteration");
terminationFlag.assertRunning();
var sampledBatchTasks = sampledBatchTaskSupplier.get();
// run forward + maybe backward for each Batch
RunWithConcurrency.builder()
.concurrency(parameters.concurrency())
.tasks(sampledBatchTasks)
.executor(executor)
.run();
var avgLossPerNode = sampledBatchTasks.stream().mapToDouble(BatchTask::loss).sum() / sampledBatchTasks.size();
iterationLosses.add(avgLossPerNode);
progressTracker.logInfo(formatWithLocale("Average loss per node: %.10f", avgLossPerNode));
if (Math.abs(prevLoss - avgLossPerNode) < parameters.tolerance()) {
converged = true;
progressTracker.endSubTask("Iteration");
break;
}
prevLoss = avgLossPerNode;
var batchedGradients = sampledBatchTasks
.stream()
.map(BatchTask::weightGradients)
.collect(Collectors.toList());
var meanGradients = averageTensors(batchedGradients);
updater.update(meanGradients);
progressTracker.endSubTask("Iteration");
}
return ImmutableEpochResult.of(converged, iterationLosses);
}
@ValueClass
interface EpochResult {
boolean converged();
List losses();
}
static class BatchTask implements Runnable {
private final Variable lossFunction;
private final List>> weightVariables;
private List extends Tensor>> weightGradients;
private final ProgressTracker progressTracker;
private double loss;
BatchTask(
Variable lossFunction,
List>> weightVariables,
ProgressTracker progressTracker
) {
this.lossFunction = lossFunction;
this.weightVariables = weightVariables;
this.progressTracker = progressTracker;
}
@Override
public void run() {
var localCtx = new ComputationContext();
loss = localCtx.forward(lossFunction).value();
localCtx.backward(lossFunction);
weightGradients = weightVariables.stream().map(localCtx::gradient).collect(Collectors.toList());
progressTracker.logProgress();
}
public double loss() {
return loss;
}
List extends Tensor>> weightGradients() {
return weightGradients;
}
}
@ValueClass
public interface GraphSageTrainMetrics extends CustomInfo {
static GraphSageTrainMetrics empty() {
return ImmutableGraphSageTrainMetrics.of(List.of(), false);
}
@Value.Derived
default List epochLosses() {
return iterationLossPerEpoch().stream()
.map(iterationLosses -> iterationLosses.get(iterationLosses.size() - 1))
.collect(Collectors.toList());
}
List> iterationLossPerEpoch();
boolean didConverge();
@Value.Derived
default int ranEpochs() {
return iterationLossPerEpoch().isEmpty()
? 0
: iterationLossPerEpoch().size();
}
@Value.Derived
default List ranIterationsPerEpoch() {
return iterationLossPerEpoch().stream().map(List::size).collect(Collectors.toList());
}
@Override
@Value.Auxiliary
@Value.Derived
default Map toMap() {
return Map.of(
"metrics", Map.of(
"epochLosses", epochLosses(),
"iterationLossesPerEpoch", iterationLossPerEpoch(),
"didConverge", didConverge(),
"ranEpochs", ranEpochs(),
"ranIterationsPerEpoch", ranIterationsPerEpoch()
));
}
}
@ValueClass
public interface ModelTrainResult {
GraphSageTrainMetrics metrics();
Layer[] layers();
static ModelTrainResult of(
List> iterationLossesPerEpoch,
boolean converged,
Layer[] layers
) {
return ImmutableModelTrainResult.builder()
.layers(layers)
.metrics(ImmutableGraphSageTrainMetrics.of(iterationLossesPerEpoch, converged))
.build();
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy