All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.nn.conf.layers.BaseOutputLayer Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
package org.deeplearning4j.nn.conf.layers;

import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor;
import lombok.ToString;
import org.nd4j.linalg.lossfunctions.ILossFunction;
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
import org.nd4j.linalg.lossfunctions.impl.*;

@Data @NoArgsConstructor
@ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true)
public abstract class BaseOutputLayer extends FeedForwardLayer {
    protected ILossFunction lossFn;

    protected BaseOutputLayer(Builder builder) {
    	super(builder);
        this.lossFn = builder.lossFn;
    }

    /**
     *
     * @deprecated As of 0.6.0. Use {@link #getLossFn()} instead
     */
    @Deprecated
    public LossFunction getLossFunction() {
        //To maintain backward compatibility only (as much as possible)
        if (lossFn instanceof LossNegativeLogLikelihood) {
            return LossFunction.NEGATIVELOGLIKELIHOOD;
        } else if (lossFn instanceof LossMCXENT) {
            return LossFunction.MCXENT;
        } else if (lossFn instanceof LossMSE) {
            return LossFunction.MSE;
        } else if (lossFn instanceof LossBinaryXENT) {
            return LossFunction.XENT;
        } else {
            //TODO: are there any others??
            return null;
        }
    }


    public static abstract class Builder> extends FeedForwardLayer.Builder {
        protected ILossFunction lossFn = new LossMCXENT();

        public Builder() {}

        public Builder(LossFunction lossFunction) {
            lossFunction(lossFunction);
        }

        public Builder(ILossFunction lossFunction) {
            this.lossFn = lossFunction;
        }

        public T lossFunction(LossFunction lossFunction) {
            return lossFunction(lossFunction.getILossFunction());
        }

        public T lossFunction(ILossFunction lossFunction) {
            this.lossFn = lossFunction;
            return (T)this;
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy