com.expleague.ml.loss.blockwise.BlockwiseWeightedLoss Maven / Gradle / Ivy
package com.expleague.ml.loss.blockwise;
import com.expleague.commons.math.vectors.Vec;
import com.expleague.commons.math.vectors.VecTools;
import com.expleague.ml.BlockwiseFuncC1;
import com.expleague.ml.TargetFunc;
import com.expleague.ml.data.set.DataSet;
/**
* User: qdeee
* Date: 26.11.13
* Time: 9:54
*/
public class BlockwiseWeightedLoss extends BlockwiseFuncC1.Stub implements TargetFunc{
private final BasedOn metric;
private final int[] weights;
public BlockwiseWeightedLoss(final BasedOn metric, final int[] weights) {
if (metric.dim() / metric.blockSize() != weights.length)
throw new IllegalArgumentException("weights.length must be equal to blocks count");
this.metric = metric;
this.weights = weights;
}
@Override
public int dim() {
return metric.xdim();
}
public double weight(final int index) {
return weights[index];
}
public BasedOn base() {
return metric;
}
@Override
public void gradient(final Vec pointBlock, final Vec result, final int index) {
if (weights[index] > 0) {
metric.gradient(pointBlock, result, index);
VecTools.scale(result, weights[index]);
}
}
@Override
public double value(final Vec pointBlock, final int index) {
return weights[index] > 0 ? weights[index] * metric.value(pointBlock, index) : 0;
}
@Override
public double transformResultValue(final double value) {
return metric.transformResultValue(value);
}
@Override
public int blockSize() {
return metric.blockSize();
}
@Override
public DataSet> owner() {
return metric.owner();
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy