All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.ui.weights.ModelAndGradient Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
package org.deeplearning4j.ui.weights;


import org.deeplearning4j.nn.api.Model;
import org.nd4j.linalg.api.ndarray.INDArray;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/**
 * @author Adam Gibson
 */

public class ModelAndGradient implements Serializable {
    private long lastUpdateTime = -1L;
    private Map parameters;
    private Map gradients;
    private double score;
    private List scores = new ArrayList<>();
    private List>> updateMagnitudes = new ArrayList<>();
    private List>> paramMagnitudes = new ArrayList<>();
    private List layerNames = new ArrayList<>();
    private String path;


    public ModelAndGradient() {
        parameters = new HashMap<>();
        gradients = new HashMap<>();
    }

    public ModelAndGradient(Model model) {
        model.computeGradientAndScore();
        this.gradients = model.gradient().gradientForVariable();
        this.parameters = model.paramTable();
        this.score = model.score();
    }


    public void setLastUpdateTime(long lastUpdateTime) {
        this.lastUpdateTime = lastUpdateTime;
    }

    public long getLastUpdateTime() {
        return lastUpdateTime;
    }

    public double getScore() {
        return score;
    }

    public void setScore(double score) {
        this.score = score;
    }


    public Map getParameters() {
        return parameters;
    }

    public void setParameters(Map parameters) {
        this.parameters = parameters;
    }


    public Map getGradients() {
        return gradients;
    }

    public void setGradients(Map gradients) {
        this.gradients = gradients;
    }

    public void setScores(List scores) {
        this.scores = scores;
    }

    public void setPath(String path) {
        this.path = path;
    }

    public String getPath() {
        return path;
    }

    public List getScores() {
        return scores;
    }

    public void setUpdateMagnitudes(List>> updateMagnitudes) {
        this.updateMagnitudes = updateMagnitudes;
    }

    public List>> getUpdateMagnitudes() {
        return updateMagnitudes;
    }

    public void setParamMagnitudes(List>> paramMagnitudes) {
        this.paramMagnitudes = paramMagnitudes;
    }

    public List>> getParamMagnitudes() {
        return paramMagnitudes;
    }

    public void setLayerNames(List layerNames) {
        this.layerNames = layerNames;
    }

    public List getLayerNames() {
        return layerNames;
    }

    @Override
    public boolean equals(Object o) {
        if (this == o)
            return true;
        if (o == null || getClass() != o.getClass())
            return false;

        ModelAndGradient that = (ModelAndGradient) o;

        if (Double.compare(that.score, score) != 0)
            return false;
        if (parameters != null ? !parameters.equals(that.parameters) : that.parameters != null)
            return false;
        return !(gradients != null ? !gradients.equals(that.gradients) : that.gradients != null);
    }

    @Override
    public int hashCode() {
        int result;
        long temp;
        result = parameters != null ? parameters.hashCode() : 0;
        result = 31 * result + (gradients != null ? gradients.hashCode() : 0);
        temp = Double.doubleToLongBits(score);
        result = 31 * result + (int) (temp ^ (temp >>> 32));
        return result;
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy