All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.expleague.ml.methods.greedyRegion.cherry.GreedyTDCherryRegion Maven / Gradle / Ivy

package com.expleague.ml.methods.greedyRegion.cherry;

import com.expleague.commons.math.Func;
import com.expleague.commons.math.Trans;
import com.expleague.commons.math.vectors.Vec;
import com.expleague.commons.random.FastRandom;
import com.expleague.ml.data.cherry.CherryLoss;
import com.expleague.ml.data.cherry.CherrySubset;
import com.expleague.ml.data.impl.BinarizedDataSet;
import com.expleague.ml.data.set.VecDataSet;
import com.expleague.ml.loss.StatBasedLoss;
import com.expleague.ml.loss.WeightedLoss;
import com.expleague.ml.methods.VecOptimization;
import com.expleague.ml.models.CNF;
import com.expleague.commons.util.ArrayTools;
import com.expleague.ml.BFGrid;
import com.expleague.ml.Binarize;
import com.expleague.ml.data.cherry.CherryPick;

import java.util.ArrayList;
import java.util.List;

/**
 * User: solar
 * Date: 15.11.12
 * Time: 15:19
 */
public class GreedyTDCherryRegion extends VecOptimization.Stub {
  public final BFGrid grid;
  private final CherryPick pick = new CherryPick();
  public GreedyTDCherryRegion(final BFGrid grid) {
    this.grid = grid;
  }

  private int[] learnPoints(Loss loss, VecDataSet ds) {
    if (loss instanceof WeightedLoss) {
      return ((WeightedLoss) loss).points();
    } else return ArrayTools.sequence(0, ds.length());
  }

  @Override
  public CNF fit(final VecDataSet learn, final Loss loss) {
    final List conditions = new ArrayList<>(100);
    final BinarizedDataSet bds = learn.cache().cache(Binarize.class, VecDataSet.class).binarize(grid);
    int[] points = learnPoints(loss, learn);
    double currentScore = Double.NEGATIVE_INFINITY;
    CherryLoss localLoss;
    {
      localLoss = new OutLoss3<>(new CherrySubset(bds,loss.statsFactory(),points), loss);
//      RankedDataSet rds = learn.cache().cache(RankIt.class, VecDataSet.class).value();
//      localLoss = new OutLoss<>(new CherryStochasticSubset(rds, bds, loss.statsFactory(), points), loss);
    }

    double bestIncInside = 0;
    double bestIncOutside = 0;
    while (true) {
      final CNF.Clause clause = pick.fit(localLoss);
      final double score = localLoss.score();
      if (score <= currentScore + 1e-9) {
        break;
      }

      System.out.println("\nAdded clause " + clause);
      currentScore = score;
      bestIncInside = localLoss.insideIncrement();
      bestIncOutside = localLoss.outsideIncrement();
      conditions.add(clause);
    }
    return  new CNF(conditions.toArray(new CNF.Clause[conditions.size()]), bestIncInside, bestIncOutside, grid);
  }
}

class MultiMethodOptimization extends VecOptimization.Stub  {
  private final VecOptimization[] learners;
  private final FastRandom random;

  public MultiMethodOptimization(VecOptimization[] learners, FastRandom random) {
    this.learners = learners;
    this.random = random;
  }

  class FuncHolder extends Func.Stub {
    Func inside;
    FuncHolder(Func inside) {
      this.inside = inside;
    }

    @Override
    public double value(Vec x) {
      return inside.value(x);
    }

    @Override
    public int dim() {
      return inside.dim();
    }
  }

  @Override
  public Trans fit(VecDataSet learn, Loss loss) {
    return new FuncHolder((Func)learners[random.nextInt(learners.length)].fit(learn,loss));
  }

}





© 2015 - 2024 Weber Informatics LLC | Privacy Policy