All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.autodiff.listeners.Listener Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
/*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */

package org.nd4j.autodiff.listeners;

import org.nd4j.autodiff.listeners.records.LossCurve;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
import org.nd4j.autodiff.samediff.internal.Variable;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.OpContext;
import org.nd4j.linalg.dataset.api.MultiDataSet;

public interface Listener {


    /**
     * Required variables for this listener.
     * 

* Used to ensure these variables end up in the minimum required subgraph calculated by {@link org.nd4j.autodiff.samediff.internal.InferenceSession}. * Otherwise, if the variables weren't required by a loss variable, they would not be calculated. *

* Any variables in here are guaranteed to have {@link Listener#activationAvailable(SameDiff, At, MultiDataSet, SameDiffOp, String, INDArray)} * called for them, regardless of whether they would normally be calculated or not. */ ListenerVariables requiredVariables(SameDiff sd); /** * Returns whether this listener is active during the given operation. If this returns false for the given operation, * those listener methods will not be called. */ boolean isActive(Operation operation); /** * Called at the start of every epoch, when fitting from an iterator * * @param sd The SameDiff instance * @param at Current iteration/epoch etc */ void epochStart(SameDiff sd, At at); /** * Called at the end of every epoch, when fitting from an iterator * * @param sd The SameDiff instance * @param at Current iteration/epoch etc * @param lossCurve The losses so far * @param epochTimeMillis How long this epoch took * @return ListenerResponse.STOP to stop training, CONTINUE or null to continue */ ListenerResponse epochEnd(SameDiff sd, At at, LossCurve lossCurve, long epochTimeMillis); /** * Called after the end of every epoch, once validation evaluation is done, when training * * @param sd The SameDiff instance * @param at Current iteration/epoch etc * @param validationTimeMillis How long validation took for this epoch * @return ListenerResponse.STOP to stop training, CONTINUE or null to continue */ ListenerResponse validationDone(SameDiff sd, At at, long validationTimeMillis); /** * Called at the start of every iteration (minibatch), before any operations have been executed * * @param sd The SameDiff instance * @param at Current iteration/epoch etc */ void iterationStart(SameDiff sd, At at, MultiDataSet data, long etlTimeMs); /** * Called at the end of every iteration, after all operations (including updating parameters) has been completed * * @param sd The SameDiff instance * @param at Current iteration/epoch etc * @param dataSet The current dataset (minibatch) used for training * @param loss The loss value for the current minibatch. Will be null except for during training */ void iterationDone(SameDiff sd, At at, MultiDataSet dataSet, Loss loss); /** * Called at the start of an operation, e.g. training or validation * * @param sd The SameDiff instance * @param op The operation being started */ void operationStart(SameDiff sd, Operation op); /** * Called at the end of an operation, e.g. training or validation * * @param sd The SameDiff instance * @param op The operation being started */ void operationEnd(SameDiff sd, Operation op); /** * Called just before each operation is executed (native code called, etc) - after all inputs etc have been set * * @param sd The SameDiff instance * @param at Current iteration/epoch etc * @param op Operation that has just been executed */ void preOpExecution(SameDiff sd, At at, SameDiffOp op, OpContext opContext); /** * Called at the end of each operation execution
*

* Note: Outputs will most likely be freed later, use detach() if you need to save it. * * @param sd The SameDiff instance * @param at Current iteration/epoch etc * @param batch The batch's input data. May be null if not called with a batch * @param op Operation that has just been executed * @param outputs The output arrays for the just-executed operation */ void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, OpContext opContext, INDArray[] outputs); /** * Called when any activation becomes available. *

* The activation will most likely be freed later, use dup() if you need to save it.
*
* Note that this method will be called when any activation becomes available, not just ones from {@link #requiredVariables(SameDiff)}
* It is guaranteed to be called for variables from requiredVariables().
*
* Note that the activations here overlap with {@link #opExecution(SameDiff, At, MultiDataSet, SameDiffOp, OpContext, INDArray[])} - * both contain the same information/arrays * * @param sd The SameDiff instance * @param at Current iteration/epoch etc * @param batch The batch's input data. May be null if not called with a batch * @param op Operation that has just been executed * @param varName The name of the variable * @param activation The variable's activation */ void activationAvailable(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, String varName, INDArray activation); /** * Called just before each parameter is to be updated - i.e., just before each parameter is modified. * * @param sd SameDiff instance * @param at The current iteration/epoch etc * @param v Variable about to be updated during backprop * @param update The array representing the update (i.e., the gradient after applying learning rate, momentum, etc) */ void preUpdate(SameDiff sd, At at, Variable v, INDArray update); }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy