All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.autodiff.listeners.ListenerEvaluations Maven / Gradle / Ivy

The newest version!
/*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */

package org.nd4j.autodiff.listeners;

import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;

import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.NonNull;
import lombok.Setter;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.common.base.Preconditions;
import org.nd4j.evaluation.IEvaluation;

@Getter
public class ListenerEvaluations {
    private Map> trainEvaluations;
    private Map trainEvaluationLabels;

    private Map> validationEvaluations;
    private Map validationEvaluationLabels;

    public ListenerEvaluations(Map> trainEvaluations,
                               Map trainEvaluationLabels, Map> validationEvaluations,
                               Map validationEvaluationLabels) {
        this.trainEvaluations = trainEvaluations;
        this.trainEvaluationLabels = trainEvaluationLabels;
        this.validationEvaluations = validationEvaluations;
        this.validationEvaluationLabels = validationEvaluationLabels;

        Preconditions.checkArgument(trainEvaluations.keySet().equals(trainEvaluationLabels.keySet()),
                "Must specify a label index for each train evaluation.  Expected: %s, got: %s",
                trainEvaluations.keySet(), trainEvaluationLabels.keySet());

        Preconditions.checkArgument(validationEvaluations.keySet().equals(validationEvaluationLabels.keySet()),
                "Must specify a label index for each validation evaluation.  Expected: %s, got: %s",
                validationEvaluations.keySet(), validationEvaluationLabels.keySet());
    }

    private ListenerEvaluations() {

    }

    public static Builder builder() {
        return new Builder();
    }

    /**
     * Get the requested training evaluations
     */
    public Map> trainEvaluations() {
        return trainEvaluations;
    }

    /**
     * Get the label indices for the requested training evaluations
     */
    public Map trainEvaluationLabels() {
        return trainEvaluationLabels;
    }

    /**
     * Get the requested validation evaluations
     */
    public Map> validationEvaluations() {
        return validationEvaluations;
    }

    /**
     * Get the label indices for the requested validation evaluations
     */
    public Map validationEvaluationLabels() {
        return validationEvaluationLabels;
    }

    /**
     * Get the required variables for these evaluations
     */
    public ListenerVariables requiredVariables() {
        return new ListenerVariables(trainEvaluations.keySet(), validationEvaluations.keySet(),
                new HashSet(), new HashSet());
    }

    /**
     * @return true if there are no requested evaluations
     */
    public boolean isEmpty() {
        return trainEvaluations.isEmpty() && validationEvaluations.isEmpty();
    }

    @NoArgsConstructor
    @Getter
    @Setter
    public static class Builder {
        private Map> trainEvaluations = new HashMap<>();
        private Map trainEvaluationLabels = new HashMap<>();

        private Map> validationEvaluations = new HashMap<>();
        private Map validationEvaluationLabels = new HashMap<>();

        private void addEvaluations(boolean validation, @NonNull Map> evaluationMap, @NonNull Map labelMap,
                                    @NonNull String variableName, int labelIndex, @NonNull IEvaluation... evaluations) {
            if (evaluationMap.containsKey(variableName) && labelMap.get(variableName) != labelIndex) {
                String s;

                if (validation) {
                    s = "This ListenerEvaluations.Builder already has validation evaluations for ";
                } else {
                    s = "This ListenerEvaluations.Builder already has train evaluations for ";
                }

                throw new IllegalArgumentException(s + "variable " +
                        variableName + " with label index " + labelIndex + ".  You can't add " +
                        " evaluations with a different label index.  Got label index " + labelIndex);
            }

            if (evaluationMap.containsKey(variableName)) {
                evaluationMap.get(variableName).addAll(Arrays.asList(evaluations));
            } else {
                evaluationMap.put(variableName, Arrays.asList(evaluations));
                labelMap.put(variableName, labelIndex);
            }
        }

        /**
         * Add requested training evaluations for a parm/variable
         *
         * @param variableName The variable to evaluate
         * @param labelIndex   The index of the label to evaluate against
         * @param evaluations  The evaluations to run
         */
        public Builder trainEvaluation(@NonNull String variableName, int labelIndex, @NonNull IEvaluation... evaluations) {
            addEvaluations(false, this.trainEvaluations, this.trainEvaluationLabels, variableName,
                    labelIndex, evaluations);
            return this;
        }

        /**
         * Add requested training evaluations for a parm/variable
         *
         * @param variable    The variable to evaluate
         * @param labelIndex  The index of the label to evaluate against
         * @param evaluations The evaluations to run
         */
        public Builder trainEvaluation(@NonNull SDVariable variable, int labelIndex, @NonNull IEvaluation... evaluations) {
            return trainEvaluation(variable.name(), labelIndex, evaluations);
        }

        /**
         * Add requested validation evaluations for a parm/variable
         *
         * @param variableName The variable to evaluate
         * @param labelIndex   The index of the label to evaluate against
         * @param evaluations  The evaluations to run
         */
        public Builder validationEvaluation(@NonNull String variableName, int labelIndex, @NonNull IEvaluation... evaluations) {
            addEvaluations(true, this.validationEvaluations, this.validationEvaluationLabels, variableName,
                    labelIndex, evaluations);
            return this;
        }

        /**
         * Add requested validation evaluations for a parm/variable
         *
         * @param variable    The variable to evaluate
         * @param labelIndex  The index of the label to evaluate against
         * @param evaluations The evaluations to run
         */
        public Builder validationEvaluation(@NonNull SDVariable variable, int labelIndex, @NonNull IEvaluation... evaluations) {
            return validationEvaluation(variable.name(), labelIndex, evaluations);
        }

        /**
         * Add requested evaluations for a parm/variable, for either training or validation
         *
         * @param validation   Whether to add these evaluations as validation or training
         * @param variableName The variable to evaluate
         * @param labelIndex   The index of the label to evaluate against
         * @param evaluations  The evaluations to run
         */
        public Builder addEvaluations(boolean validation, @NonNull String variableName, int labelIndex, @NonNull IEvaluation... evaluations) {
            if (validation) {
                return validationEvaluation(variableName, labelIndex, evaluations);
            } else {
                return trainEvaluation(variableName, labelIndex, evaluations);
            }
        }

        public ListenerEvaluations build() {
            return new ListenerEvaluations(trainEvaluations, trainEvaluationLabels, validationEvaluations, validationEvaluationLabels);
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy