All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.expleague.ml.methods.StochasticGradientDescent Maven / Gradle / Ivy

package com.expleague.ml.methods;

import com.expleague.commons.math.vectors.Vec;
import com.expleague.commons.math.vectors.VecTools;
import com.expleague.commons.math.vectors.impl.vectors.ArrayVec;
import com.expleague.commons.math.vectors.impl.vectors.ROVec;
import com.expleague.ml.loss.DSSumFuncComposite;
import com.expleague.commons.func.impl.WeakListenerHolderImpl;
import com.expleague.commons.random.FastRandom;
import com.expleague.commons.util.ThreadTools;
import com.expleague.ml.data.set.DataSet;

import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Executor;

/**
 * User: solar
 * Date: 01.06.15
 * Time: 12:49
 */
public class StochasticGradientDescent extends WeakListenerHolderImpl implements Optimization, DataSet,Item> {
  private final Executor executor;

  private final FastRandom rng;
  private final int couple;
  private final int T;
  private final double step;

  public StochasticGradientDescent(FastRandom rng, int couple, int T, double step) {
    this.rng = rng;
    this.couple = couple;
    this.T = T;
    this.step = step;
    executor = ThreadTools.createBGExecutor(StochasticGradientDescent.class.getName(), this.couple);
  }
  @Override
  public DSSumFuncComposite.Decision fit(DataSet learn, final DSSumFuncComposite target) {
    final Vec cursor = new ArrayVec(target.dim());
    init(cursor);
    final Vec[] coupleVec = new Vec[couple];
    final Vec gradient = new ArrayVec(target.dim());
    for (int t = 0; t < T; t++) {
      VecTools.fill(gradient, 0.);
      final CountDownLatch latch = new CountDownLatch(couple);
      for (int i = 0; i < couple; i++) {
//        final int nextItem = rng.nextInt(1000);
        final int nextItem = rng.nextInt(learn.length());
//        System.out.println("sample :" + learn.meta().owner().feature(0, nextItem) + " target: " + target.component(nextItem));
        final int finalI = i;
        executor.execute(new Runnable() {
          @Override
          public void run() {
            final Vec currentGrad = target.component(nextItem).gradient(new ROVec(cursor));
            coupleVec[finalI] = currentGrad;
            synchronized (gradient) {
              VecTools.append(gradient, currentGrad);
            }
            latch.countDown();
          }
        });
      }
      try {
        latch.await();
      } catch (InterruptedException e) {
        throw new RuntimeException(e);
      }
      VecTools.scale(gradient, 1. / couple);
      {
//        double meanCos = 0;
//        for (int i = 0; i < couple; i++) {
//          for (int j = 0; j < couple; j++) {
//            meanCos += VecTools.cosine(coupleVec[i], coupleVec[j]) / couple / couple;
//          }
//        }
////        System.out.println(gradient);
//        System.out.println(meanCos + " " + VecTools.norm(gradient));
      }
      normalizeGradient(gradient);
//      VecTools.scale(gradient, step * 100. / Math.sqrt(10000. + t));
      VecTools.scale(gradient, step);
      VecTools.append(cursor, gradient);
      invoke(new ROVec(cursor));
    }
    return target.decision(cursor);
  }

  public void init(Vec cursor) {
    VecTools.fillUniform(cursor, rng);
  }

  public void normalizeGradient(Vec grad) {
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy