com.expleague.ml.embedding.impl.EmbeddingImpl Maven / Gradle / Ivy
package com.expleague.ml.embedding.impl;
import com.expleague.commons.math.vectors.Vec;
import com.expleague.commons.math.vectors.VecTools;
import com.expleague.commons.math.vectors.impl.vectors.ArrayVec;
import com.expleague.commons.seq.CharSeq;
import com.expleague.ml.data.tools.DataTools;
import com.expleague.ml.embedding.Embedding;
import gnu.trove.map.TObjectIntMap;
import gnu.trove.map.hash.TObjectIntHashMap;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.Reader;
import java.io.Writer;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
public class EmbeddingImpl implements Embedding {
private final Map mapping;
private final List vocab;
private final TObjectIntMap invVocab = new TObjectIntHashMap<>();
public EmbeddingImpl(Map mapping) {
this.mapping = mapping;
this.vocab = new ArrayList<>(mapping.keySet());
for (int i = 0; i < vocab.size(); i++) {
invVocab.put(vocab.get(i), i);
}
}
public boolean inVocab(T obj) {
return vocab.contains(obj);
}
public int vocabSize() {
return vocab.size();
}
public int getIndex(T obj) {
return invVocab.get(obj);
}
public T getObj(int i) {
return vocab.get(i);
}
@Override
public double distance(T a, T b) {
Vec vA = mapping.get(a);
Vec vB = mapping.get(b);
return vA == null || vB == null ? Double.POSITIVE_INFINITY : VecTools.cosine(vA, vB);
}
@Override
public Vec apply(T t) {
return mapping.get(t);
}
public void write(Writer to) {
mapping.forEach((word, vec) -> {
try {
to.append(DataTools.SERIALIZATION.write(word))
.append('\t')
.append(DataTools.SERIALIZATION.write(vec))
.append('\n');
}
catch (IOException e) {
throw new RuntimeException(e);
}
});
}
public static EmbeddingImpl read(Reader from, Class extends T> clazz) {
BufferedReader bufferedReader = new BufferedReader(from);
Map mapping = new HashMap<>();
bufferedReader.lines().forEach(line -> {
int partitionIdx = line.lastIndexOf('\t');
T word;
if (clazz.equals(CharSeq.class)) {
//noinspection unchecked
word = (T) CharSeq.intern(DataTools.SERIALIZATION.read(line.substring(0, partitionIdx), CharSequence.class));
} else {
word = DataTools.SERIALIZATION.read(line.substring(0, partitionIdx), clazz);
}
Vec vec = DataTools.SERIALIZATION.read(line.substring(partitionIdx + 1), ArrayVec.class);
mapping.put(word, vec);
});
return new EmbeddingImpl<>(mapping);
}
}