All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.expleague.ml.methods.seq.BootstrapSeqOptimization Maven / Gradle / Ivy

package com.expleague.ml.methods.seq;

import com.expleague.commons.math.vectors.impl.vectors.ArrayVec;
import com.expleague.commons.func.impl.WeakListenerHolderImpl;
import com.expleague.commons.math.Trans;
import com.expleague.commons.math.vectors.Vec;
import com.expleague.commons.random.FastRandom;
import com.expleague.commons.seq.Seq;
import com.expleague.ml.data.set.DataSet;
import com.expleague.ml.loss.L2;
import com.expleague.ml.loss.WeightedL2;
import com.expleague.ml.methods.SeqOptimization;

import java.util.function.Function;

public class BootstrapSeqOptimization extends WeakListenerHolderImpl implements SeqOptimization {
  protected final FastRandom rnd;
  private final SeqOptimization weak;

  public BootstrapSeqOptimization(final SeqOptimization weak, final FastRandom rnd) {
    this.weak = weak;
    this.rnd = rnd;
  }

  @Override
  public Function, Vec> fit(final DataSet> learn, final Loss globalLoss) {
    return weak.fit(learn, bootstrap(globalLoss, rnd));
  }

  private WeightedL2 bootstrap(Loss loss, FastRandom rnd) {
    final double[] poissonWeights = new double[loss.xdim()];
    for (int i = 0; i < loss.xdim(); i++) {
      poissonWeights[i] = rnd.nextPoisson(1.);
    }
    return new WeightedL2(loss.target, loss.owner(), new ArrayVec(poissonWeights));
  }
}





© 2015 - 2024 Weber Informatics LLC | Privacy Policy