org.deeplearning4j.nn.layers.training.CenterLossOutputLayer Maven / Gradle / Ivy
/*-
*
* * Copyright 2015 Skymind,Inc.
* *
* * Licensed under the Apache License, Version 2.0 (the "License");
* * you may not use this file except in compliance with the License.
* * You may obtain a copy of the License at
* *
* * http://www.apache.org/licenses/LICENSE-2.0
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS,
* * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* * See the License for the specific language governing permissions and
* * limitations under the License.
*
*/
package org.deeplearning4j.nn.layers.training;
import org.nd4j.linalg.primitives.Pair;
import org.deeplearning4j.exception.DL4JInvalidInputException;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.BaseOutputLayer;
import org.deeplearning4j.nn.params.CenterLossParamInitializer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.lossfunctions.ILossFunction;
/**
* Center loss is similar to triplet loss except that it enforces
* intraclass consistency and doesn't require feed forward of multiple
* examples. Center loss typically converges faster for training
* ImageNet-based convolutional networks.
*
* "If example x is in class Y, ensure that embedding(x) is close to
* average(embedding(y)) for all examples y in Y"
*
* @author Justin Long (@crockpotveggies)
*/
public class CenterLossOutputLayer extends BaseOutputLayer {
private double fullNetworkL1;
private double fullNetworkL2;
public CenterLossOutputLayer(NeuralNetConfiguration conf) {
super(conf);
}
public CenterLossOutputLayer(NeuralNetConfiguration conf, INDArray input) {
super(conf, input);
}
/** Compute score after labels and input have been set.
* @param fullNetworkL1 L1 regularization term for the entire network
* @param fullNetworkL2 L2 regularization term for the entire network
* @param training whether score should be calculated at train or test time (this affects things like application of
* dropout, etc)
* @return score (loss function)
*/
@Override
public double computeScore(double fullNetworkL1, double fullNetworkL2, boolean training) {
if (input == null || labels == null)
throw new IllegalStateException("Cannot calculate score without input and labels " + layerId());
this.fullNetworkL1 = fullNetworkL1;
this.fullNetworkL2 = fullNetworkL2;
INDArray preOut = preOutput2d(training);
// center loss has two components
// the first enforces inter-class dissimilarity, the second intra-class dissimilarity (squared l2 norm of differences)
ILossFunction interClassLoss = layerConf().getLossFn();
// calculate the intra-class score component
INDArray centers = params.get(CenterLossParamInitializer.CENTER_KEY);
INDArray centersForExamples = labels.mmul(centers);
// double intraClassScore = intraClassLoss.computeScore(centersForExamples, input, Activation.IDENTITY.getActivationFunction(), maskArray, false);
INDArray norm2DifferenceSquared = input.sub(centersForExamples).norm2(1);
norm2DifferenceSquared.muli(norm2DifferenceSquared);
double sum = norm2DifferenceSquared.sumNumber().doubleValue();
double lambda = layerConf().getLambda();
double intraClassScore = lambda / 2.0 * sum;
// intraClassScore = intraClassScore * layerConf().getLambda() / 2;
if (System.getenv("PRINT_CENTERLOSS") != null) {
System.out.println("Center loss is " + intraClassScore);
}
// now calculate the inter-class score component
double interClassScore = interClassLoss.computeScore(getLabels2d(), preOut, layerConf().getActivationFn(),
maskArray, false);
double score = interClassScore + intraClassScore;
score += fullNetworkL1 + fullNetworkL2;
score /= getInputMiniBatchSize();
this.score = score;
return score;
}
/**Compute the score for each example individually, after labels and input have been set.
*
* @param fullNetworkL1 L1 regularization term for the entire network (or, 0.0 to not include regularization)
* @param fullNetworkL2 L2 regularization term for the entire network (or, 0.0 to not include regularization)
* @return A column INDArray of shape [numExamples,1], where entry i is the score of the ith example
*/
@Override
public INDArray computeScoreForExamples(double fullNetworkL1, double fullNetworkL2) {
if (input == null || labels == null)
throw new IllegalStateException("Cannot calculate score without input and labels " + layerId());
INDArray preOut = preOutput2d(false);
// calculate the intra-class score component
INDArray centers = params.get(CenterLossParamInitializer.CENTER_KEY);
INDArray centersForExamples = labels.mmul(centers);
INDArray intraClassScoreArray = input.sub(centersForExamples);
// calculate the inter-class score component
ILossFunction interClassLoss = layerConf().getLossFn();
INDArray scoreArray = interClassLoss.computeScoreArray(getLabels2d(), preOut, layerConf().getActivationFn(),
maskArray);
scoreArray.addi(intraClassScoreArray.muli(layerConf().getLambda() / 2));
double l1l2 = fullNetworkL1 + fullNetworkL2;
if (l1l2 != 0.0) {
scoreArray.addi(l1l2);
}
return scoreArray;
}
@Override
public void computeGradientAndScore() {
if (input == null || labels == null)
return;
INDArray preOut = preOutput2d(true);
Pair pair = getGradientsAndDelta(preOut);
this.gradient = pair.getFirst();
score = computeScore(fullNetworkL1, fullNetworkL2, true);
}
@Override
protected void setScoreWithZ(INDArray z) {
throw new RuntimeException("Not supported " + layerId());
}
@Override
public Pair gradientAndScore() {
return new Pair<>(gradient(), score());
}
@Override
public Pair backpropGradient(INDArray epsilon) {
Pair pair = getGradientsAndDelta(preOutput2d(true)); //Returns Gradient and delta^(this), not Gradient and epsilon^(this-1)
INDArray delta = pair.getSecond();
// centers
INDArray centers = params.get(CenterLossParamInitializer.CENTER_KEY);
INDArray centersForExamples = labels.mmul(centers);
INDArray dLcdai = input.sub(centersForExamples);
INDArray epsilonNext = params.get(CenterLossParamInitializer.WEIGHT_KEY).mmul(delta.transpose()).transpose();
double lambda = layerConf().getLambda();
epsilonNext.addi(dLcdai.muli(lambda)); // add center loss here
return new Pair<>(pair.getFirst(), epsilonNext);
}
/**
* Gets the gradient from one training iteration
* @return the gradient (bias and weight matrix)
*/
@Override
public Gradient gradient() {
return gradient;
}
/** Returns tuple: {Gradient,Delta,Output} given preOut */
private Pair getGradientsAndDelta(INDArray preOut) {
ILossFunction lossFunction = layerConf().getLossFn();
INDArray labels2d = getLabels2d();
if (labels2d.size(1) != preOut.size(1)) {
throw new DL4JInvalidInputException(
"Labels array numColumns (size(1) = " + labels2d.size(1) + ") does not match output layer"
+ " number of outputs (nOut = " + preOut.size(1) + ") " + layerId());
}
INDArray delta = lossFunction.computeGradient(labels2d, preOut, layerConf().getActivationFn(), maskArray);
Gradient gradient = new DefaultGradient();
INDArray weightGradView = gradientViews.get(CenterLossParamInitializer.WEIGHT_KEY);
INDArray biasGradView = gradientViews.get(CenterLossParamInitializer.BIAS_KEY);
INDArray centersGradView = gradientViews.get(CenterLossParamInitializer.CENTER_KEY);
// centers delta
double alpha = layerConf().getAlpha();
INDArray centers = params.get(CenterLossParamInitializer.CENTER_KEY);
INDArray centersForExamples = labels.mmul(centers);
INDArray diff = centersForExamples.sub(input).muli(alpha);
INDArray numerator = labels.transpose().mmul(diff);
INDArray denominator = labels.sum(0).addi(1.0).transpose();
INDArray deltaC;
if (layerConf().getGradientCheck()) {
double lambda = layerConf().getLambda();
//For gradient checks: need to multiply dLc/dcj by lambda to get dL/dcj
deltaC = numerator.muli(lambda);
} else {
deltaC = numerator.diviColumnVector(denominator);
}
centersGradView.assign(deltaC);
// other standard calculations
Nd4j.gemm(input, delta, weightGradView, true, false, 1.0, 0.0); //Equivalent to: weightGradView.assign(input.transpose().mmul(delta));
delta.sum(biasGradView, 0); //biasGradView is initialized/zeroed first in sum op
gradient.gradientForVariable().put(CenterLossParamInitializer.WEIGHT_KEY, weightGradView);
gradient.gradientForVariable().put(CenterLossParamInitializer.BIAS_KEY, biasGradView);
gradient.gradientForVariable().put(CenterLossParamInitializer.CENTER_KEY, centersGradView);
return new Pair<>(gradient, delta);
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy