org.deeplearning4j.nn.conf.layers.BaseOutputLayer Maven / Gradle / Ivy
package org.deeplearning4j.nn.conf.layers;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor;
import lombok.ToString;
import org.nd4j.linalg.lossfunctions.ILossFunction;
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
import org.nd4j.linalg.lossfunctions.impl.*;
@Data @NoArgsConstructor
@ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true)
public abstract class BaseOutputLayer extends FeedForwardLayer {
protected ILossFunction lossFn;
protected BaseOutputLayer(Builder builder) {
super(builder);
this.lossFn = builder.lossFn;
}
/**
*
* @deprecated As of 0.6.0. Use {@link #getLossFn()} instead
*/
@Deprecated
public LossFunction getLossFunction() {
//To maintain backward compatibility only (as much as possible)
if (lossFn instanceof LossNegativeLogLikelihood) {
return LossFunction.NEGATIVELOGLIKELIHOOD;
} else if (lossFn instanceof LossMCXENT) {
return LossFunction.MCXENT;
} else if (lossFn instanceof LossMSE) {
return LossFunction.MSE;
} else if (lossFn instanceof LossBinaryXENT) {
return LossFunction.XENT;
} else {
//TODO: are there any others??
return null;
}
}
public static abstract class Builder> extends FeedForwardLayer.Builder {
protected ILossFunction lossFn = new LossMCXENT();
public Builder() {}
public Builder(LossFunction lossFunction) {
lossFunction(lossFunction);
}
public Builder(ILossFunction lossFunction) {
this.lossFn = lossFunction;
}
public T lossFunction(LossFunction lossFunction) {
return lossFunction(lossFunction.getILossFunction());
}
public T lossFunction(ILossFunction lossFunction) {
this.lossFn = lossFunction;
return (T)this;
}
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy