All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.expleague.ml.methods.greedyRegion.RegionForest Maven / Gradle / Ivy

package com.expleague.ml.methods.greedyRegion;

import com.expleague.commons.math.vectors.impl.vectors.ArrayVec;
import com.expleague.commons.math.vectors.VecTools;
import com.expleague.commons.random.FastRandom;
import com.expleague.commons.util.ThreadTools;
import com.expleague.ml.BFGrid;
import com.expleague.commons.math.Trans;
import com.expleague.ml.data.set.VecDataSet;
import com.expleague.ml.data.tools.DataTools;
import com.expleague.ml.func.Ensemble;
import com.expleague.ml.loss.StatBasedLoss;
import com.expleague.ml.loss.WeightedLoss;
import com.expleague.ml.methods.MTA;
import com.expleague.ml.methods.VecOptimization;
import com.expleague.ml.models.Region;

import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ThreadPoolExecutor;


public class RegionForest extends VecOptimization.Stub {
  public enum MeanMethod {
    MTAConst, Stein, Naive, MTAMinMax
  }

  protected final FastRandom rnd;
  private final GreedyTDRegion weak;
  private final int weakCount;
  private final ThreadPoolExecutor pool;
  private final MeanMethod meanMethod;


  public RegionForest(final BFGrid grid, final FastRandom rnd, final int weakCount) {
    this(grid, rnd, weakCount, MeanMethod.Naive);
  }

  public RegionForest(final BFGrid grid, final FastRandom rnd, final int weakCount, final double alpha, final double beta) {
    this(grid, rnd, weakCount, MeanMethod.Naive, alpha, beta, 1);
  }

  public RegionForest(final BFGrid grid, final FastRandom rnd, final int weakCount, final MeanMethod meanMethod) {
    this(grid, rnd, weakCount, meanMethod, 0.02, 0.05, 1);
  }

  public RegionForest(final BFGrid grid, final FastRandom rnd, final int weakCount, final MeanMethod meanMethod, final double alpha, final double beta, final int maxFailed) {
    this.weak = new GreedyTDRegion<>(grid, alpha, beta, maxFailed);
    this.weakCount = weakCount;
    this.rnd = rnd;
    pool = ThreadTools.createBGExecutor("RF pool", weakCount);
    this.meanMethod = meanMethod;
  }

  public RegionForest(final BFGrid grid, final FastRandom rnd, final int weakCount, final double alpha, final double beta, final int maxFailed) {
    this(grid, rnd, weakCount, MeanMethod.Naive, alpha, beta, maxFailed);
  }


  @Override
  public Trans fit(final VecDataSet learn, final Loss globalLoss) {
    switch (meanMethod) {
      case Naive: {
        return fitNaive(learn, globalLoss);
      }
      default: {
        return fitMTA(learn, globalLoss);
      }
    }

  }

  public Trans fitMTA(final VecDataSet learn, final Loss globalLoss) {
    final GreedyTDRegion.RegionStats[] regions = new GreedyTDRegion.RegionStats[weakCount];
    final Trans[] weakModels = new Trans[weakCount];
    final CountDownLatch latch = new CountDownLatch(weakCount);
    for (int i = 0; i < weakCount; ++i) {
      final int index = i;
      pool.execute(new Runnable() {
        @Override
        public void run() {
          regions[index] = weak.findRegion(learn, DataTools.bootstrap(globalLoss, rnd));
          latch.countDown();
        }
      });
    }

    try {
      latch.await();
    } catch (InterruptedException e) {
      e.printStackTrace();
    }

    final double[][] samples = new double[weakCount][];
    for (int i = 0; i < weakCount; ++i) {
      samples[i] = regions[i].inside.toArray();
    }
    final MTA mta = new MTA(samples);
    final double[] means;
    switch (meanMethod) {
      case Stein: {
        means = mta.stein();
        break;
      }
      case MTAConst: {
        means = mta.mtaConst();
        break;
      }
      case MTAMinMax: {
        means = mta.mtaMiniMax();
        break;
      }
      default: {
        means = mta.classic();
        break;
      }
    }
    for (int i = 0; i < weakCount; ++i) {
      weakModels[i] = new Region(regions[i].conditions, regions[i].mask, means[i], 0, 0, 0, regions[i].maxFailed);
    }
    return new Ensemble(weakModels, VecTools.fill(new ArrayVec(weakModels.length), 1.0 / weakCount));
  }

  public Trans fitNaive(final VecDataSet learn, final Loss globalLoss) {
    final Trans[] weakModels = new Trans[weakCount];
    final CountDownLatch latch = new CountDownLatch(weakCount);
    for (int i = 0; i < weakCount; ++i) {
      final int index = i;
      pool.execute(new Runnable() {
        @Override
        public void run() {
          weakModels[index] = weak.fit(learn, DataTools.bootstrap(globalLoss, rnd));
          latch.countDown();
        }
      });
    }

    try {
      latch.await();
    } catch (InterruptedException e) {
      e.printStackTrace();
    }
    return new Ensemble(weakModels, VecTools.fill(new ArrayVec(weakModels.length), 1.0 / weakCount));
  }
}





© 2015 - 2024 Weber Informatics LLC | Privacy Policy