com.expleague.ml.loss.DSSumFuncComposite Maven / Gradle / Ivy
package com.expleague.ml.loss;
import com.expleague.commons.func.Computable;
import com.expleague.commons.math.vectors.impl.vectors.ArrayVec;
import com.expleague.commons.math.vectors.Vec;
import com.expleague.commons.math.vectors.VecTools;
import com.expleague.ml.BlockedTargetFunc;
import com.expleague.commons.math.FuncC1;
import com.expleague.commons.math.Trans;
import com.expleague.commons.math.TransC1;
import com.expleague.ml.data.set.DataSet;
/**
* User: solar
* Date: 01.06.15
* Time: 13:23
*/
public class DSSumFuncComposite- extends DSSumFuncC1
- implements FuncC1 {
private final BlockedTargetFunc tgt;
private final Computable
- decisionFactory;
private final int dim;
public DSSumFuncComposite(DataSet
- ds, BlockedTargetFunc tgt, Computable
- decisionFactory) {
super(ds);
this.tgt = tgt;
this.decisionFactory = decisionFactory;
dim = decisionFactory.compute(ds.at(0)).xdim();
}
@Override
public CompositeFunc component(int index) {
return new CompositeFunc((FuncC1)tgt.block(index), decisionFactory.compute(ds.at(index)));
}
public Decision decision(final Vec x) {
return new Decision(x);
}
@Override
public Vec gradient(Vec x) {
final Vec result = new ArrayVec(dim());
final int length = length();
for (int i = 0; i < length; i++){
VecTools.append(result, component(i).gradient(x));
}
return result;
}
public int dim() {
return dim;
}
public class Decision implements Computable
- {
public final Vec x;
public Decision(Vec x) {
this.x = x;
}
@Override
public Vec compute(Item argument) {
final Trans compute = decisionFactory.compute(argument);
return compute.trans(x);
}
}
}