All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.expleague.ml.factorization.impl.StochasticALS Maven / Gradle / Ivy

There is a newer version: 1.4.9
Show newest version
package com.expleague.ml.factorization.impl;

import com.expleague.commons.func.impl.WeakListenerHolderImpl;
import com.expleague.commons.math.vectors.Mx;
import com.expleague.commons.math.vectors.Vec;
import com.expleague.commons.math.vectors.VecTools;
import com.expleague.commons.math.vectors.impl.vectors.ArrayVec;
import com.expleague.commons.random.FastRandom;
import com.expleague.commons.util.Pair;
import com.expleague.commons.util.logging.Interval;
import com.expleague.ml.factorization.Factorization;

import java.util.function.BiConsumer;
import java.util.logging.Level;
import java.util.logging.Logger;
import java.util.stream.IntStream;

/**
 * Experts League
 * Created by solar on 04.05.17.
 */

public class StochasticALS extends WeakListenerHolderImpl> implements Factorization {
  private static final Logger log = Logger.getLogger(StochasticALS.class.getName());
  private final FastRandom rng;
  private final double gamma;
  private final int maxIterations;
  private final Cache cache;

  public StochasticALS(FastRandom rng, double gamma, int maxIterations, Cache cache) {
    this.rng = rng;
    this.gamma = gamma;
    this.maxIterations = maxIterations;
    this.cache = cache;
  }

  public StochasticALS(FastRandom rng, double gamma, int maxIterations) {
    this(rng, gamma, maxIterations, null);
  }

  public StochasticALS(FastRandom rng, double gamma, Cache cache) {
    this(rng, gamma, 100000, cache);
  }

  public StochasticALS(FastRandom rng, double gamma) {
    this(rng, gamma, 100000);
  }

  public Pair factorize(final Mx X, final BiConsumer scaledWeakLearner) {
    final int m = X.rows();
    final int n = X.columns();
    if (m < n * 10) {
      log.log(Level.WARNING, "This algorithm is intended to be used for matrices with rows >> columns");
    }

    final Vec v = new ArrayVec(n);
    VecTools.fillGaussian(v, rng);
    VecTools.scale(v, 1. / VecTools.norm(v));
    final double gamma = this.gamma / (2 + X.rows());
    int iteration = 0;

    if (cache != null) {
      Interval.start();
      cache.update(X, scaledWeakLearner);
      Interval.stopAndPrint("Cache update");
    }

    // call cursor update
    X.row(0);

    Interval.start();
    {
      double a;
      double u_hat;

      do {
        iteration++;
        final int i = rng.nextInt(m);
        final double denominator = Math.log(1 + iteration);
        final Vec row = cache != null ? cache.getAnyCached() : X.row(i);
        u_hat = VecTools.multiply(row, v);
        if (Double.isNaN(u_hat)) {
          break;
        }
        a = 0;
        for (int j = 0; j < n; j++) {
          final double val = gamma * (u_hat * (v.get(j) * u_hat - row.get(j))) / denominator;
          v.adjust(j, -val);
          if (a < Math.abs(val))
            a = Math.abs(val);
        }
        VecTools.scale(v, 1 / VecTools.norm(v)); //TODO
      }
      while (/*a > 0.2 * gamma * u_hat * u_hat &&*/ iteration < maxIterations);
    }
    Interval.stopAndPrint(iteration + " SALS iterations");

    Interval.start();
    VecTools.scale(v, 1 / VecTools.norm(v));
    final Vec u = new ArrayVec(m);
    IntStream.range(0, m).parallel().forEach(i -> {
      final Vec row = cache != null && cache.containsRow(i) ? cache.getRowFromCache(i) : X.row(i);
      double u_val = VecTools.multiply(row, v);
      u.set(i, Double.isNaN(u_val) ? 0 : u_val);
    });
    Interval.stopAndPrint("Decomposition calculation");

    return Pair.create(u, v);
  }

  @Override
  public Pair factorize(final Mx X) {
    return factorize(X, null);
  }

  public static class Cache {
    private final int cacheSizeMB;
    private final double updateRate;
    private final FastRandom rng;

    private int cacheSize;
    private Vec[] cachedRows = null;
    private int[] mtcIds = null;
    private int[] ctmIds = null;


    public Cache(final int cacheSizeMB, final double updateRate, final FastRandom rng) {
      this.cacheSizeMB = cacheSizeMB;
      this.updateRate = updateRate;
      this.rng = rng;
    }

    private void initializeCache(final Mx X) {
      cacheSize = Math.min(X.rows(), cacheSizeMB * 1024 * 1024 / (X.columns() * Double.BYTES));
      System.out.println("Initialized cache of size " + cacheSize);

      cachedRows = new Vec[cacheSize];
      mtcIds = new int[X.rows()];
      ctmIds = new int[X.rows()];

      for (int i = 0; i < cacheSize; i++) {
        cachedRows[i] = X.row(i);
        ctmIds[i] = i;
        mtcIds[i] = i;
      }
      for (int i = cacheSize; i < X.rows(); i++) {
        ctmIds[i] = i;
        mtcIds[i] = -1;
      }
    }

    private void swap(final int i, final int j) {
      if (i == j) {
        return;
      }

      if (i < cacheSize && j < cacheSize) {
        final Vec tmp_r = cachedRows[i];
        cachedRows[i] = cachedRows[j];
        cachedRows[j] = tmp_r;
      }

      final int tmp_i = ctmIds[i];
      ctmIds[i] = ctmIds[j];
      ctmIds[j] = tmp_i;

      mtcIds[ctmIds[i]] = i < cacheSize ? i : -1;
      mtcIds[ctmIds[j]] = j < cacheSize ? j : -1;
    }

    private void updateCache(final Mx X, final BiConsumer scaledWeakLearner) {
      final int toUpdate = (int) (updateRate * cacheSize);

      for (int i = cacheSize - toUpdate; i < cacheSize; i++) {
        final int j = rng.nextInt(i + 1);
        swap(i, j);
      }

      IntStream.range(0, cacheSize - toUpdate).parallel().forEach(i -> {
        scaledWeakLearner.accept(ctmIds[i], cachedRows[i]);
      });

      for (int i = cacheSize - toUpdate; i < cacheSize; i++) {
        final int j = i + rng.nextInt(X.rows() - i);
        swap(i, j);
        if (j < cacheSize) {
          scaledWeakLearner.accept(ctmIds[i], cachedRows[i]);
        } else {
          cachedRows[i] = X.row(ctmIds[i]);
        }
      }
    }

    public void update(final Mx X, final BiConsumer scaledWeakLearner) {
      if (cachedRows == null) {
        initializeCache(X);
      } else {
        updateCache(X, scaledWeakLearner);
      }
    }

    public Vec getAnyCached() {
      return cachedRows[rng.nextInt(cacheSize)];
    }

    public boolean containsRow(int i) {
      return mtcIds[i] != -1;
    }

    public Vec getRowFromCache(int i) {
      return cachedRows[mtcIds[i]];
    }

  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy