All Downloads are FREE. Search and download functionalities are using the official Maven repository.

deepnetts.net.loss.CrossEntropyLoss Maven / Gradle / Ivy

/**  
 *  DeepNetts is pure Java Deep Learning Library with support for Backpropagation 
 *  based learning and image recognition.
 * 
 *  Copyright (C) 2017  Zoran Sevarac 
 *
 *  This file is part of DeepNetts.
 *
 *  DeepNetts is free software: you can redistribute it and/or modify
 *  it under the terms of the GNU General Public License as published by
 *  the Free Software Foundation, either version 3 of the License, or
 *  (at your option) any later version.
 *
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *  GNU General Public License for more details.
 *
 *  You should have received a copy of the GNU General Public License
 *  along with this program.  If not, see .package deepnetts.core;
 */
    
package deepnetts.net.loss;

import deepnetts.net.NeuralNetwork;
import java.io.Serializable;

/**
 * Represents Average Cross Entropy Loss function.
 * This function is used as a loss function for a multi class classification problems.
 * 
 * @author Zoran Sevarac 
 */
public class CrossEntropyLoss implements LossFunction, Serializable {
	
	
	private static final long serialVersionUID = 7810738324038602274L;
	
	
    private final float[] outputError;
    private int targetIdx;    
    private float totalError;
    private int patternCount=0;            
    
    public CrossEntropyLoss(NeuralNetwork neuralNet) {
     //   outputLayer = (OutputLayer)neuralNet.getOutputLayer();
        outputError = new float[neuralNet.getOutputLayer().getWidth()];
    }
       
    
    /**
     * Calculates and returns error vector for specified actual and target outputs.
     * 
     * @param actualOutput actual output from the neural network
     * @param targetOutput target/desired output of the neural network
     * @return error vector for specified actual and target outputs
     */
    @Override
    public float[] addPatternError(float[] actualOutput,  float[] targetOutput) {        
        patternCount++;        
        
        for (int i = 0; i < actualOutput.length; i++) {                                     
            outputError[i] = actualOutput[i] - targetOutput[i]; // ovo je dL/dy izvod loss funkcije u odnosu na izlaz ovog neurona - ovo se koristi za deltu izlaznog neurona
            if (targetOutput[i] == 1) {                        
                targetIdx = i; // TODO: this could be set explicitly in data set in order to avoid this if     
            }
        }     

        totalError += (float)Math.log(actualOutput[targetIdx]);        
        
        return outputError;        
    }
    
    @Override
    public void addRegularizationSum(final float reg) {
        totalError += reg;
    }       
    
    @Override
    public float getTotal() {
        return  -totalError / patternCount;
    }
    
    @Override
    public void reset() {
        totalError = 0;
        patternCount=0;
    }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy