All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.expleague.ml.optimization.impl.AdamDescent Maven / Gradle / Ivy

There is a newer version: 1.4.9
Show newest version
package com.expleague.ml.optimization.impl;

import com.expleague.commons.math.FuncC1;
import com.expleague.commons.math.MathTools;
import com.expleague.commons.math.vectors.Vec;
import com.expleague.commons.math.vectors.VecTools;
import com.expleague.commons.math.vectors.impl.vectors.ArrayVec;
import com.expleague.ml.func.FuncEnsemble;
import com.expleague.ml.func.RegularizerFunc;
import com.expleague.ml.optimization.Optimize;

import java.util.*;
import java.util.function.Consumer;
import java.util.stream.IntStream;

public class AdamDescent implements Optimize> {
  private final double step;
  private final double beta1;
  private final double beta2;
  private final double eps;
  private final Random random;
  private final int epochCount;
  private final int batchSize;
  private Consumer listener;

  public AdamDescent(Random random, int epochCount, int batchSize) {
    this(random, epochCount, batchSize, 0.001, 0.9, 0.999, 1e-8);
  }

  public AdamDescent(Random random, int epochCount, int batchSize, double step) {
    this(random, epochCount, batchSize, step, 0.9, 0.999, 1e-8);
  }

  public AdamDescent(Random random, int epochCount, int batchSize, double step, double beta1, double beta2, double eps) {
    this.random = random;
    this.epochCount = epochCount;
    this.batchSize = batchSize;
    this.step = step * Math.sqrt(batchSize);
    this.beta1 = beta1;
    this.beta2 = beta2;
    this.eps = eps;
  }

  @Override
  public Vec optimize(FuncEnsemble sumFuncs) {
    final Vec x = new ArrayVec(sumFuncs.dim());
    for (int i = 0; i < sumFuncs.dim(); i++) {
      x.set(i, random.nextGaussian());
    }
    return optimize(sumFuncs, x);
  }

  @Override
  public Vec optimize(FuncEnsemble sumFuncs, RegularizerFunc reg, Vec x0) {
    final long startTime = System.nanoTime();

    Vec x = VecTools.copy(x0);
    final Vec v = new ArrayVec(x.dim());
    final Vec c = new ArrayVec(x.dim());
    double error = getLoss(sumFuncs, x);

    final List permutation = new ArrayList<>(sumFuncs.size());
    for (int i = 0; i < sumFuncs.size(); i++) {
      permutation.add(i);
    }

    long timeToGrad = 0;
    long timeToSum = 0;
    for (int epoch = 0; epoch < epochCount; epoch++) {
      Collections.shuffle(permutation, random);
      for (int i = 0; i + batchSize <= sumFuncs.size(); i += batchSize) {
        IntStream stream;
        Vec finalX = x;
        if (batchSize > 1) {
          stream = IntStream.range(i, i + batchSize).parallel();
        } else {
          stream = IntStream.range(i, i + batchSize);
        }
        long start = System.nanoTime();
        final Vec gradVec = stream
            .mapToObj(j -> sumFuncs.models[permutation.get(j)].gradient(finalX))
            .reduce(VecTools::append).get();
        final double[] grad = gradVec.toArray();

        //todo is converting sparse vector to array a good idea?
        timeToGrad += System.nanoTime() - start;
        start = System.nanoTime();

//        final int threadCount = Runtime.getRuntime().availableProcessors();
        final int threadCount = batchSize;
        final int blockSize = (x.dim() + threadCount - 1) / threadCount;
        IntStream.range(0, threadCount).parallel().forEach(blockNum -> {
          for (int index = blockNum * blockSize; index < Math.min(finalX.dim(), (blockNum + 1) * blockSize); index++) {
            final double gradAt = grad[index] / batchSize;
            v.set(index, v.get(index) * beta2 + gradAt * (1 - beta2));
            c.set(index, c.get(index) * beta1 + MathTools.sqr(gradAt * (1 - beta1)));
            finalX.adjust(index, -step * v.get(index) / (Math.sqrt(c.get(index) + eps)));
          }
        });
        x = reg.project(x);
        timeToSum += System.nanoTime() - start;
      }
      if (listener != null) {
        listener.accept(x);
      }
      if ((epoch + 1) % 5 == 0) {
        final double curError = getLoss(sumFuncs, x);
        System.out.printf("ADAM descent epoch %d: new=%.6f old=%.6f\n", epoch, curError, error);
//        if (curError > error) {
//          System.out.printf("ADAM descent finished after %d epochs\n", epoch);
//          break;
//        }
//        System.out.println(x);
        System.out.println("|x|=" + VecTools.norm(x));
        error = curError;
      } else if (epoch == epochCount - 1) {
        final double curError = getLoss(sumFuncs, x);
        System.out.printf("ADAM descent epoch %d: new=%.6f old=%.6f\n", epoch, curError, error);
      }
    }
    System.out.printf("Time to grad: %.3f, time to sum: %.3f\n", timeToGrad / 1e9, timeToSum / 1e9);
    System.out.printf("Adam Descent finished in %.2f seconds\n", (System.nanoTime() - startTime) / 1e9);
    return x;
  }

  public void setListener(Consumer listener) {
    this.listener = listener;
  }

  private double getLoss(FuncEnsemble sumFuncs, Vec x) {
    return Arrays.stream(sumFuncs.models).parallel().mapToDouble(func -> func.value(x)).sum() / sumFuncs.size();
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy