org.nd4j.linalg.lossfunctions.LossFunctions Maven / Gradle / Ivy
/*
*
* * Copyright 2015 Skymind,Inc.
* *
* * Licensed under the Apache License, Version 2.0 (the "License");
* * you may not use this file except in compliance with the License.
* * You may obtain a copy of the License at
* *
* * http://www.apache.org/licenses/LICENSE-2.0
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS,
* * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* * See the License for the specific language governing permissions and
* * limitations under the License.
*
*
*/
package org.nd4j.linalg.lossfunctions;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.lossfunctions.impl.*;
import java.util.Arrays;
import static org.nd4j.linalg.ops.transforms.Transforms.*;
/**
* Central class for loss functions
* @author Adam Gibson
*/
public class LossFunctions {
/**
* Generic scoring function.
* Note that an IllegalArgumentException is thrown if the given
* loss function is custom. An alternative mechanism for scoring
* (preferrably with a function name and the op factory) should be used instead.
*
* @param labels the labels to score
* @param lossFunction the loss function to use
* @param z the output function
* @param l2 the l2 regularization term (0.5 * l2Coeff * sum w^2)
* @param l1 the l1 regularization term (l1Coeff * sum |w|)
* @param useRegularization whether to use regularization
* @return the score for the given parameters
*/
public static double score(INDArray labels, LossFunction lossFunction, INDArray z, double l2, double l1,boolean useRegularization) {
return LossCalculation.builder()
.l1(l1).lossFunction(lossFunction)
.l2(l2).labels(labels)
.z(z)
.useRegularization(useRegularization)
.build().score();
}
/**
* MSE: Mean Squared Error: Linear Regression
* EXPLL: Exponential log likelihood: Poisson Regression
* XENT: Cross Entropy: Binary Classification
* MCXENT: Multiclass Cross Entropy
* RMSE_XENT: RMSE Cross Entropy
* SQUARED_LOSS: Squared Loss
* NEGATIVELOGLIKELIHOOD: Negative Log Likelihood
*/
public enum LossFunction {
MSE,
L1,
@Deprecated
EXPLL,
XENT,
MCXENT,
@Deprecated
RMSE_XENT,
SQUARED_LOSS,
RECONSTRUCTION_CROSSENTROPY,
NEGATIVELOGLIKELIHOOD,
@Deprecated
CUSTOM,
COSINE_PROXIMITY,
HINGE,
SQUARED_HINGE,
KL_DIVERGENCE,
MEAN_ABSOLUTE_ERROR,
L2,
MEAN_ABSOLUTE_PERCENTAGE_ERROR,
MEAN_SQUARED_LOGARITHMIC_ERROR,
POISSON;
public ILossFunction getILossFunction(){
switch(this){
case MSE:
case SQUARED_LOSS:
return new LossMSE();
case L1:
return new LossL1();
case XENT:
return new LossBinaryXENT();
case MCXENT:
return new LossMCXENT();
case KL_DIVERGENCE:
case RECONSTRUCTION_CROSSENTROPY:
return new LossKLD();
case NEGATIVELOGLIKELIHOOD:
return new LossNegativeLogLikelihood();
case COSINE_PROXIMITY:
return new LossCosineProximity();
case HINGE:
return new LossHinge();
case SQUARED_HINGE:
return new LossSquaredHinge();
case MEAN_ABSOLUTE_ERROR:
return new LossMAE();
case L2:
return new LossL2();
case MEAN_ABSOLUTE_PERCENTAGE_ERROR:
return new LossMAPE();
case MEAN_SQUARED_LOGARITHMIC_ERROR:
return new LossMSLE();
case POISSON:
case EXPLL:
return new LossPoisson();
default:
//Custom, RMSE_XENT
throw new UnsupportedOperationException("Unknown or not supported loss function: " + this);
}
}
}
}