All Downloads are FREE. Search and download functionalities are using the official Maven repository.
Please wait. This can take some minutes ...
Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance.
Project price only 1 $
You can buy this project and download/modify it how often you want.
com.expleague.ml.embedding.decomp.DecompBuilder Maven / Gradle / Ivy
package com.expleague.ml.embedding.decomp;
import com.expleague.commons.math.MathTools;
import com.expleague.commons.math.vectors.Mx;
import com.expleague.commons.math.vectors.Vec;
import com.expleague.commons.math.vectors.VecTools;
import com.expleague.commons.math.vectors.impl.mx.VecBasedMx;
import com.expleague.commons.math.vectors.impl.vectors.ArrayVec;
import com.expleague.commons.random.FastRandom;
import com.expleague.commons.seq.CharSeq;
import com.expleague.commons.util.logging.Interval;
import com.expleague.ml.embedding.Embedding;
import com.expleague.ml.embedding.impl.EmbeddingBuilderBase;
import com.expleague.ml.embedding.impl.EmbeddingImpl;
import gnu.trove.list.array.TIntArrayList;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.HashMap;
import java.util.Map;
import java.util.stream.IntStream;
public class DecompBuilder extends EmbeddingBuilderBase {
private static final Logger log = LoggerFactory.getLogger(DecompBuilder.class.getName());
private double xMax = 10;
private double alpha = 0.75;
private int symDim = 50;
private int skewDim = 10;
private FastRandom rng = new FastRandom();
public DecompBuilder xMax(int xMax) {
this.xMax = xMax;
return this;
}
public DecompBuilder alpha(double alpha) {
this.alpha = alpha;
return this;
}
public DecompBuilder dimSym(int dim) {
this.symDim = dim;
return this;
}
public DecompBuilder dimSkew(int dim) {
this.skewDim = dim;
return this;
}
public DecompBuilder seed(long seed) {
rng = new FastRandom(seed);
return this;
}
private double weightingFunc(double x) {
return x < xMax ? Math.pow((x / xMax), alpha) : 1;
}
@Override
public Embedding fit() {
final int size = dict().size();
final Mx symDecomp = new VecBasedMx(size, symDim);
final Mx skewsymDecomp = new VecBasedMx(size, skewDim);
final Vec bias = new ArrayVec(size);
for (int i = 0; i < size; i++) {
bias.set(i, initializeValue(symDim));
for (int j = 0; j < symDim; j++) {
symDecomp.set(i, j, initializeValue(symDim));
}
for (int j = 0; j < skewDim; j++) {
skewsymDecomp.set(i, j, initializeValue(skewDim));
}
}
final Mx softMaxSym = new VecBasedMx(symDecomp.rows(), symDecomp.columns());
final Mx softMaxSkewsym = new VecBasedMx(skewsymDecomp.rows(), skewsymDecomp.columns());
final Vec softMaxBias = new ArrayVec(bias.dim());
VecTools.fill(softMaxSym, 1);
VecTools.fill(softMaxSkewsym, 1);
VecTools.fill(softMaxBias, 1);
final TIntArrayList order = new TIntArrayList(IntStream.range(0, size).toArray());
rng = new FastRandom();
for (int iter = 0; iter < T(); iter++) {
Interval.start();
order.shuffle(rng);
final ScoreCalculator scoreCalculator = new ScoreCalculator(size);
IntStream.range(0, size).parallel().map(order::get).forEach(i -> {
final Vec sym_i = symDecomp.row(i);
final Vec skew_i = skewsymDecomp.row(i);
final Vec softMaxSym_i = softMaxSym.row(i);
final Vec softMaxSkew_i = softMaxSkewsym.row(i);
cooc(i, (j, X_ij) -> {
final Vec sym_j = symDecomp.row(j);
final Vec skew_j = skewsymDecomp.row(j);
final Vec softMaxSym_j = softMaxSym.row(j);
final Vec softMaxSkew_j = softMaxSkewsym.row(j);
final double b_i = bias.get(i);
final double b_j = bias.get(j);
double asum = VecTools.multiply(sym_i, sym_j);
double bsum = VecTools.multiply(skew_i, skew_j);
final int sign = i > j ? -1 : 1;
final double minfo = Math.log(X_ij);
final double diff = b_i + b_j + asum + sign * bsum - minfo;
final double weight = weightingFunc(X_ij);
final double biasStep = weight * diff;
scoreCalculator.adjust(i, j, weight, 0.5 * weight * MathTools.sqr(diff));
update(sym_i, softMaxSym_i, sym_j, softMaxSym_j, diff * weight);
update(skew_i, softMaxSkew_i, skew_j, softMaxSkew_j, diff * weight * sign);
bias.adjust(i, -step() * biasStep / Math.sqrt(softMaxBias.get(i)));
softMaxBias.adjust(i, biasStep * biasStep);
bias.adjust(j, -step() * biasStep / Math.sqrt(softMaxBias.get(j)));
softMaxBias.adjust(j, biasStep * biasStep);
});
});
project(skewsymDecomp);
log.info("Iteration: " + iter + ", score: " + scoreCalculator.gloveScore() + ", time: " + Interval.time());
// Interval.stopAndPrint("Iteration: " + iter + ", score: " + scoreCalculator.gloveScore());
}
final Map mapping = new HashMap<>();
for (int i = 0; i < dict().size(); i++) {
final CharSeq word = dict().get(i);
mapping.put(word, symDecomp.row(i));
}
// try (final BufferedWriter writer = Files.newBufferedWriter(Paths.get("/Users/solar/temp/skewsym.txt"))) {
// for (int i = 0; i < dict().size(); i++) {
// writer.write(dict().get(i).toString());
// writer.write('\t');
// writer.write(MathTools.CONVERSION.convert(skewsymDecomp.row(i), CharSequence.class).toString());
// writer.write('\n');
// }
// }
// catch (IOException e) {
// throw new RuntimeException(e);
// }
return new EmbeddingImpl<>(mapping);
}
private void project(Vec vec) {
final int dim = vec.dim();
for (int i = 0; i < dim; i++) {
final double v = vec.get(i);
double lambda = 1e-2;
if (v < -lambda) {
vec.set(i, v + lambda);
}
else if (v > lambda) {
vec.set(i, v - lambda);
}
else {
vec.set(i, 0.);
}
}
}
private void update(Vec x_i, Vec softMaxD_i, Vec x_j, Vec softMaxD_j, double step) {
IntStream.range(0, x_i.dim()).forEach(id -> {
final double dx_i = x_j.get(id) * step;
final double dx_j = x_i.get(id) * step;
final double maxL_i = softMaxD_i.get(id);
final double maxL_j = softMaxD_j.get(id);
x_i.adjust(id, -step() * dx_i / Math.sqrt(maxL_i));
x_j.adjust(id, -step() * dx_j / Math.sqrt(maxL_j));
softMaxD_i.set(id, maxL_i + MathTools.sqr(dx_i));
softMaxD_j.set(id, maxL_j + MathTools.sqr(dx_j));
});
}
private double initializeValue(int dim) {
return (Math.random() - 0.5) / dim;
}
}