All Downloads are FREE. Search and download functionalities are using the official Maven repository.
Please wait. This can take some minutes ...
Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance.
Project price only 1 $
You can buy this project and download/modify it how often you want.
com.expleague.ml.embedding.decomp.MultiDecompBuilder Maven / Gradle / Ivy
package com.expleague.ml.embedding.decomp;
import com.expleague.commons.math.MathTools;
import com.expleague.commons.math.vectors.Vec;
import com.expleague.commons.math.vectors.VecTools;
import com.expleague.commons.math.vectors.impl.vectors.ArrayVec;
import com.expleague.commons.random.FastRandom;
import com.expleague.commons.seq.CharSeq;
import com.expleague.commons.seq.LongSeq;
import com.expleague.commons.util.ArrayTools;
import com.expleague.commons.util.MultiMap;
import com.expleague.commons.util.logging.Interval;
import com.expleague.ml.embedding.Embedding;
import com.expleague.ml.embedding.impl.EmbeddingBuilderBase;
import gnu.trove.list.array.TDoubleArrayList;
import gnu.trove.list.array.TIntArrayList;
import gnu.trove.list.array.TLongArrayList;
import gnu.trove.procedure.TLongProcedure;
import gnu.trove.set.TIntSet;
import gnu.trove.set.hash.TIntHashSet;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.stream.IntStream;
public class MultiDecompBuilder extends EmbeddingBuilderBase {
private static final Logger log = LoggerFactory.getLogger(MultiDecompBuilder.class);
private double xMax = 10;
private double alpha = 0.75;
private int symDim = 50;
private int skewDim = 10;
private final double minimumNorm = 1;
private FastRandom rng = new FastRandom();
public MultiDecompBuilder xMax(int xMax) {
this.xMax = xMax;
return this;
}
public MultiDecompBuilder alpha(double alpha) {
this.alpha = alpha;
return this;
}
public MultiDecompBuilder dimSym(int dim) {
this.symDim = dim;
return this;
}
public MultiDecompBuilder dimSkew(int dim) {
this.skewDim = dim;
return this;
}
public MultiDecompBuilder seed(long seed) {
rng = new FastRandom(seed);
return this;
}
private double weightingFunc(double x) {
return x < xMax ? Math.pow((x / xMax), alpha) : 1;
}
@Override
public Embedding fit() {
final int size = dict().size();
final List symDecomp = new ArrayList<>();
final List softMaxSym = new ArrayList<>();
final List skewsymDecomp = new ArrayList<>();
final List softMaxSkewsym = new ArrayList<>();
final TDoubleArrayList bias = new TDoubleArrayList(size);
final TDoubleArrayList softMaxBias = new TDoubleArrayList(size);
for (int i = 0; i < size; i++) {
symDecomp.add(new ArrayVec(IntStream.range(0, symDim).mapToDouble(d -> initializeValue(symDim)).toArray()));
softMaxSym.add(VecTools.fill(new ArrayVec(symDim), 1.));
skewsymDecomp.add(new ArrayVec(IntStream.range(0, skewDim).mapToDouble(d -> initializeValue(skewDim)).toArray()));
softMaxSkewsym.add(VecTools.fill(new ArrayVec(skewDim), 1.));
bias.add(initializeValue(symDim));
softMaxBias.add(1);
}
final TIntArrayList order = new TIntArrayList(IntStream.range(0, size).toArray());
rng = new FastRandom();
for (int iter = 0; iter < T(); iter++) {
Interval.start();
order.shuffle(rng);
final ScoreCalculator scoreCalculator = new ScoreCalculator(size);
int finalIter = iter;
ThreadLocal validPairsHolder = ThreadLocal.withInitial(TLongArrayList::new);
IntStream.range(0, size).parallel().map(order::get).forEach(i -> {
final Vec sym_i = symDecomp.get(i);
final Vec skew_i = skewsymDecomp.get(i);
final Vec softMaxSym_i = softMaxSym.get(i);
final Vec softMaxSkew_i = softMaxSkewsym.get(i);
final LongSeq cooc = cooc(i);
if (cooc.length() < 10000 && cooc.length() > 1000 && finalIter > 0 && VecTools.norm(sym_i) > minimumNorm) {
final TLongArrayList validPairs = validPairsHolder.get();
validPairs.reset();
double qualityThreshold = 0.7;//Math.cos(Math.PI / 2.5);
// for (int u = 0; u < cooc.length(); u++) {
// Vec vecA = symDecomp.get(u);
// qualityThreshold += VecTools.cosine(sym_i, vecA);
// }
// qualityThreshold /= cooc.length();
// qualityThreshold = Math.max(0, qualityThreshold);
double[] counters = new double[cooc.length()];
Arrays.fill(counters, Double.NaN);
for (int u = 0; u < cooc.length(); u++) {
final Vec vecU = symDecomp.get(u);
final double norm_u = VecTools.norm(vecU);
if (norm_u < minimumNorm)
continue;
counters[u] = unpackWeight(cooc, u);
for (int v = u + 1; v < cooc.length(); v++) {
final Vec vecV = symDecomp.get(v);
final double norm_v = VecTools.norm(vecV);
if (norm_v < minimumNorm)
continue;
if (VecTools.multiply(vecU, vecV) / norm_u / norm_v > qualityThreshold) {
validPairs.add(((long) (u + 1) << 32) | (v + 1));
validPairs.add(((long) (v + 1) << 32) | (u + 1));
}
}
}
validPairs.sort();
validPairs.forEach(p -> {
final int u = (int) (p >>> 32) - 1;
final int v = (int) (p & 0xFFFFFFFFL) - 1;
counters[u] += unpackWeight(cooc, v);
return true;
});
List clusters = new ArrayList<>();
List> wordClusters = new ArrayList<>();
while (true) {
int max = ArrayTools.max(counters);
if (max < 0)
break;
counters[max] = Double.NaN;
final TIntHashSet cluster = new TIntHashSet();
final List wordsCluster = new ArrayList<>();
cluster.add(max);
wordsCluster.add(dict().get(max).toString());
{ // form cluster
int index = -validPairs.binarySearch((long) (max + 1) << 32) - 1;
long limit = ((long) (max + 2) << 32);
long p;
while (index < validPairs.size() && (p = validPairs.getQuick(index)) < limit) {
int v = (int) (p & 0xFFFFFFFFL) - 1;
if (!Double.isNaN(counters[v])) {
counters[v] = Double.NaN;
cluster.add(v);
wordsCluster.add(dict().get(unpackB(cooc, v)).toString());
}
index++;
}
}
validPairs.forEach(new TLongProcedure() {
int current = 0;
float currentWeight;
@Override
public boolean execute(long p) { // update counters
int u = (int) (p >>> 32) - 1;
int v = (int) (p & 0xFFFFFFFFL) - 1;
if (u != current) {
current = u;
currentWeight = cluster.contains(u) ? unpackWeight(cooc, u) : 0.f;
}
if (currentWeight != 0f)
counters[v] -= currentWeight;
return true;
}
});
clusters.add(cluster);
if (cluster.size() == 1)
continue;
wordClusters.add(wordsCluster);
}
CharSeq word = dict().get(i);
if (word.equals("apple") || word.equals("lock")) {
clusters.size();
}
if (word.equals("apple") || wordClusters.size() > 1 && wordClusters.get(0).size() / (double)wordClusters.get(1).size() < 10 && wordClusters.get(1).size() > 10) {
StringBuilder builder = new StringBuilder();
builder.append(word).append('\n');
for (List cluster : wordClusters) {
builder.append('\t').append(cluster.size()).append('\t');
for (int j = 0; j < cluster.size() && j < 10; j++) {
builder.append(cluster.get(j)).append(',');
}
builder.append('\n');
}
System.out.println(builder);
}
}
cooc(i, (j, X_ij) -> {
final Vec sym_j = symDecomp.get(j);
final Vec skew_j = skewsymDecomp.get(j);
final Vec softMaxSym_j = softMaxSym.get(j);
final Vec softMaxSkew_j = softMaxSkewsym.get(j);
final double b_i = bias.get(i);
final double b_j = bias.get(j);
double asum = VecTools.multiply(sym_i, sym_j);
double bsum = VecTools.multiply(skew_i, skew_j);
final int sign = i > j ? -1 : 1;
final double minfo = Math.log(X_ij);
final double diff = b_i + b_j + asum + sign * bsum - minfo;
final double weight = weightingFunc(X_ij);
final double biasStep = weight * diff;
scoreCalculator.adjust(i, j, weight, 0.5 * weight * MathTools.sqr(diff));
update(sym_i, softMaxSym_i, sym_j, softMaxSym_j, diff * weight);
update(skew_i, softMaxSkew_i, skew_j, softMaxSkew_j, diff * weight * sign);
bias.setQuick(i, b_i - step() * biasStep / Math.sqrt(softMaxBias.get(i)));
softMaxBias.setQuick(i, softMaxBias.getQuick(j) + biasStep * biasStep);
bias.setQuick(j, b_j -step() * biasStep / Math.sqrt(softMaxBias.get(j)));
softMaxBias.setQuick(j, softMaxBias.getQuick(j) + biasStep * biasStep);
});
});
Interval.stopAndPrint("Iteration: " + iter + " Score: " + scoreCalculator.gloveScore());
}
final MultiMap mapping = new MultiMap<>();
for (int i = 0; i < dict().size(); i++) {
final CharSeq word = dict().get(i);
mapping.put(word, symDecomp.get(i));
}
return null;
}
private void update(Vec x_i, Vec softMaxD_i, Vec x_j, Vec softMaxD_j, double step) {
IntStream.range(0, x_i.dim()).forEach(id -> {
final double dx_i = x_j.get(id) * step;
final double dx_j = x_i.get(id) * step;
final double maxL_i = softMaxD_i.get(id);
final double maxL_j = softMaxD_j.get(id);
x_i.adjust(id, -step() * dx_i / Math.sqrt(maxL_i));
x_j.adjust(id, -step() * dx_j / Math.sqrt(maxL_j));
softMaxD_i.set(id, maxL_i + MathTools.sqr(dx_i));
softMaxD_j.set(id, maxL_j + MathTools.sqr(dx_j));
});
}
private synchronized void split(int i, int[] indices) {
CharSeq word = dict().get(i);
log.info("Splitting word: " + word);
final int newIndex = dict().size();
final LongSeq line = cooc(i);
final TIntSet removeSet = new TIntHashSet(indices);
cooc(newIndex, line.sub(indices));
dict().add(word);
cooc(i, new LongSeq(line.stream().filter(pack -> removeSet.contains((int) (pack >>> 32))).toArray()));
for (int index : indices) {
LongSeq cooc = cooc(index);
for (int j = 0; j < cooc.length(); j++) {
long entry = cooc.longAt(j);
if (entry >>> 32 == i) {
cooc.data()[j] = (entry & 0x00000000FFFFFFFFL) | ((long)newIndex << 32);
}
}
}
}
private double initializeValue(int dim) {
return (Math.random() - 0.5) / dim;
}
}