Please wait. This can take some minutes ...
Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance.
Project price only 1 $
You can buy this project and download/modify it how often you want.
com.expleague.ml.embedding.impl.EmbeddingBuilderBase Maven / Gradle / Ivy
package com.expleague.ml.embedding.impl;
import com.expleague.commons.csv.CsvRow;
import com.expleague.commons.csv.WritableCsvRow;
import com.expleague.commons.func.IntDoubleConsumer;
import com.expleague.commons.seq.CharSeq;
import com.expleague.commons.seq.CharSeqTools;
import com.expleague.commons.seq.LongSeq;
import com.expleague.commons.seq.LongSeqBuilder;
import com.expleague.ml.embedding.Embedding;
import gnu.trove.list.TLongList;
import gnu.trove.list.array.TIntArrayList;
import gnu.trove.list.array.TLongArrayList;
import gnu.trove.map.TObjectIntMap;
import gnu.trove.map.hash.TObjectIntHashMap;
import org.jetbrains.annotations.Nullable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.*;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.*;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import java.util.function.Consumer;
import java.util.function.IntFunction;
import java.util.function.LongFunction;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.LongStream;
import java.util.stream.Stream;
import java.util.zip.GZIPInputStream;
import java.util.zip.GZIPOutputStream;
public abstract class EmbeddingBuilderBase implements Embedding.Builder {
private static final Logger log = LoggerFactory.getLogger(EmbeddingBuilderBase.class.getName());
public static final int CAPACITY = 50_000_000;
private Path path;
private int minCount = 5;
private int windowLeft = 15;
private int windowRight = 15;
private Embedding.WindowType windowType = Embedding.WindowType.LINEAR;
private int iterations = 25;
private double step = 0.01;
private List wordsList = new ArrayList<>();
private TObjectIntMap wordsIndex = new TObjectIntHashMap<>(50_000, 0.6f, -1);
private boolean dictReady;
private List cooc;
private boolean coocReady = false;
@Override
public Embedding.Builder file(Path path) {
this.path = path;
return this;
}
@Override
public Embedding.Builder minWordCount(int count) {
this.minCount = count;
return this;
}
@Override
public Embedding.Builder window(Embedding.WindowType type, int left, int right) {
this.windowLeft = left;
this.windowRight = right;
this.windowType = type;
return this;
}
@Override
public Embedding.Builder step(double step) {
this.step = step;
return this;
}
@Override
public Embedding.Builder iterations(int count) {
iterations = count;
return this;
}
protected abstract Embedding fit();
protected List dict() {
return wordsList;
}
protected void cooc(int i, IntDoubleConsumer consumer) {
cooc.get(i).stream().forEach(packed ->
consumer.accept((int)(packed >>> 32), Float.intBitsToFloat((int)(packed & 0xFFFFFFFFL)))
);
}
protected int index(CharSequence word) {
return wordsIndex.get(CharSeq.create(word));
}
protected LongSeq cooc(int i) {
return cooc.get(i);
}
protected synchronized void cooc(int i, LongSeq set) {
if (i > cooc.size()) {
for (int k = cooc.size(); k <= i; k++) {
cooc.add(new LongSeq());
}
}
cooc.set(i, set);
}
protected int T() {
return iterations;
}
protected double step() {
return step;
}
@Override
public Embedding build() {
try {
log.info("==== Dictionary phase ====");
long time = System.nanoTime();
acquireDictionary();
log.info("==== " + TimeUnit.NANOSECONDS.toSeconds(System.nanoTime() - time) + "s ====");
log.info("==== Cooccurrences phase ====");
time = System.nanoTime();
acquireCooccurrences();
log.info("==== " + TimeUnit.NANOSECONDS.toSeconds(System.nanoTime() - time) + "s ====");
log.info("==== Training phase ====");
time = System.nanoTime();
final Embedding result = fit();
log.info("==== " + TimeUnit.NANOSECONDS.toSeconds(System.nanoTime() - time) + "s ====");
return result;
}
catch (Exception e) {
if (e instanceof RuntimeException)
throw (RuntimeException) e;
else
throw new RuntimeException(e);
}
}
private void acquireCooccurrences() throws IOException {
final Path coocPath = Paths.get(this.path.getParent().toString(), strip(this.path.getFileName()) + "." + windowType.name().toLowerCase() + "-" + windowLeft + "-" + windowRight + "-" + minCount + ".cooc");
try {
final LongSeq[] cooc = new LongSeq[wordsList.size()];
Reader coocReader = readExisting(coocPath);
if (coocReader != null) {
log.info("Reading existing cooccurrences");
CharSeqTools.llines(coocReader, true).forEach(line -> {
final LongSeqBuilder values = new LongSeqBuilder(wordsList.size());
final CharSeq[] wordWeightPair = new CharSeq[3];
CharSeqTools.split(line.line, " ", false)
.skip(1)
.map(part -> CharSeqTools.split(part, ':', wordWeightPair))
.forEach(split -> values.add(((long)CharSeqTools.parseInt(split[0])) << 32 | Float.floatToIntBits(CharSeqTools.parseFloat(split[2]))));
cooc[line.number] = values.build();
});
this.cooc = new ArrayList<>(Arrays.asList(cooc));
coocReady = true;
}
}
catch (IOException ioe) {
log.warn("Unable to read : " + coocPath, ioe);
}
if (!coocReady) {
log.info("Generating cooccurrences for " + this.path);
final long startTime = System.nanoTime();
final Lock[] rowLocks = IntStream.range(0, wordsList.size()).mapToObj(i -> new ReentrantLock()).toArray(Lock[]::new);
final List accumulators = new ArrayList<>();
cooc = IntStream.range(0, wordsList.size()).mapToObj(i -> LongSeq.empty()).collect(Collectors.toList());
final CharSeq newLine = CharSeq.create("777newline777");
wordsIndex.put(newLine, Integer.MAX_VALUE);
source().peek(new Consumer() {
long line = 0;
long time = System.nanoTime();
@Override
public synchronized void accept(CharSeq l) {
if ((++line) % 10000 == 0) {
log.info(line + " lines processed for " + TimeUnit.NANOSECONDS.toSeconds(System.nanoTime() - time) + "s");
time = System.nanoTime();
}
}
}).map(line -> (CharSeq)CharSeqTools.concat(line, " ", newLine)).flatMap(CharSeqTools::words).map(this::normalize).mapToInt(wordsIndex::get).filter(idx -> idx >= 0).mapToObj(new IntFunction() {
final TIntArrayList queue = new TIntArrayList(1000);
int offset = 0;
@Override
public synchronized LongStream apply(int idx) {
if (idx == Integer.MAX_VALUE) { // new line
queue.resetQuick();
offset = 0;
return LongStream.empty();
}
int pos = queue.size();
final long[] out = new long[windowLeft + windowRight];
int outIndex = 0;
for (int i = offset; i < pos; i++) {
byte distance = (byte)(pos - i);
if (distance == 0) {
log.warn("Zero distance occured! pos: " + pos + " i: " + i);
System.err.println("Zero distance occured! pos: " + pos + " i: " + i);
}
if (distance <= windowRight)
out[outIndex++] = pack(queue.getQuick(i), idx, distance);
if (distance <= windowLeft)
out[outIndex++] = pack(idx, queue.getQuick(i), (byte)-distance);
}
queue.add(idx);
if (queue.size() > Math.max(windowLeft, windowRight)) {
offset++;
if (offset > 1000 - Math.max(windowLeft, windowRight)) {
queue.remove(0, offset);
offset = 0;
}
}
return Arrays.stream(out, 0, outIndex);
}
}).flatMapToLong(entries -> entries).parallel()/*.peek(p -> {
System.out.println(dict().get(unpackA(p)) + "->" + dict().get(unpackB(p)) + "=" + unpackDist(p));
})*/.mapToObj(new LongFunction() {
volatile TLongList accumulator;
@Override
public TLongList apply(long value) {
if (accumulator == null || accumulator.size() >= CAPACITY) {
synchronized (this) {
if (accumulator == null || accumulator.size() >= CAPACITY) {
final TLongList accumulator = this.accumulator;
accumulators.add(this.accumulator = new TLongArrayList(CAPACITY));
return accumulator;
}
}
}
accumulator.add(value);
return null;
}
}).filter(Objects::nonNull).peek(accumulators::remove).peek(TLongList::sort).forEach(acc -> merge(rowLocks, (TLongArrayList)acc));
accumulators.parallelStream().peek(TLongList::sort).forEach(acc -> merge(rowLocks, (TLongArrayList)acc));
log.info("Generated for " + TimeUnit.NANOSECONDS.toSeconds(System.nanoTime() - startTime) + "s");
wordsIndex.remove(newLine);
final Path coocOut = Paths.get(coocPath.toString() + ".gz");
try (Writer coocWriter = new OutputStreamWriter(new GZIPOutputStream(Files.newOutputStream(coocOut)))) {
log.info("Writing cooccurrences to: " + coocOut);
for (int i = 0; i < this.cooc.size(); i++) {
final LongSeq row = this.cooc.get(i);
final StringBuilder builder = new StringBuilder();
builder.append(dict().get(i)).append('\t');
row.stream().forEach(packed -> {
final int wordId = (int)(packed >>> 32);
builder.append(wordId).append(':').append(dict().get(wordId)).append(':').append(CharSeqTools.ppDouble(Float.intBitsToFloat(((int)(packed & 0xFFFFFFFFL))))).append(' ');
});
coocWriter.append(builder, 0, builder.length() - 1).append('\n');
}
}
catch (IOException ioe) {
log.warn("Unable to write dictionary to " + coocOut, ioe);
}
coocReady = true;
}
}
private void merge(Lock[] rowLocks, TLongArrayList acc) {
final int size = acc.size();
final float[] weights = new float[256];
IntStream.range(0, 256).forEach(i -> weights[i] = (float)windowType.weight(i > 126 ? -256 + i : i));
LongSeq prevRow = null;
final LongSeqBuilder updatedRow = new LongSeqBuilder(wordsList.size());
int prevA = -1;
int pos = 0; // insertion point
int prevLength = 0;
try {
for (int i = 0; i < size; i++) {
long next = acc.getQuick(i);
final long currentPairMasked = next & 0xFFFFFFFFFFFFFF00L;
final int a = unpackA(next);
final int b = unpackB(next);
float weight = weights[unpackDist(next)];
while (++i < size && ((next = acc.getQuick(i)) & 0xFFFFFFFFFFFFFF00L) == currentPairMasked) {
weight += weights[unpackDist(next)];
}
if (i < size)
i--;
if (a != prevA) {
if (prevA >= 0) {
updatedRow.addAll(prevRow.sub(pos, prevLength));
cooc.set(prevA, updatedRow.build(prevRow.data(), 0.2, 100));
rowLocks[prevA].unlock();
}
prevA = a;
prevRow = cooc.get(a);
prevLength = prevRow.length();
pos = 0;
rowLocks[a].lock();
}
long prevPacked;
while (pos < prevLength) { // merging previous version of the cooc row with current data
prevPacked = prevRow.longAt(pos);
int prevB = (int)(prevPacked >>> 32);
if (prevB >= b) {
if (prevB == b) { // second entry matches with the merged one
weight += Float.intBitsToFloat((int) (prevPacked & 0xFFFFFFFFL));
pos++;
}
break;
}
updatedRow.append(prevPacked);
pos++;
}
final long repacked = (((long)b) << 32) | Float.floatToIntBits(weight);
updatedRow.append(repacked);
}
//noinspection ConstantConditions
updatedRow.addAll(prevRow.sub(pos, prevLength));
cooc.set(prevA, updatedRow.build(prevRow.data(), 0.2, 100));
}
finally {
rowLocks[prevA].unlock();
}
}
private int unpackA(long next) {
return (int)(next >>> 36);
}
private int unpackB(long next) {
return ((int)(next >>> 8)) & 0x0FFFFFFF;
}
private int unpackDist(long next) {
return (int)(0xFF & next);
}
private long pack(long a, long b, byte dist) {
return (a << 36) | (b << 8) | ((long) dist & 0xFF);
}
private void acquireDictionary() throws IOException {
final Path dictPath = Paths.get(this.path.getParent().toString(), strip(this.path.getFileName()) + ".dict");
try {
Reader dictReader = readExisting(dictPath);
if (dictReader != null) {
log.info("Reading existing dictionary");
try (Stream dictStream = CsvRow.read(dictReader)) {
dictStream.filter(row -> row.asInt("freq") >= minCount).map(row -> CharSeq.intern(row.at("word"))).forEach(word -> {
wordsIndex.put(word, wordsList.size());
wordsList.add(word);
});
}
dictReady = true;
}
}
catch (IOException ioe) {
log.warn("Unable to read dictionary: " + dictPath, ioe);
}
if (!dictReady) {
log.info("Generating dictionary for " + this.path);
TObjectIntMap wordsCount = new TObjectIntHashMap<>();
source().flatMap(CharSeqTools::words).filter(word -> word.stream().anyMatch(Character::isLetter)).map(this::normalize).forEach(w ->
wordsCount.adjustOrPutValue(w, 1, 1)
);
final Path dictOut = Paths.get(dictPath.toString() + ".gz");
final Supplier factory = CsvRow.factory("word", "freq");
final List words = new ArrayList<>(wordsCount.keySet());
words.sort(Comparator.comparingInt(wordsCount::get).reversed());
try (Writer dictWriter = new OutputStreamWriter(new GZIPOutputStream(Files.newOutputStream(dictOut)))) {
log.info("Writing dictionary to: " + dictOut);
dictWriter.append(factory.get().names().toString()).append('\n');
words.forEach(word ->
factory.get().set("word", word).set("freq", wordsCount.get(word)).writeln(dictWriter)
);
}
catch (IOException ioe) {
log.warn("Unable to write dictionary to " + dictOut, ioe);
}
words.forEach(word -> {
if (wordsCount.get(word) >= minCount) {
wordsIndex.put(word, wordsList.size());
wordsList.add(word);
}
});
dictReady = true;
}
}
private CharSeq normalize(CharSeq word) {
final int initialLength = word.length();
int len = initialLength;
int st = 0;
while ((st < len) && !Character.isLetterOrDigit(word.charAt(st))) {
st++;
}
while ((st < len) && !Character.isLetterOrDigit(word.charAt(len - 1))) {
len--;
}
word = ((st > 0) || (len < initialLength)) ? word.subSequence(st, len) : word;
return (CharSeq)CharSeqTools.toLowerCase(word);
}
protected Stream source() throws IOException {
if (path.getFileName().toString().endsWith(".gz"))
return CharSeqTools.lines(new InputStreamReader(new GZIPInputStream(Files.newInputStream(path)), StandardCharsets.UTF_8));
return CharSeqTools.lines(Files.newBufferedReader(path));
}
protected String strip(Path fileName) {
final String name = fileName.toString();
if (name.endsWith(".gz"))
return name.substring(0, name.length() - ".gz".length());
return name;
}
protected float unpackWeight(LongSeq cooc, int v) {
return Float.intBitsToFloat((int) (cooc.longAt(v) & 0xFFFFFFFFL));
}
protected int unpackB(LongSeq cooc, int v) {
return (int) (cooc.longAt(v) >>> 32);
}
@Nullable
private Reader readExisting(Path path) throws IOException {
if (Files.exists(path))
return Files.newBufferedReader(path);
else if (Files.exists(Paths.get(path.toString() + ".gz")))
return new InputStreamReader(new GZIPInputStream(Files.newInputStream(Paths.get(path.toString() + ".gz"))), StandardCharsets.UTF_8);
return null;
}
protected class ScoreCalculator {
private final double[] scores;
private final double[] weights;
private final long[] counts;
public ScoreCalculator(int dim) {
counts = new long[dim];
scores = new double[dim];
weights = new double[dim];
}
public void adjust(int i, int j, double weight, double value) {
weights[i] += weight;
scores[i] += value;
counts[i] ++;
}
public double gloveScore() {
return Arrays.stream(scores).sum() / Arrays.stream(counts).sum();
}
public long count() {
return Arrays.stream(counts).sum();
}
}
}