All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.expleague.ml.data.tools.MCTools Maven / Gradle / Ivy

package com.expleague.ml.data.tools;

import com.expleague.commons.func.Computable;
import com.expleague.commons.math.Func;
import com.expleague.commons.math.Trans;
import com.expleague.commons.math.metrics.Metric;
import com.expleague.commons.math.metrics.impl.CosineDVectorMetric;
import com.expleague.commons.math.vectors.Mx;
import com.expleague.commons.math.vectors.Vec;
import com.expleague.commons.math.vectors.VecTools;
import com.expleague.commons.math.vectors.impl.vectors.ArrayVec;
import com.expleague.commons.seq.IntSeq;
import com.expleague.ml.data.set.VecDataSet;
import com.expleague.ml.func.FuncEnsemble;
import com.expleague.ml.func.FuncJoin;
import com.expleague.ml.loss.blockwise.BlockwiseMLLLogit;
import com.expleague.ml.loss.multiclass.MCMicroF1Score;
import com.expleague.ml.meta.items.QURLItem;
import com.expleague.ml.models.multiclass.JoinedBinClassModel;
import com.expleague.ml.models.multiclass.MCModel;
import com.expleague.commons.math.vectors.impl.mx.VecBasedMx;
import com.expleague.commons.text.StringUtils;
import com.expleague.commons.util.ArrayTools;
import com.expleague.commons.util.Pair;
import com.expleague.ml.func.Ensemble;
import com.expleague.ml.loss.L2;
import com.expleague.ml.loss.multiclass.util.ConfusionMatrix;
import gnu.trove.iterator.TIntIterator;
import gnu.trove.list.TDoubleList;
import gnu.trove.list.TIntList;
import gnu.trove.list.array.TIntArrayList;
import gnu.trove.list.linked.TIntLinkedList;
import gnu.trove.map.TIntIntMap;
import gnu.trove.map.TIntObjectMap;
import gnu.trove.map.hash.TIntIntHashMap;
import gnu.trove.map.hash.TIntObjectHashMap;

import java.io.IOException;
import java.util.Arrays;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;

import static java.lang.Math.max;

/**
 * User: qdeee
 * Date: 04.06.14
 */
public class MCTools {
  public static int countClasses(final IntSeq target) {
    int classesCount = 0;
    for (int i = 0; i < target.length(); i++) {
      classesCount = max(target.at(i) + 1, classesCount);
    }
    return classesCount;
  }

  /**
   * calculate classes entries counts
   * @param target normalized(!) target with class labels from {0,...,K-1}
   * @return array with counts
   */
  public static int[] classEntriesCounts(final IntSeq target) {
    final int[] counts = new int[countClasses(target)];
    for (int i = 0; i < target.length(); i++) {
      counts[target.arr[i]]++;
    }
    return counts;
  }

  public static int classEntriesCount(final IntSeq target, final int classNo) {
    int result = 0;
    for (int i = 0; i < target.length(); i++) {
      if (target.at(i) == classNo)
        result++;
    }
    return result;
  }

  public static Vec extractClassForBinary(final IntSeq target, final int classNo) {
    final Vec result = new ArrayVec(target.length());
    for (int i = 0; i < target.length(); i++)
      result.set(i, (target.at(i) == classNo) ? 1. : -1.);
    return result;
  }

  /**
   *
   * @param target MC target with any classes labels
   * @return classes labels corresponding their order (uniq)
   */
  public static int[] getClassesLabels(final IntSeq target) {
    final TIntList labels = new TIntArrayList();
    for (int i = 0; i < target.length(); i++) {
      final int label = target.at(i);
      if (!labels.contains(label)) {
        labels.add(label);
      }
    }
    return labels.toArray();
  }

  public static int[] getClassLabels(final Vec target) {
    final TIntList labels = new TIntArrayList();
    for (int i = 0; i < target.length(); i++) {
      final int label = target.at(i).intValue();
      if (!labels.contains(label)) {
        labels.add(label);
      }
    }
    return labels.toArray();
  }

  /**
   * Normalization of multiclass target. Target may contain any labels. Notice that error class (-1) will be mapped to the class K.
   * Example: if target contains {10, 10, 6, 4, -1, -1} then result is {2, 2, 1, 0, 3, 3} and map will be filled {(4->0), (6->1), (10->2), (-1->3)}
   * @param target Target vec with any class labels.
   * @param labelsMap Empty map which will be filled here by pairs (label, normalizedLabel).
   * @return new target with classes labels from {0..K}.
   */
  public static IntSeq normalizeTarget(final IntSeq target, final TIntIntMap labelsMap) {
    for (int i = 0; i < target.length(); i++) {
      labelsMap.putIfAbsent(target.arr[i], 0);
    }
    labelsMap.remove(-1);

    final int[] labels = labelsMap.keys();
    Arrays.sort(labels);
    for (int i = 0; i < labels.length; i++) {
      labelsMap.put(labels[i], i);
    }
    labelsMap.put(-1, labels.length);

    final int[] newTarget = new int[target.length()];
    for (int i = 0; i < target.length(); i++) {
      newTarget[i] = labelsMap.get(target.arr[i]);
    }
    return new IntSeq(newTarget);
  }

  public static TIntObjectMap splitClassesIdxs(final IntSeq target) {
    final TIntObjectMap indexes = new TIntObjectHashMap();
    for (int i = 0; i < target.length(); i++) {
      final int label = target.at(i);
      if (indexes.containsKey(label)) {
        indexes.get(label).add(i);
      }
      else {
        final TIntList newClassIdxs = new TIntLinkedList();
        newClassIdxs.add(i);
        indexes.put(label, newClassIdxs);
      }
    }
    return indexes;
  }

  private static double normalizeRelevance(final double y) {
    if (y <= 0.0)
      return 0.;
//    else if (y < 0.14)
//      return 1.;
//    else if (y < 0.41)
//      return 2.;
//    else if (y < 0.61)
//      return 3.;
//    else
//      return 4.;
    return 1.;
  }


  public static IntSeq transformRegressionToMC(final Vec regressionTarget, final int classCount, final TDoubleList borders) throws IOException {
    final double[] target = regressionTarget.toArray();
    final int[] idxs = ArrayTools.sequence(0, target.length);
    ArrayTools.parallelSort(target, idxs);

    if (borders.size() == 0) {
      final double min = target[0];
      final double max = target[target.length - 1];
      final double delta = (max - min) / classCount;
      for (int i = 0; i < classCount; i++) {
        borders.add(delta * (i + 1));
      }
    }

    final int[] resultTarget = new int[target.length];
    int targetCursor = 0;
    for (int borderCursor = 0; borderCursor < borders.size(); borderCursor++){
      while (targetCursor < target.length && target[targetCursor] <= borders.get(borderCursor)) {
        resultTarget[idxs[targetCursor]] = borderCursor;
        targetCursor++;
      }
    }
    return new IntSeq(resultTarget);
  }

  public static Pair loadRegressionAsMC(final String file, final int classCount, final TDoubleList borders)  throws IOException{
    final Pool pool = DataTools.loadFromFeaturesTxt(file);
    return Pair.create(pool.vecData(), transformRegressionToMC(pool.target(L2.class).target, classCount, borders));
  }

  public static Mx createSimilarityMatrixParallels(final VecDataSet learn, final IntSeq target, final Metric metric) {
    final ExecutorService executor = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors());
    final TIntObjectMap indexes = splitClassesIdxs(target);
    final int k = indexes.keys().length;
    final Mx S = new VecBasedMx(k, k);
    for (int i = 0; i < k; i++) {
      final TIntList classIdxsI = indexes.get(i);

      for (int j = i; j < k; j++) {
        final TIntList classIdxsJ = indexes.get(j);
        final int iCopy = i;
        final int jCopy = j;

        executor.submit(new Runnable() {
          @Override
          public void run() {
            double value = 0.;

            for (final TIntIterator iterI = classIdxsI.iterator(); iterI.hasNext(); ) {
              final int i1 = iterI.next();
              for (final TIntIterator iterJ = classIdxsJ.iterator(); iterJ.hasNext(); ) {
                final int i2 = iterJ.next();
                value += 1 - metric.distance(learn.data().row(i1), learn.data().row(i2));
              }
            }
            value /= classIdxsI.size() * classIdxsJ.size();
            S.set(iCopy, jCopy, value);
            S.set(jCopy, iCopy, value);
          }
        });
      }
    }

    executor.shutdown();
    try {
      executor.awaitTermination(1000, TimeUnit.HOURS);
    } catch (InterruptedException e) {
      throw new RuntimeException(e);
    }
    return S;
  }

  public static Mx createSimilarityMatrix(final VecDataSet learn, final IntSeq target) {
    final TIntObjectMap indexes = splitClassesIdxs(target);
    final Metric metric = new CosineDVectorMetric();
    final int k = indexes.keys().length;
    final Mx S = new VecBasedMx(k, k);
    for (int i = 0; i < k; i++) {
      final TIntList classIdxsI = indexes.get(i);
      for (int j = i; j < k; j++) {
        final TIntList classIdxsJ = indexes.get(j);
        double value = 0.;
        for (final TIntIterator iterI = classIdxsI.iterator(); iterI.hasNext(); ) {
          final int i1 = iterI.next();
          for (final TIntIterator iterJ = classIdxsJ.iterator(); iterJ.hasNext(); ) {
            final int i2 = iterJ.next();
            value += 1 - metric.distance(learn.data().row(i1), learn.data().row(i2));
          }
        }
        value /= classIdxsI.size() * classIdxsJ.size();
        S.set(i, j, value);
        S.set(j, i, value);
      }
      System.out.println("class " + i + " is finished!");
    }
    return S;
  }

  public static String evalModel(final MCModel model, final Pool ds, final String prefixComment, final boolean oneLine) {
    final Vec predict = model.bestClassAll(ds.vecData().data());
    final TIntIntMap labelsMap = new TIntIntHashMap();
    final ConfusionMatrix confusionMatrix = new ConfusionMatrix(
        normalizeTarget(ds.target(BlockwiseMLLLogit.class).labels(), labelsMap),
        mapTarget(VecTools.toIntSeq(predict), labelsMap)
    );
    if (oneLine) {
      return prefixComment + confusionMatrix.debug();
    } else {
      return "\n==========" + prefixComment +
          StringUtils.repeatWithDelimeter("", "=", 100 - 10 - prefixComment.length()) + "\n" +
          confusionMatrix.toSummaryString() + "\n" +
          confusionMatrix.toClassDetailsString() +
          StringUtils.repeatWithDelimeter("", "=", 100);
    }
  }

  //only for FuncJoin models
  public static FuncJoin joinBoostingResult(final Ensemble ensemble) {
    if (ensemble.last() instanceof FuncJoin) {
      final int modelsCount = ensemble.ydim();
      final Func[] joinedModels = new Func[modelsCount];
      final Func[][] transpose = new Func[modelsCount][ensemble.size()];
      for (int iter = 0; iter < ensemble.size(); iter++) {
        final FuncJoin model = (FuncJoin) ensemble.models[iter];
        final Func[] sourceFunctions = model.dirs();
        for (int c = 0; c < modelsCount; c++) {
          transpose[c][iter] = sourceFunctions[c];
        }
      }
      for (int i = 0; i < joinedModels.length; i++) {
        joinedModels[i] = new FuncEnsemble(transpose[i], ensemble.weights);
      }
      return new FuncJoin(joinedModels);
    }
    else
      throw new ClassCastException("Ensemble object does not contain FuncJoin objects");
  }

  public static IntSeq mapTarget(final IntSeq intTarget, final TIntIntMap mapping) {
    final int[] mapped = new int[intTarget.length()];
    for (int i = 0; i < intTarget.length(); i++) {
      mapped[i] = mapping.get(intTarget.at(i));
    }
    return new IntSeq(mapped);
  }

  public static void makeOneVsRestReport(final Pool learn, final Pool test, final JoinedBinClassModel joinedBinClassModel, final int period) {
    if (!(joinedBinClassModel.getInternModel().dirs[0] instanceof FuncEnsemble)) {
      throw new IllegalArgumentException("Provided model must contain array of FuncEnsemble objects");
    }

    final IntSeq learnLabels = learn.target(MCMicroF1Score.class).labels();
    final IntSeq testLabels = test.target(MCMicroF1Score.class).labels();

    final FuncEnsemble[] perClassModels = ArrayTools.map(joinedBinClassModel.getInternModel().dirs, FuncEnsemble.class, new Computable() {
      @Override
      public FuncEnsemble compute(final Trans argument) {
        return (FuncEnsemble) argument;
      }
    });
    final int ensembleSize = perClassModels[0].size();
    final int classesCount = perClassModels.length;

    final Mx learnCache = new VecBasedMx(learn.size(), classesCount);
    final Mx testCache = new VecBasedMx(test.size(), classesCount);

    for (int t = 0; t < ensembleSize; t += period) {
      final Func[] functions = new Func[classesCount];
      for (int c = 0; c < classesCount; c++) {
        functions[c] = new FuncEnsemble<>(
            Arrays.copyOfRange(perClassModels[c].models, t, Math.min(t + period, ensembleSize), Func[].class),
            perClassModels[c].weights.sub(0, t)
        );
      }
      final FuncJoin deltaFuncJoin = new FuncJoin(functions);

      VecTools.append(learnCache, deltaFuncJoin.transAll(learn.vecData().data()));
      VecTools.append(testCache, deltaFuncJoin.transAll(test.vecData().data()));

      final IntSeq learnPredict = convertTransOneVsRestResults(learnCache);
      final IntSeq testPredict = convertTransOneVsRestResults(testCache);

      final ConfusionMatrix learnConfusionMatrix = new ConfusionMatrix(learnLabels, learnPredict);
      final ConfusionMatrix testConfusionMatrix = new ConfusionMatrix(testLabels, testPredict);
      System.out.print("\t" + learnConfusionMatrix.oneLineReport());
      System.out.print("\t" + testConfusionMatrix.oneLineReport());
      System.out.println();
    }
  }

  private static IntSeq convertTransOneVsRestResults(final Mx trans) {
    final int[] result = new int[trans.rows()];
    for (int i = 0; i < trans.rows(); i++) {
      final Vec row = trans.row(i);
      result[i] = VecTools.argmax(row);
    }
    return new IntSeq(result);
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy