All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.autodiff.listeners.ListenerEvaluations Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
/*
 * Copyright (c) 2015-2019 Skymind, Inc.
 *
 * This program and the accompanying materials are made available under the
 * terms of the Apache License, Version 2.0 which is available at
 * https://www.apache.org/licenses/LICENSE-2.0.
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 *
 * SPDX-License-Identifier: Apache-2.0
 */

package org.nd4j.autodiff.listeners;

import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;

import lombok.Data;
import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.NonNull;
import lombok.Setter;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.base.Preconditions;
import org.nd4j.evaluation.IEvaluation;

/**
 * A class to allow Listeners to define what evaluations they need to run during training
*

* Usage example - does classification ({@link org.nd4j.evaluation.classification.Evaluation}) evaluation on * the training set (as training proceeds) and also Evaluation/ROCMultiClass evaluation on the test/validation set. * Assumes that the output predictions are called "softmax" and the first DataSet/MultiDataSet labels are those corresponding * to the "softmax" node *

{@code
 * ListenerEvaluations.builder()
 *     //trainEvaluations: on the training data (in-line, as training proceeds through the epoch)
 *     .trainEvaluation("softmax", 0, new Evaluation(), new ROCMultiClass())
 *     //validationEvaluation: on the test/validation data, at the end of each epoch
 *     .validationEvaluation("softmax", 0, new Evaluation(), new ROCMultiClass())
 *     .build();
 * }
*/ @Getter public class ListenerEvaluations { private Map> trainEvaluations; private Map trainEvaluationLabels; private Map> validationEvaluations; private Map validationEvaluationLabels; public ListenerEvaluations(Map> trainEvaluations, Map trainEvaluationLabels, Map> validationEvaluations, Map validationEvaluationLabels) { this.trainEvaluations = trainEvaluations; this.trainEvaluationLabels = trainEvaluationLabels; this.validationEvaluations = validationEvaluations; this.validationEvaluationLabels = validationEvaluationLabels; Preconditions.checkArgument(trainEvaluations.keySet().equals(trainEvaluationLabels.keySet()), "Must specify a label index for each train evaluation. Expected: %s, got: %s", trainEvaluations.keySet(), trainEvaluationLabels.keySet()); Preconditions.checkArgument(validationEvaluations.keySet().equals(validationEvaluationLabels.keySet()), "Must specify a label index for each validation evaluation. Expected: %s, got: %s", validationEvaluations.keySet(), validationEvaluationLabels.keySet()); } private ListenerEvaluations() { } public static Builder builder() { return new Builder(); } /** * Get the requested training evaluations */ public Map> trainEvaluations() { return trainEvaluations; } /** * Get the label indices for the requested training evaluations */ public Map trainEvaluationLabels() { return trainEvaluationLabels; } /** * Get the requested validation evaluations */ public Map> validationEvaluations() { return validationEvaluations; } /** * Get the label indices for the requested validation evaluations */ public Map validationEvaluationLabels() { return validationEvaluationLabels; } /** * Get the required variables for these evaluations */ public ListenerVariables requiredVariables() { return new ListenerVariables(trainEvaluations.keySet(), validationEvaluations.keySet(), new HashSet(), new HashSet()); } /** * @return true if there are no requested evaluations */ public boolean isEmpty() { return trainEvaluations.isEmpty() && validationEvaluations.isEmpty(); } @NoArgsConstructor @Getter @Setter public static class Builder { private Map> trainEvaluations = new HashMap<>(); private Map trainEvaluationLabels = new HashMap<>(); private Map> validationEvaluations = new HashMap<>(); private Map validationEvaluationLabels = new HashMap<>(); private void addEvaluations(boolean validation, @NonNull Map> evaluationMap, @NonNull Map labelMap, @NonNull String variableName, int labelIndex, @NonNull IEvaluation... evaluations) { if (evaluationMap.containsKey(variableName) && labelMap.get(variableName) != labelIndex) { String s; if (validation) { s = "This ListenerEvaluations.Builder already has validation evaluations for "; } else { s = "This ListenerEvaluations.Builder already has train evaluations for "; } throw new IllegalArgumentException(s + "variable " + variableName + " with label index " + labelIndex + ". You can't add " + " evaluations with a different label index. Got label index " + labelIndex); } if (evaluationMap.containsKey(variableName)) { evaluationMap.get(variableName).addAll(Arrays.asList(evaluations)); } else { evaluationMap.put(variableName, Arrays.asList(evaluations)); labelMap.put(variableName, labelIndex); } } /** * Add requested training evaluations for a parm/variable * * @param variableName The variable to evaluate * @param labelIndex The index of the label to evaluate against * @param evaluations The evaluations to run */ public Builder trainEvaluation(@NonNull String variableName, int labelIndex, @NonNull IEvaluation... evaluations) { addEvaluations(false, this.trainEvaluations, this.trainEvaluationLabels, variableName, labelIndex, evaluations); return this; } /** * Add requested training evaluations for a parm/variable * * @param variable The variable to evaluate * @param labelIndex The index of the label to evaluate against * @param evaluations The evaluations to run */ public Builder trainEvaluation(@NonNull SDVariable variable, int labelIndex, @NonNull IEvaluation... evaluations) { return trainEvaluation(variable.getVarName(), labelIndex, evaluations); } /** * Add requested validation evaluations for a parm/variable * * @param variableName The variable to evaluate * @param labelIndex The index of the label to evaluate against * @param evaluations The evaluations to run */ public Builder validationEvaluation(@NonNull String variableName, int labelIndex, @NonNull IEvaluation... evaluations) { addEvaluations(true, this.validationEvaluations, this.validationEvaluationLabels, variableName, labelIndex, evaluations); return this; } /** * Add requested validation evaluations for a parm/variable * * @param variable The variable to evaluate * @param labelIndex The index of the label to evaluate against * @param evaluations The evaluations to run */ public Builder validationEvaluation(@NonNull SDVariable variable, int labelIndex, @NonNull IEvaluation... evaluations) { return validationEvaluation(variable.getVarName(), labelIndex, evaluations); } /** * Add requested evaluations for a parm/variable, for either training or validation * * @param validation Whether to add these evaluations as validation or training * @param variableName The variable to evaluate * @param labelIndex The index of the label to evaluate against * @param evaluations The evaluations to run */ public Builder addEvaluations(boolean validation, @NonNull String variableName, int labelIndex, @NonNull IEvaluation... evaluations) { if (validation) { return validationEvaluation(variableName, labelIndex, evaluations); } else { return trainEvaluation(variableName, labelIndex, evaluations); } } public ListenerEvaluations build() { return new ListenerEvaluations(trainEvaluations, trainEvaluationLabels, validationEvaluations, validationEvaluationLabels); } } }




© 2015 - 2024 Weber Informatics LLC | Privacy Policy