All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.expleague.ml.methods.nn.NeuralTreesOptimization Maven / Gradle / Ivy

package com.expleague.ml.methods.nn;

import com.expleague.commons.math.FuncC1;
import com.expleague.commons.math.Trans;
import com.expleague.commons.math.vectors.Mx;
import com.expleague.commons.math.vectors.Vec;
import com.expleague.commons.math.vectors.VecTools;
import com.expleague.commons.math.vectors.impl.mx.VecBasedMx;
import com.expleague.commons.math.vectors.impl.vectors.ArrayVec;
import com.expleague.commons.random.FastRandom;
import com.expleague.ml.BFGrid;
import com.expleague.ml.BlockedTargetFunc;
import com.expleague.ml.GridTools;
import com.expleague.ml.ProgressHandler;
import com.expleague.ml.data.set.DataSet;
import com.expleague.ml.data.set.VecDataSet;
import com.expleague.ml.data.set.impl.VecDataSetImpl;
import com.expleague.ml.func.Ensemble;
import com.expleague.ml.loss.L2;
import com.expleague.ml.loss.WeightedLoss;
import com.expleague.ml.methods.BootstrapOptimization;
import com.expleague.ml.methods.GradientBoosting;
import com.expleague.ml.methods.Optimization;
import com.expleague.ml.methods.trees.GreedyObliviousTree;
import com.expleague.ml.models.nn.ConvNet;

import java.util.List;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

public class NeuralTreesOptimization implements Optimization {
  private Vec x;
  private final int numIterations;
  private final int nSampleBuildTree;
  private final ConvNet nn;
  private final FastRandom rng;
  private Loss loss;
  private final TreesLoss curLoss = new TreesLoss();
  private final int[] sampleIdxs;

  public NeuralTreesOptimization(int numIterations, int nSampleBuildTree, ConvNet nn, FastRandom rng) {
    this.numIterations = numIterations;
    this.nSampleBuildTree = nSampleBuildTree;
    this.nn = nn;
    this.rng = rng;
    sampleIdxs = new int[nSampleBuildTree];
  }

  @Override
  public Function fit(VecDataSet learn, Loss loss) {
    this.loss = loss;

    final Mx highFeatures = new VecBasedMx(nSampleBuildTree, nn.ydim());
    final Ensemble[] ensemble = {null};

    for (int iter = 0; iter < numIterations; iter++) {
      for (int i = 0; i < nSampleBuildTree; i++) {
        sampleIdxs[i] = rng.nextInt(learn.length());
        final Vec result = nn.apply(learn.data().row(sampleIdxs[i]));
        VecTools.assign(highFeatures.row(i), result);
      }

      final VecDataSetImpl highLearn = new VecDataSetImpl(highFeatures, null);

      final BFGrid grid = GridTools.medianGrid(highLearn, 32);
      final GreedyObliviousTree> weak = new GreedyObliviousTree<>
          (grid, 6);

      final GradientBoosting boosting = new GradientBoosting<>(
          new BootstrapOptimization<>(weak, rng), L2.class, 1000, 0.005);
      final Consumer counter = new ProgressHandler() {
        int index = 0;

        @Override
        public void accept(Trans partial) {
          if (index % 100 == 0)
            System.out.println("boost [" + (index++) + "]");
        }
      };
      boosting.addListener(counter);

      final Ensemble curEnsemble = boosting.fit(highLearn, (Loss) curLoss);

      final Vec result = curEnsemble.transAll(learn.data()).vec();
      final double value = curLoss.value(result);
      System.out.println("[" + iter + "], loss: " + value);

      final Vec grad = new ArrayVec(nn.wdim());
      final Vec x = nn.weights();
      final double sgdStep = 1e-3;
      for (int i = 0; i < nSampleBuildTree; i++) {
        final Vec treeOut = curEnsemble.trans(highLearn.data().row(i));
        nn.gradientTo(learn.data().row(sampleIdxs[i]), x, new TargetByTreeOut(sampleIdxs[i], treeOut), grad);

        VecTools.append(x, VecTools.scale(grad, -sgdStep));
      }

      if (iter == numIterations - 1) {
        ensemble[0] = curEnsemble;
      }
    }

    return argument -> {
      Vec result = nn.apply(argument, x);
      return ensemble[0].trans(result);
    };
  }

  private class TargetByTreeOut extends FuncC1.Stub {
    private final int blockIdx;
    private final Vec treeOut;

    TargetByTreeOut(int blockIdx, Vec treeOut) {
      this.blockIdx = blockIdx;
      this.treeOut = treeOut;
    }

    @Override
    public double value(Vec x) {
      throw new UnsupportedOperationException();
    }

    @Override
    public Vec gradientTo(Vec x, Vec to) {
      return ((FuncC1) loss.block(blockIdx)).gradientTo(treeOut, to);
    }

    @Override
    public int dim() {
      return 0;
    }
  }

  private class TreesLoss extends Trans.Stub implements BlockedTargetFunc {
    @Override
    public FuncC1 block(int index) {
      return (FuncC1) loss.block(sampleIdxs[index]);
    }

    @Override
    public int blocksCount() {
      return nSampleBuildTree;
    }

    @Override
    public int xdim() {
      return nSampleBuildTree;
    }

    @Override
    public Trans gradient() {
      final List collect = IntStream.range(0, sampleIdxs.length)
          .mapToObj(this::block)
          .collect(Collectors.toList());

      return new Trans.Stub() {
        @Override
        public Vec transTo(Vec argument, Vec to) {
          Vec next = new ArrayVec(to.dim());
          for (int i = 0; i < to.dim(); i++) {
            to.set(i, collect.get(i).gradientTo(argument, next).get(0));
          }
          return to;
        }

        @Override
        public int xdim() {
          return TreesLoss.this.xdim();
        }

        @Override
        public int ydim() {
          return TreesLoss.this.xdim();
        }
      };
    }

    @Override
    public DataSet owner() {
      throw new UnsupportedOperationException();
    }

    @Override
    public double value(Vec x) {
      return IntStream.range(0, sampleIdxs.length)
          .mapToObj(this::block)
          .map(func -> func.value(x))
          .collect(Collectors.averagingDouble(i -> i));
    }

    @Override
    public int dim() {
      throw new UnsupportedOperationException();
    }

    @Override
    public int ydim() {
      throw new UnsupportedOperationException();
    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy