All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.autodiff.listeners.impl.ScoreListener Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
/*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */

package org.nd4j.autodiff.listeners.impl;

import lombok.extern.slf4j.Slf4j;
import org.nd4j.autodiff.listeners.At;
import org.nd4j.autodiff.listeners.BaseListener;
import org.nd4j.autodiff.listeners.ListenerResponse;
import org.nd4j.autodiff.listeners.Loss;
import org.nd4j.autodiff.listeners.records.LossCurve;
import org.nd4j.autodiff.listeners.Operation;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.dataset.api.MultiDataSet;

import java.text.DecimalFormat;

@Slf4j
public class ScoreListener extends BaseListener {

    private final int frequency;
    private final boolean reportEpochs;
    private final boolean reportIterPerformance;

    private long epochExampleCount;
    private int epochBatchCount;
    private long etlTotalTimeEpoch;

    private long lastIterTime;
    private long etlTimeSumSinceLastReport;
    private long iterTimeSumSinceLastReport;
    private int examplesSinceLastReportIter;
    private long lastReportTime = -1;

    /**
     * Create a ScoreListener reporting every 10 iterations, and at the end of each epoch
     */
    public ScoreListener() {
        this(10, true);
    }

    /**
     * Create a ScoreListener reporting every N iterations, and at the end of each epoch
     */
    public ScoreListener(int frequency) {
        this(frequency, true);
    }

    /**
     * Create a ScoreListener reporting every N iterations, and optionally at the end of each epoch
     */
    public ScoreListener(int frequency, boolean reportEpochs) {
        this(frequency, reportEpochs, true);
    }

    public ScoreListener(int frequency, boolean reportEpochs, boolean reportIterPerformance) {
        Preconditions.checkArgument(frequency > 0, "ScoreListener frequency must be > 0, got %s", frequency);
        this.frequency = frequency;
        this.reportEpochs = reportEpochs;
        this.reportIterPerformance = reportIterPerformance;
    }


    @Override
    public boolean isActive(Operation operation) {
        return operation == Operation.TRAINING;
    }

    @Override
    public void epochStart(SameDiff sd, At at) {
        if (reportEpochs) {
            epochExampleCount = 0;
            epochBatchCount = 0;
            etlTotalTimeEpoch = 0;
        }
        lastReportTime = -1;
        examplesSinceLastReportIter = 0;
    }

    @Override
    public ListenerResponse epochEnd(SameDiff sd, At at, LossCurve lossCurve, long epochTimeMillis) {
        if (reportEpochs) {
            double batchesPerSec = epochBatchCount / (epochTimeMillis / 1000.0);
            double examplesPerSec = epochExampleCount / (epochTimeMillis / 1000.0);
            double pcEtl = 100.0 * etlTotalTimeEpoch / (double) epochTimeMillis;
            String etl = formatDurationMs(etlTotalTimeEpoch) + " ETL time" + (etlTotalTimeEpoch > 0 ? "(" + format2dp(pcEtl) + " %)" : "");
            log.info("Epoch {} complete on iteration {} - {} batches ({} examples) in {} - {} batches/sec, {} examples/sec, {}",
                    at.epoch(), at.iteration(), epochBatchCount, epochExampleCount, formatDurationMs(epochTimeMillis),
                    format2dp(batchesPerSec), format2dp(examplesPerSec), etl);
        }

        return ListenerResponse.CONTINUE;
    }

    @Override
    public void iterationStart(SameDiff sd, At at, MultiDataSet data, long etlMs) {
        lastIterTime = System.currentTimeMillis();
        etlTimeSumSinceLastReport += etlMs;
        etlTotalTimeEpoch += etlMs;
    }

    @Override
    public void iterationDone(SameDiff sd, At at, MultiDataSet dataSet, Loss loss) {
        iterTimeSumSinceLastReport += System.currentTimeMillis() - lastIterTime;
        epochBatchCount++;
        if (dataSet.numFeatureArrays() > 0 && dataSet.getFeatures(0) != null) {
            int n = (int) dataSet.getFeatures(0).size(0);
            examplesSinceLastReportIter += n;
            epochExampleCount += n;
        }

        if (at.iteration() > 0 && at.iteration() % frequency == 0) {
            double l = loss.totalLoss();
            String etl = "";
            if (etlTimeSumSinceLastReport > 0) {
                etl = "(" + formatDurationMs(etlTimeSumSinceLastReport) + " ETL";
                if (frequency == 1) {
                    etl += ")";
                } else {
                    etl += " in " + frequency + " iter)";
                }
            }

            if(!reportIterPerformance) {
                log.info("Loss at epoch {}, iteration {}: {}{}", at.epoch(), at.iteration(), format5dp(l), etl);
            } else {
                long time = System.currentTimeMillis();
                if(lastReportTime > 0){
                    double batchPerSec = 1000 * frequency / (double)(time - lastReportTime);
                    double exPerSec = 1000 * examplesSinceLastReportIter / (double)(time - lastReportTime);
                    log.info("Loss at epoch {}, iteration {}: {}{}, batches/sec: {}, examples/sec: {}", at.epoch(), at.iteration(), format5dp(l),
                            etl, format5dp(batchPerSec), format5dp(exPerSec));
                } else {
                    log.info("Loss at epoch {}, iteration {}: {}{}", at.epoch(), at.iteration(), format5dp(l), etl);
                }

                lastReportTime = time;
            }

            iterTimeSumSinceLastReport = 0;
            etlTimeSumSinceLastReport = 0;
            examplesSinceLastReportIter = 0;
        }
    }

    protected String formatDurationMs(long ms) {
        if (ms <= 100) {
            return ms + " ms";
        } else if (ms <= 60000L) {
            double sec = ms / 1000.0;
            return format2dp(sec) + " sec";
        } else if (ms <= 60 * 60000L) {
            double min = ms / 60_000.0;
            return format2dp(min) + " min";
        } else {
            double hr = ms / 360_000.0;
            return format2dp(hr) + " hr";
        }
    }

    protected static final ThreadLocal DF_2DP = new ThreadLocal<>();
    protected static final ThreadLocal DF_2DP_SCI = new ThreadLocal<>();

    protected String format2dp(double d) {
        if (d < 0.01) {
            DecimalFormat f = DF_2DP_SCI.get();
            if (f == null) {
                f = new DecimalFormat("0.00E0");
                DF_2DP.set(f);
            }
            return f.format(d);
        } else {
            DecimalFormat f = DF_2DP.get();
            if (f == null) {
                f = new DecimalFormat("#.00");
                DF_2DP.set(f);
            }
            return f.format(d);
        }
    }

    protected static final ThreadLocal DF_5DP = new ThreadLocal<>();
    protected static final ThreadLocal DF_5DP_SCI = new ThreadLocal<>();

    protected String format5dp(double d) {

        if (d < 1e-4 || d > 1e4) {
            //Use scientific
            DecimalFormat f = DF_5DP_SCI.get();
            if (f == null) {
                f = new DecimalFormat("0.00000E0");
                DF_5DP_SCI.set(f);
            }
            return f.format(d);
        } else {
            DecimalFormat f = DF_5DP.get();
            if (f == null) {
                f = new DecimalFormat("0.00000");
                DF_5DP.set(f);
            }
            return f.format(d);
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy