All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.autodiff.listeners.Listener Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
package org.nd4j.autodiff.listeners;

import org.nd4j.autodiff.listeners.records.LossCurve;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
import org.nd4j.autodiff.samediff.internal.Variable;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.MultiDataSet;

/**
 * A {@link SameDiff} listener interface that is called during every iteration of training or inference
 *
 * @author Alex Black
 * @see BaseListener BaseListener, for extending only the required methods (all others are no-op)
 * @see BaseEvaluationListener BaseEvaluationListener, for extending if you want to use evaluations
 */
public interface Listener {


    /**
     * Required variables for this listener.
     * 

* Used to ensure these variables end up in the minimum required subgraph calculated by {@link org.nd4j.autodiff.samediff.internal.InferenceSession}. * Otherwise, if the variables weren't required by a loss variable, they would not be calculated. *

* Any variables in here are guaranteed to have {@link Listener#activationAvailable(SameDiff, At, MultiDataSet, SameDiffOp, String, INDArray)} * called for them, regardless of whether they would normally be calculated or not. */ ListenerVariables requiredVariables(SameDiff sd); /** * Returns whether this listener is active during the given operation. If this returns false for the given operation, * those listener methods will not be called. */ boolean isActive(Operation operation); /** * Called at the start of every epoch, when fitting from an iterator * * @param sd The SameDiff instance * @param at Current iteration/epoch etc */ void epochStart(SameDiff sd, At at); /** * Called at the end of every epoch, when fitting from an iterator * * @param sd The SameDiff instance * @param at Current iteration/epoch etc * @param lossCurve The losses so far * @param epochTimeMillis How long this epoch took * @return ListenerResponse.STOP to stop training, CONTINUE or null to continue */ ListenerResponse epochEnd(SameDiff sd, At at, LossCurve lossCurve, long epochTimeMillis); /** * Called after the end of every epoch, once validation evaluation is done, when training * * @param sd The SameDiff instance * @param at Current iteration/epoch etc * @param validationTimeMillis How long validation took for this epoch * @return ListenerResponse.STOP to stop training, CONTINUE or null to continue */ ListenerResponse validationDone(SameDiff sd, At at, long validationTimeMillis); /** * Called at the start of every iteration (minibatch), before any operations have been executed * * @param sd The SameDiff instance * @param at Current iteration/epoch etc */ void iterationStart(SameDiff sd, At at, MultiDataSet data, long etlTimeMs); /** * Called at the end of every iteration, after all operations (including updating parameters) has been completed * * @param sd The SameDiff instance * @param at Current iteration/epoch etc * @param dataSet The current dataset (minibatch) used for training * @param loss The loss value for the current minibatch. Will be null except for during training */ void iterationDone(SameDiff sd, At at, MultiDataSet dataSet, Loss loss); /** * Called at the start of an operation, e.g. training or validation * * @param sd The SameDiff instance * @param op The operation being started */ void operationStart(SameDiff sd, Operation op); /** * Called at the end of an operation, e.g. training or validation * * @param sd The SameDiff instance * @param op The operation being started */ void operationEnd(SameDiff sd, Operation op); /** * Called just before each operation is executed (native code called, etc) - after all inputs etc have been set * * @param sd The SameDiff instance * @param at Current iteration/epoch etc * @param op Operation that has just been executed */ void preOpExecution(SameDiff sd, At at, SameDiffOp op); /** * Called at the end of each operation execution
*

* Note: Outputs will most likely be freed later, use detach() if you need to save it. * * @param sd The SameDiff instance * @param at Current iteration/epoch etc * @param batch The batch's input data. May be null if not called with a batch * @param op Operation that has just been executed * @param outputs The output arrays for the just-executed operation */ void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, INDArray[] outputs); /** * Called when any activation becomes available. *

* The activation will most likely be freed later, use detach() if you need to save it.
*
* Note that this method will be called when any activation becomes available, not just ones from {@link #requiredVariables(SameDiff)}
* It is guaranteed to be called for variables from requiredVariables().
*
* Note that the activations here overlap with {@link #opExecution(SameDiff, At, MultiDataSet, SameDiffOp, INDArray[])} - * both contain the same information/arrays * * @param sd The SameDiff instance * @param at Current iteration/epoch etc * @param batch The batch's input data. May be null if not called with a batch * @param op Operation that has just been executed * @param varName The name of the variable * @param activation The variable's activation */ void activationAvailable(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, String varName, INDArray activation); /** * Called just before each parameter is to be updated - i.e., just before each parameter is modified. * * @param sd SameDiff instance * @param at The current iteration/epoch etc * @param v Variable about to be updated during backprop * @param update The array representing the update (i.e., the gradient after applying learning rate, momentum, etc) */ void preUpdate(SameDiff sd, At at, Variable v, INDArray update); }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy