All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.spark.impl.common.updater.UpdaterElementCombinerCG Maven / Gradle / Ivy

The newest version!
package org.deeplearning4j.spark.impl.common.updater;

import org.apache.spark.api.java.function.Function2;
import org.deeplearning4j.nn.api.Updater;
import org.deeplearning4j.nn.updater.aggregate.UpdaterAggregator;
import org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater;

/** Simple function to add an Updater to an UpdaterAggregator */
public class UpdaterElementCombinerCG implements Function2 {
    @Override
    public ComputationGraphUpdater.Aggregator call(ComputationGraphUpdater.Aggregator updaterAggregator, ComputationGraphUpdater updater) throws Exception {
        if(updaterAggregator == null && updater == null) return null;

        if(updaterAggregator == null){
            //updater is not null, but updaterAggregator is
            return updater.getAggregator(true);
        }
        if(updater == null){
            //updater is null, but aggregator is not -> no op
            return updaterAggregator;
        }

        //both are non-null
        updaterAggregator.aggregate(updater);
        return updaterAggregator;
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy