Please wait. This can take some minutes ...
Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance.
Project price only 1 $
You can buy this project and download/modify it how often you want.
org.nd4j.autodiff.listeners.impl.ScoreListener Maven / Gradle / Ivy
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.nd4j.autodiff.listeners.impl;
import lombok.extern.slf4j.Slf4j;
import org.nd4j.autodiff.listeners.At;
import org.nd4j.autodiff.listeners.BaseListener;
import org.nd4j.autodiff.listeners.ListenerResponse;
import org.nd4j.autodiff.listeners.Loss;
import org.nd4j.autodiff.listeners.records.LossCurve;
import org.nd4j.autodiff.listeners.Operation;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import java.text.DecimalFormat;
@Slf4j
public class ScoreListener extends BaseListener {
private final int frequency;
private final boolean reportEpochs;
private final boolean reportIterPerformance;
private long epochExampleCount;
private int epochBatchCount;
private long etlTotalTimeEpoch;
private long lastIterTime;
private long etlTimeSumSinceLastReport;
private long iterTimeSumSinceLastReport;
private int examplesSinceLastReportIter;
private long lastReportTime = -1;
/**
* Create a ScoreListener reporting every 10 iterations, and at the end of each epoch
*/
public ScoreListener() {
this(10, true);
}
/**
* Create a ScoreListener reporting every N iterations, and at the end of each epoch
*/
public ScoreListener(int frequency) {
this(frequency, true);
}
/**
* Create a ScoreListener reporting every N iterations, and optionally at the end of each epoch
*/
public ScoreListener(int frequency, boolean reportEpochs) {
this(frequency, reportEpochs, true);
}
public ScoreListener(int frequency, boolean reportEpochs, boolean reportIterPerformance) {
Preconditions.checkArgument(frequency > 0, "ScoreListener frequency must be > 0, got %s", frequency);
this.frequency = frequency;
this.reportEpochs = reportEpochs;
this.reportIterPerformance = reportIterPerformance;
}
@Override
public boolean isActive(Operation operation) {
return operation == Operation.TRAINING;
}
@Override
public void epochStart(SameDiff sd, At at) {
if (reportEpochs) {
epochExampleCount = 0;
epochBatchCount = 0;
etlTotalTimeEpoch = 0;
}
lastReportTime = -1;
examplesSinceLastReportIter = 0;
}
@Override
public ListenerResponse epochEnd(SameDiff sd, At at, LossCurve lossCurve, long epochTimeMillis) {
if (reportEpochs) {
double batchesPerSec = epochBatchCount / (epochTimeMillis / 1000.0);
double examplesPerSec = epochExampleCount / (epochTimeMillis / 1000.0);
double pcEtl = 100.0 * etlTotalTimeEpoch / (double) epochTimeMillis;
String etl = formatDurationMs(etlTotalTimeEpoch) + " ETL time" + (etlTotalTimeEpoch > 0 ? "(" + format2dp(pcEtl) + " %)" : "");
log.info("Epoch {} complete on iteration {} - {} batches ({} examples) in {} - {} batches/sec, {} examples/sec, {}",
at.epoch(), at.iteration(), epochBatchCount, epochExampleCount, formatDurationMs(epochTimeMillis),
format2dp(batchesPerSec), format2dp(examplesPerSec), etl);
}
return ListenerResponse.CONTINUE;
}
@Override
public void iterationStart(SameDiff sd, At at, MultiDataSet data, long etlMs) {
lastIterTime = System.currentTimeMillis();
etlTimeSumSinceLastReport += etlMs;
etlTotalTimeEpoch += etlMs;
}
@Override
public void iterationDone(SameDiff sd, At at, MultiDataSet dataSet, Loss loss) {
iterTimeSumSinceLastReport += System.currentTimeMillis() - lastIterTime;
epochBatchCount++;
if (dataSet.numFeatureArrays() > 0 && dataSet.getFeatures(0) != null) {
int n = (int) dataSet.getFeatures(0).size(0);
examplesSinceLastReportIter += n;
epochExampleCount += n;
}
if (at.iteration() > 0 && at.iteration() % frequency == 0) {
double l = loss.totalLoss();
String etl = "";
if (etlTimeSumSinceLastReport > 0) {
etl = "(" + formatDurationMs(etlTimeSumSinceLastReport) + " ETL";
if (frequency == 1) {
etl += ")";
} else {
etl += " in " + frequency + " iter)";
}
}
if(!reportIterPerformance) {
log.info("Loss at epoch {}, iteration {}: {}{}", at.epoch(), at.iteration(), format5dp(l), etl);
} else {
long time = System.currentTimeMillis();
if(lastReportTime > 0){
double batchPerSec = 1000 * frequency / (double)(time - lastReportTime);
double exPerSec = 1000 * examplesSinceLastReportIter / (double)(time - lastReportTime);
log.info("Loss at epoch {}, iteration {}: {}{}, batches/sec: {}, examples/sec: {}", at.epoch(), at.iteration(), format5dp(l),
etl, format5dp(batchPerSec), format5dp(exPerSec));
} else {
log.info("Loss at epoch {}, iteration {}: {}{}", at.epoch(), at.iteration(), format5dp(l), etl);
}
lastReportTime = time;
}
iterTimeSumSinceLastReport = 0;
etlTimeSumSinceLastReport = 0;
examplesSinceLastReportIter = 0;
}
}
protected String formatDurationMs(long ms) {
if (ms <= 100) {
return ms + " ms";
} else if (ms <= 60000L) {
double sec = ms / 1000.0;
return format2dp(sec) + " sec";
} else if (ms <= 60 * 60000L) {
double min = ms / 60_000.0;
return format2dp(min) + " min";
} else {
double hr = ms / 360_000.0;
return format2dp(hr) + " hr";
}
}
protected static final ThreadLocal DF_2DP = new ThreadLocal<>();
protected static final ThreadLocal DF_2DP_SCI = new ThreadLocal<>();
protected String format2dp(double d) {
if (d < 0.01) {
DecimalFormat f = DF_2DP_SCI.get();
if (f == null) {
f = new DecimalFormat("0.00E0");
DF_2DP.set(f);
}
return f.format(d);
} else {
DecimalFormat f = DF_2DP.get();
if (f == null) {
f = new DecimalFormat("#.00");
DF_2DP.set(f);
}
return f.format(d);
}
}
protected static final ThreadLocal DF_5DP = new ThreadLocal<>();
protected static final ThreadLocal DF_5DP_SCI = new ThreadLocal<>();
protected String format5dp(double d) {
if (d < 1e-4 || d > 1e4) {
//Use scientific
DecimalFormat f = DF_5DP_SCI.get();
if (f == null) {
f = new DecimalFormat("0.00000E0");
DF_5DP_SCI.set(f);
}
return f.format(d);
} else {
DecimalFormat f = DF_5DP.get();
if (f == null) {
f = new DecimalFormat("0.00000");
DF_5DP.set(f);
}
return f.format(d);
}
}
}