All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.linalg.lossfunctions.impl.LossMultiLabel Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
/*******************************************************************************
 * Copyright (c) 2015-2018 Skymind, Inc.
 *
 * This program and the accompanying materials are made available under the
 * terms of the Apache License, Version 2.0 which is available at
 * https://www.apache.org/licenses/LICENSE-2.0.
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 *
 * SPDX-License-Identifier: Apache-2.0
 ******************************************************************************/

package org.nd4j.linalg.lossfunctions.impl;

import lombok.EqualsAndHashCode;
import lombok.Getter;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.lossfunctions.ILossFunction;
import org.nd4j.linalg.lossfunctions.LossUtil;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.shade.jackson.annotation.JsonInclude;

/**
 * Multi-Label-Loss Function, maybe more commonly known as BPMLL
 * 

* This Loss function requires that the Labels are given as a multi-hot encoded vector. It doesn't require any special * Activation method, i.e. the network output doesn't have to be in any specific range. *

* The loss is calculated based on the classification difference on labels that the examples has, and those that it * doesn't have. Assume that each example has a set of labels, these labels are the positive set, the labels that do not * belong to the example are in the negative set. This loss function trains the network to produce a higher value for * labels that are in the positive set than those that are in the negative set. *

* Notice that in order to learn anything at all, this loss function requires that your example labels are not * all 0 or all 1. In these cases the loss gradient will be 0. If you have to work with examples like that, you should * try using a ComputationGraph with two LossLayers, one using LossMultiLabel and the other one using LossBinaryXENT. *

* For a more detailed explanation and the actual formulas, read the original paper by Zhang and Zhou. The * implementation on scoreArray is based on equation 3, while computeGradient is based on equation 11. The main * difference being that -(c_k - c_l) = (c_l - c_k) was used to simplify the calculations. *

* Min-Ling Zhang and Zhi-Hua Zhou, "Multilabel Neural Networks with Applications to Functional Genomics and Text * Categorization," in IEEE Transactions on Knowledge and Data Engineering, vol. 18, no. 10, pp. 1338-1351, Oct. 2006. * doi: 10.1109/TKDE.2006.162 * * * * @author Paul Dubs */ @EqualsAndHashCode @JsonInclude(JsonInclude.Include.NON_NULL) @Getter public class LossMultiLabel implements ILossFunction { public LossMultiLabel() { } private void calculate(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask, INDArray scoreOutput, INDArray gradientOutput) { if (scoreOutput == null && gradientOutput == null) { throw new IllegalArgumentException("You have to provide at least one of scoreOutput or gradientOutput!"); } if (labels.size(1) != preOutput.size(1)) { throw new IllegalArgumentException( "Labels array numColumns (size(1) = " + labels.size(1) + ") does not match output layer" + " number of outputs (nOut = " + preOutput.size(1) + ") "); } labels = labels.castTo(preOutput.dataType()); //No-op if already correct dtype final INDArray postOutput = activationFn.getActivation(preOutput.dup(), true); final INDArray positive = labels; final INDArray negative = labels.eq(0.0).castTo(Nd4j.defaultFloatingPointType()); final INDArray normFactor = negative.sum(true,1).castTo(Nd4j.defaultFloatingPointType()).muli(positive.sum(true,1)); long examples = positive.size(0); for (int i = 0; i < examples; i++) { final INDArray locCfn = postOutput.getRow(i, true); final long[] shape = locCfn.shape(); final INDArray locPositive = positive.getRow(i, true); final INDArray locNegative = negative.getRow(i, true); final Double locNormFactor = normFactor.getDouble(i); final int outSetSize = locNegative.sumNumber().intValue(); if(outSetSize == 0 || outSetSize == locNegative.columns()){ if (scoreOutput != null) { scoreOutput.getRow(i, true).assign(0); } if (gradientOutput != null) { gradientOutput.getRow(i, true).assign(0); } }else { final INDArray operandA = Nd4j.ones(shape[1], shape[0]).mmul(locCfn); final INDArray operandB = operandA.transpose(); final INDArray pairwiseSub = Transforms.exp(operandA.sub(operandB)); final INDArray selection = locPositive.transpose().mmul(locNegative); final INDArray classificationDifferences = pairwiseSub.muli(selection).divi(locNormFactor); if (scoreOutput != null) { if (mask != null) { final INDArray perLabel = classificationDifferences.sum(0); LossUtil.applyMask(perLabel, mask.getRow(i, true)); perLabel.sum(scoreOutput.getRow(i, true), 0); } else { classificationDifferences.sum(scoreOutput.getRow(i, true), 0, 1); } } if (gradientOutput != null) { gradientOutput.getRow(i, true).assign(classificationDifferences.sum(true, 0).addi(classificationDifferences.sum(true,1).transposei().negi())); } } } if (gradientOutput != null) { gradientOutput.assign(activationFn.backprop(preOutput.dup(), gradientOutput).getFirst()); //multiply with masks, always if (mask != null) { LossUtil.applyMask(gradientOutput, mask); } } } public INDArray scoreArray(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) { final INDArray scoreArr = Nd4j.create(labels.size(0), 1); calculate(labels, preOutput, activationFn, mask, scoreArr, null); return scoreArr; } @Override public double computeScore(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask, boolean average) { INDArray scoreArr = scoreArray(labels, preOutput, activationFn, mask); double score = scoreArr.sumNumber().doubleValue(); if (average) score /= scoreArr.size(0); return score; } @Override public INDArray computeScoreArray(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) { INDArray scoreArr = scoreArray(labels, preOutput, activationFn, mask); return scoreArr.sum(true,1); } @Override public INDArray computeGradient(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) { labels = labels.castTo(preOutput.dataType()); //No-op if already correct dtype if (labels.size(1) != preOutput.size(1)) { throw new IllegalArgumentException( "Labels array numColumns (size(1) = " + labels.size(1) + ") does not match output layer" + " number of outputs (nOut = " + preOutput.size(1) + ") "); } final INDArray grad = Nd4j.ones(labels.shape()); calculate(labels, preOutput, activationFn, mask, null, grad); return grad; } @Override public Pair computeGradientAndScore(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask, boolean average) { final INDArray scoreArr = Nd4j.create(labels.size(0), 1); final INDArray grad = Nd4j.ones(labels.shape()); calculate(labels, preOutput, activationFn, mask, scoreArr, grad); double score = scoreArr.sumNumber().doubleValue(); if (average) score /= scoreArr.size(0); return new Pair<>(score, grad); } @Override public String name() { return toString(); } @Override public String toString() { return "LossMultiLabel"; } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy