All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.expleague.ml.methods.multiclass.hierarchical.HierarchicalClassification Maven / Gradle / Ivy

package com.expleague.ml.methods.multiclass.hierarchical;

import com.expleague.commons.math.Trans;
import com.expleague.commons.seq.IntSeq;
import com.expleague.commons.util.tree.IntTreeVisitor;
import com.expleague.ml.data.set.VecDataSet;
import com.expleague.ml.loss.blockwise.BlockwiseMLLLogit;
import com.expleague.ml.methods.VecOptimization;
import com.expleague.ml.models.MultiClassModel;
import com.expleague.ml.models.multiclass.HierarchicalModel;
import com.expleague.commons.util.Pair;
import com.expleague.commons.util.tree.IntTree;
import com.expleague.ml.data.tools.DataTools;
import gnu.trove.iterator.TIntIterator;
import gnu.trove.list.TIntList;
import gnu.trove.list.array.TIntArrayList;
import gnu.trove.list.linked.TIntLinkedList;

import java.util.Stack;

/**
 * User: qdeee
 * Date: 06.02.14
 */
public class HierarchicalClassification extends VecOptimization.Stub {
  protected final VecOptimization weakMultiClass;
  protected final IntTree tree;

  public HierarchicalClassification(final VecOptimization weakMultiClass, final IntTree tree) {
    this.weakMultiClass = weakMultiClass;
    this.tree = tree;
  }

  @Override
  public Trans fit(final VecDataSet learn, final BlockwiseMLLLogit globalLoss) {
    final int[] localClasses = new int[learn.length()];
    final Stack modelsStack = new Stack();

    final IntTreeVisitor visitor = new IntTreeVisitor() {
      @Override
      public void visit(final int node) {
        final TIntList dsIdxs = new TIntLinkedList();

        for (int i = 0; i < learn.length(); i++) {
          final int dsClassLabel = globalLoss.label(i);
          final TIntIterator children = tree.getChildren(node);
          for (int j = 0; children.hasNext(); j++) {
            final int child = children.next();
            if (dsClassLabel == child || tree.isDescendant(dsClassLabel, child)) {
              dsIdxs.add(i);
              localClasses[i] = j;
              break;
            }
          }
        }

        final Pair pair = DataTools.createSubset(learn, new IntSeq(localClasses), dsIdxs.toArray());
        final MultiClassModel model = (MultiClassModel) weakMultiClass.fit(pair.first, new BlockwiseMLLLogit(pair.second, learn));

        final TIntList labels = new TIntLinkedList();
        final TIntIterator childrenIter = tree.getChildren(node);
        while (childrenIter.hasNext()) {
          labels.add(childrenIter.next());
        }

        final HierarchicalModel hierarchicalModel = new HierarchicalModel(model, new TIntArrayList(labels));

        modelsStack.push(hierarchicalModel);
        final TIntIterator children = tree.getChildren(node);
        while (children.hasNext()) {
          final int child = children.next();
          if (tree.hasChildren(child)) {
            tree.accept(this, child);
            hierarchicalModel.addChild(modelsStack.pop(), child);
          }
        }
      }
    };
    tree.accept(visitor, tree.ROOT);
    return modelsStack.pop();
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy