All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.expleague.ml.methods.seq.GradientSeqBoosting Maven / Gradle / Ivy

package com.expleague.ml.methods.seq;

import com.expleague.commons.func.Computable;
import com.expleague.commons.func.impl.WeakListenerHolderImpl;
import com.expleague.commons.math.Trans;
import com.expleague.commons.math.vectors.SingleValueVec;
import com.expleague.commons.math.vectors.Vec;
import com.expleague.commons.math.vectors.VecTools;
import com.expleague.commons.math.vectors.impl.vectors.ArrayVec;
import com.expleague.commons.seq.Seq;
import com.expleague.ml.TargetFunc;
import com.expleague.ml.data.set.DataSet;
import com.expleague.ml.data.tools.DataTools;
import com.expleague.ml.loss.L2;
import com.expleague.ml.methods.SeqOptimization;

import java.util.ArrayList;
import java.util.List;


public class GradientSeqBoosting extends WeakListenerHolderImpl, Vec>> implements SeqOptimization {
  protected final SeqOptimization weak;
  private final Class factory;
  int iterationsCount;

  double step;

  public GradientSeqBoosting(final SeqOptimization weak, final int iterationsCount, final double step) {
    this(weak, L2.class, iterationsCount, step);
  }

  public GradientSeqBoosting(final SeqOptimization weak, final Class factory, final int iterationsCount, final double step) {
    this.weak = weak;
    this.factory = factory;
    this.iterationsCount = iterationsCount;
    this.step = step;
  }

  @Override
  public Computable, Vec> fit(final DataSet> learn, final GlobalLoss globalLoss) {
    final Vec cursor = new ArrayVec(globalLoss.xdim());
    final List, Vec>> weakModels = new ArrayList<>(iterationsCount);
    final Trans gradient = globalLoss.gradient();
    for (int t = 0; t < iterationsCount; t++) {
      final Vec gradientValueAtCursor = gradient.trans(cursor);
      final L2 localLoss = DataTools.newTarget(factory, gradientValueAtCursor, learn);
      System.out.println("Iteration " + t + ". Gradient norm: " + VecTools.norm(localLoss.target));
      final Computable, Vec> weakModel = weak.fit(learn, localLoss);
      weakModels.add(weakModel);
      final Computable, Vec> curRes = getResult(new ArrayList<>(weakModels));
      invoke(curRes);
      for (int i = 0; i < learn.length(); i++) {
        cursor.adjust(i, weakModel.compute(learn.at(i)).get(0) * -step);
      }
    }
    return getResult(weakModels);
  }

  private Computable, Vec> getResult(final List, Vec>> weakModels) {
    return argument -> {
      double result = 0;
      for (Computable, Vec> model: weakModels) {
        result += model.compute(argument).get(0) * -step;
      }
      return new SingleValueVec(result);
    };
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy