All Downloads are FREE. Search and download functionalities are using the official Maven repository.
Please wait. This can take some minutes ...
Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance.
Project price only 1 $
You can buy this project and download/modify it how often you want.
com.expleague.ml.optimization.impl.AdamDescent Maven / Gradle / Ivy
package com.expleague.ml.optimization.impl;
import com.expleague.commons.math.FuncC1;
import com.expleague.commons.math.vectors.Vec;
import com.expleague.commons.math.vectors.VecTools;
import com.expleague.commons.math.vectors.impl.vectors.ArrayVec;
import com.expleague.ml.func.FuncEnsemble;
import com.expleague.ml.optimization.Optimize;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Random;
import java.util.stream.IntStream;
public class AdamDescent implements Optimize> {
private final double step;
private final double beta1;
private final double beta2;
private final double eps;
private final Random random;
private final int epochCount;
private final int batchSize;
public AdamDescent(Random random, int epochCount, int batchSize) {
this(random, epochCount, batchSize, 0.002, 0.9, 0.999, 1e-8);
}
public AdamDescent(Random random, int epochCount, int batchSize, double step) {
this(random, epochCount, batchSize, step, 0.9, 0.999, 1e-8);
}
public AdamDescent(Random random, int epochCount, int batchSize, double step, double beta1, double beta2, double eps) {
this.random = random;
this.epochCount = epochCount;
this.batchSize = batchSize;
this.step = step * Math.sqrt(batchSize);
this.beta1 = beta1;
this.beta2 = beta2;
this.eps = eps;
}
@Override
public Vec optimize(FuncEnsemble extends FuncC1> sumFuncs) {
final Vec x = new ArrayVec(sumFuncs.dim());
for (int i = 0; i < sumFuncs.dim(); i++) {
x.set(i, random.nextGaussian());
}
return optimize(sumFuncs, x);
}
@Override
public Vec optimize(FuncEnsemble extends FuncC1> sumFuncs, Vec x0) {
final long startTime = System.nanoTime();
final Vec x = VecTools.copy(x0);
final Vec v = new ArrayVec(x.dim());
final Vec c = new ArrayVec(x.dim());
Vec grad = new ArrayVec(x.dim());
double error = sumFuncs.value(x) / sumFuncs.size();
final List permutation = new ArrayList<>(sumFuncs.size());
for (int i = 0; i < sumFuncs.size(); i++) {
permutation.add(i);
}
for (int epoch = 0; epoch < epochCount; epoch++) {
Collections.shuffle(permutation, random);
for (int i = 0; i + batchSize < sumFuncs.size(); i += batchSize) {
VecTools.fill(grad, 0);
IntStream stream;
if (batchSize > 1) {
stream = IntStream.range(i, i + batchSize).parallel();
} else {
stream = IntStream.range(i, i + batchSize);
}
grad = stream
.mapToObj(j -> VecTools.scale(sumFuncs.models[permutation.get(j)].gradient(x), 1.0 / batchSize))
.reduce((vec1, vec2) -> VecTools.append(vec1, vec2)).get();
// for (int j = i; j < i + batchSize; j++) {
// VecTools.append(grad, sumFuncs.models[permutation.get(j)].gradient(x));
// }
// VecTools.scale(grad, 1.0 / batchSize);
VecTools.scale(v, beta2);
VecTools.incscale(v, grad, 1 - beta2);
VecTools.scale(c, beta1);
VecTools.scale(grad, grad);
VecTools.incscale(c, grad, 1 - beta1);
for (int j = 0; j < x.dim(); j++) {
x.adjust(j, -step * v.get(j) / (Math.sqrt(c.get(j) + eps)));
}
}
if ((epoch + 1) % 5 == 0) {
final double curError = sumFuncs.value(x) / sumFuncs.size();
System.out.printf("ADAM descent epoch %d: new=%.6f old=%.6f\n", epoch, curError, error);
if (curError > error) {
System.out.printf("ADAM descent finished after %d epochs\n", epoch);
break;
}
error = curError;
} else if (epoch == epochCount - 1) {
final double curError = sumFuncs.value(x) / sumFuncs.size();
System.out.printf("ADAM descent epoch %d: new=%.6f old=%.6f\n", epoch, curError, error);
}
}
System.out.printf("Adam Descent finished in %.2f seconds\n", (System.nanoTime() - startTime) / 1e9);
return x;
}
}