com.expleague.ml.methods.multiclass.hierarchical.HierarchicalRefinedClassification Maven / Gradle / Ivy
package com.expleague.ml.methods.multiclass.hierarchical;
import com.expleague.commons.math.Func;
import com.expleague.commons.math.Trans;
import com.expleague.commons.math.vectors.Vec;
import com.expleague.commons.math.vectors.VecTools;
import com.expleague.commons.seq.IntSeq;
import com.expleague.commons.util.tree.IntTreeVisitor;
import com.expleague.ml.data.set.VecDataSet;
import com.expleague.ml.loss.blockwise.BlockwiseMLLLogit;
import com.expleague.ml.methods.VecOptimization;
import com.expleague.ml.models.multiclass.HierarchicalModel;
import com.expleague.ml.models.multiclass.JoinedBinClassModel;
import com.expleague.ml.models.multiclass.MCModel;
import com.expleague.commons.util.Pair;
import com.expleague.commons.util.tree.IntTree;
import com.expleague.ml.data.tools.DataTools;
import com.expleague.ml.data.tools.MCTools;
import com.expleague.ml.loss.LLLogit;
import gnu.trove.iterator.TIntIterator;
import gnu.trove.list.TIntList;
import gnu.trove.list.array.TIntArrayList;
import gnu.trove.list.linked.TIntLinkedList;
import java.util.Arrays;
import java.util.Stack;
/**
* User: qdeee
* Date: 10.04.14
*/
public class HierarchicalRefinedClassification extends VecOptimization.Stub {
protected final VecOptimization weakBinClass;
protected final VecOptimization weakMultiClass;
protected final IntTree tree;
public HierarchicalRefinedClassification(final VecOptimization weakBinClass,
final VecOptimization weakMultiClass,
final IntTree tree) {
this.weakBinClass = weakBinClass;
this.weakMultiClass = weakMultiClass;
this.tree = tree;
}
@Override
public Trans fit(final VecDataSet learn, final BlockwiseMLLLogit globalLoss) {
final SpecialHierModel hierJoinedBinClassModel = firstTraverse(learn, globalLoss);
final HierarchicalModel hierarchicalModel = secondTraverse(learn, globalLoss, hierJoinedBinClassModel);
return hierarchicalModel;
}
private SpecialHierModel firstTraverse(final VecDataSet learn, final BlockwiseMLLLogit globalLoss) {
final int[] localClasses = new int[learn.length()];
final Stack modelsStack = new Stack<>();
final IntTreeVisitor visitor = new IntTreeVisitor() {
@Override
public void visit(final int node) {
final TIntList uniqClasses = new TIntLinkedList();
{
final TIntIterator children = tree.getChildren(node);
while (children.hasNext()) {
uniqClasses.add(children.next());
}
}
for (int i = 0; i < learn.length(); i++) {
final int dsClassLabel = globalLoss.label(i);
localClasses[i] = -1;
final TIntIterator children = tree.getChildren(node);
for (int j = 0; children.hasNext(); j++) {
final int child = children.next();
if (dsClassLabel == child || tree.isDescendant(dsClassLabel, child)) {
localClasses[i] = j;
break;
}
}
}
final Func[] models = new Func[uniqClasses.size()];
for (int j = 0; j < uniqClasses.size(); j++) {
final VecDataSet dsForLearn;
final Vec targetForLearn;
final Vec oneVsRestTarget = MCTools.extractClassForBinary(new IntSeq(localClasses), j);
if (node != tree.ROOT) {
final TIntList dsIdxs = new TIntLinkedList();
for (int i = 0; i < learn.length(); i++) {
if (localClasses[i] == -1 || localClasses[i] == j)
dsIdxs.add(i); //everyone exclude siblings
}
final Pair pair = DataTools.createSubset(learn, oneVsRestTarget, dsIdxs.toArray());
dsForLearn = pair.first;
targetForLearn = pair.second;
}
else {
dsForLearn = learn;
targetForLearn = oneVsRestTarget;
}
models[j] = (Func) weakBinClass.fit(dsForLearn, new LLLogit(targetForLearn, learn));
}
final SpecialHierModel nodeModel = new SpecialHierModel(new JoinedBinClassModel(models), uniqClasses);
modelsStack.push(nodeModel);
final TIntIterator children = tree.getChildren(node);
while (children.hasNext()) {
final int child = children.next();
if (tree.hasChildren(child)) {
tree.accept(this, child);
nodeModel.addChild(modelsStack.pop(), child);
}
}
}
};
tree.accept(visitor, tree.ROOT);
return modelsStack.pop();
}
private HierarchicalModel secondTraverse(final VecDataSet learn, final BlockwiseMLLLogit globalLoss, final SpecialHierModel firstModel) {
final int[] localClasses = new int[learn.length()];
final Stack errorsStack = new Stack<>();
final Stack cleanModelsStack = new Stack<>();
final Stack refinedModelsStack = new Stack<>();
cleanModelsStack.push(firstModel);
final IntTreeVisitor visitor = new IntTreeVisitor() {
@Override
public void visit(final int node) {
final int[] errors = errorsStack.size() > 0 ? Arrays.copyOf(errorsStack.peek(), learn.length()) : new int[learn.length()];
final TIntList uniqClasses = new TIntLinkedList();
{
final TIntIterator children = tree.getChildren(node);
while (children.hasNext()) {
uniqClasses.add(children.next());
}
}
final SpecialHierModel cleanModel = cleanModelsStack.pop();
final TIntList dsIdxs = new TIntLinkedList();
for (int i = 0; i < learn.length(); i++) {
localClasses[i] = -1;
if (errors[i] == 1) {
continue; //skip errors from top
}
final int dsClassLabel = globalLoss.label(i);
final TIntIterator children = tree.getChildren(node);
for (int j = 0; children.hasNext(); j++) {
final int child = children.next();
if (dsClassLabel == child || tree.isDescendant(dsClassLabel, child)) {
dsIdxs.add(i);
localClasses[i] = j;
break;
}
}
if (dsIdxs.get(dsIdxs.size() - 1) == i) {
final int predictedLocalClass = cleanModel.bestClass(learn.at(i));
if (localClasses[i] != predictedLocalClass) {
localClasses[i] = uniqClasses.size();
errors[i] = 1; //for next levels
}
}
}
if (MCTools.classEntriesCount(new IntSeq(localClasses), uniqClasses.size()) > 0) {
uniqClasses.add(-1);
}
final Pair pair = DataTools.createSubset(learn, new IntSeq(localClasses), dsIdxs.toArray());
final MCModel model = (MCModel) weakMultiClass.fit(pair.first, new BlockwiseMLLLogit(pair.second, learn));
final HierarchicalModel refinedModel = new HierarchicalModel(model, new TIntArrayList(uniqClasses));
refinedModelsStack.push(refinedModel); //for top levels
errorsStack.push(errors); //for bottom levels
final TIntIterator children = tree.getChildren(node);
while (children.hasNext()) {
final int child = children.next();
if (tree.hasChildren(child)) {
cleanModelsStack.push((SpecialHierModel) cleanModel.getChild(child));
tree.accept(this, child);
refinedModel.addChild(refinedModelsStack.pop(), child);
}
}
errorsStack.pop();
}
};
tree.accept(visitor, tree.ROOT);
return refinedModelsStack.pop();
}
private static class SpecialHierModel extends HierarchicalModel {
public SpecialHierModel(final JoinedBinClassModel basedOn, final TIntList classLabels) {
super(basedOn, classLabels);
}
//we need to accumulate signals from bottom levels
private Vec deepTrans(final Vec x) {
final Vec trans = ((JoinedBinClassModel) basedOn).getInternModel().trans(x);
for (int i = 0; i < classLabels.size(); i++) {
final int label = classLabels.get(i);
final SpecialHierModel model = (SpecialHierModel) label2childModel.get(label);
if (model != null) {
final double val = VecTools.sum(model.deepTrans(x));
trans.adjust(i, val);
}
}
return trans;
}
@Override
public int bestClass(final Vec x) {
final Vec deepTrans = deepTrans(x);
return VecTools.argmax(deepTrans); //return local best class index (not label)
}
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy