All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.expleague.ml.methods.seq.GradientSeqBoosting Maven / Gradle / Ivy

package com.expleague.ml.methods.seq;

import com.expleague.commons.func.impl.WeakListenerHolderImpl;
import com.expleague.commons.math.Trans;
import com.expleague.commons.math.vectors.Vec;
import com.expleague.commons.math.vectors.VecTools;
import com.expleague.commons.math.vectors.impl.vectors.ArrayVec;
import com.expleague.commons.seq.Seq;
import com.expleague.ml.TargetFunc;
import com.expleague.ml.data.set.DataSet;
import com.expleague.ml.data.tools.DataTools;
import com.expleague.ml.loss.L2;
import com.expleague.ml.methods.SeqOptimization;

import java.util.ArrayList;
import java.util.List;
import java.util.function.Function;


public class GradientSeqBoosting extends WeakListenerHolderImpl,Vec>> implements SeqOptimization {
  protected final SeqOptimization weak;
  private final Class factory;
  int iterationsCount;

  double step;

  public GradientSeqBoosting(final SeqOptimization weak, final int iterationsCount, final double step) {
    this(weak, L2.class, iterationsCount, step);
  }

  public GradientSeqBoosting(final SeqOptimization weak, final Class factory, final int iterationsCount, final double step) {
    this.weak = weak;
    this.factory = factory;
    this.iterationsCount = iterationsCount;
    this.step = step;
  }

  @Override
  public Function, Vec> fit(final DataSet> learn, final GlobalLoss globalLoss) {
    final Vec cursor = new ArrayVec(globalLoss.xdim());
    final List,Vec>> weakModels = new ArrayList<>(iterationsCount);
    final Trans gradient = globalLoss.gradient();
    for (int t = 0; t < iterationsCount; t++) {
      assert gradient != null;
      final Vec gradientValueAtCursor = gradient.trans(cursor);
      final L2 localLoss = DataTools.newTarget(factory, gradientValueAtCursor, learn);
      System.out.println("Iteration " + t + ". Gradient norm: " + VecTools.norm(localLoss.target));
      final Function, Vec> weakModel = weak.fit(learn, localLoss);
      weakModels.add(weakModel);
      final Function,Vec> curRes = new GradientSeqBoostingModel<>(new ArrayList<>(weakModels), step);
      invoke(curRes);
      for (int i = 0; i < learn.length(); i++) {
        final Vec val = weakModel.apply(learn.at(i));
        for (int j = 0; j < val.dim(); j++) {
          cursor.adjust(i * val.dim() + j, val.get(j) * -step);
        }
      }
    }
    return new GradientSeqBoostingModel<>(weakModels, step);
  }

  public static class GradientSeqBoostingModel implements Function, Vec> {
    private List, Vec>> models;
    private double step;

    GradientSeqBoostingModel(final List, Vec>> models, final double step) {
      this.models = new ArrayList<>(models);
      this.step = step;
    }

    @Override
    public Vec apply(Seq seq) {
      Vec result = null;
      for (Function, Vec> model: models) {
        if (result == null) {
          result = model.apply(seq);
        } else {
          VecTools.append(result, model.apply(seq));
        }
      }
      VecTools.scale(result, -step);
      return result;
    }

    public List, Vec>> getModels() {
      return models;
    }

    public void setModels(List, Vec>> models) {
      this.models = models;
    }

    public double getStep() {
      return step;
    }

    public void setStep(double step) {
      this.step = step;
    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy