com.expleague.ml.methods.seq.DictExpansionOptimization Maven / Gradle / Ivy
package com.expleague.ml.methods.seq;
import com.expleague.commons.io.codec.seq.DictExpansion;
import com.expleague.commons.io.codec.seq.Dictionary;
import com.expleague.commons.math.vectors.Vec;
import com.expleague.commons.seq.Seq;
import com.expleague.ml.TargetFunc;
import com.expleague.ml.data.set.DataSet;
import com.expleague.ml.methods.SeqOptimization;
import com.fasterxml.jackson.annotation.JsonIgnore;
import java.io.PrintStream;
import java.util.Collection;
import java.util.HashSet;
import java.util.function.Function;
public class DictExpansionOptimization, Loss extends TargetFunc>
implements SeqOptimization {
private final SeqOptimization optimization;
private final int maxAlphabetSize;
private Collection alphabet;
private PrintStream tracePrint;
public DictExpansionOptimization(
SeqOptimization optimization,
int maxAlphabetSize,
Collection alphabet,
PrintStream tracePrint
) {
this.optimization = optimization;
this.maxAlphabetSize = maxAlphabetSize;
this.alphabet = alphabet;
this.tracePrint = tracePrint;
}
@Override
public Function, Vec> fit(DataSet> learn, Loss loss) {
final long startTime = System.nanoTime();
Collection realAlphabet = alphabet;
if (alphabet == null) {
realAlphabet = new HashSet<>();
for (Seq seq: learn) {
for (T t: seq) {
realAlphabet.add(t);
}
}
}
final DictExpansion de = new DictExpansion<>(realAlphabet, maxAlphabetSize, tracePrint);
for (int iter = 0; iter < 1000; iter++) {
for (Seq seq: learn) {
de.accept(seq);
if (de.result() != null && de.result().size() == maxAlphabetSize) {
break;
}
}
}
final Dictionary dict = de.result();
if (dict.size() < maxAlphabetSize) {
throw new IllegalStateException("Cannot build alphabet of size " + maxAlphabetSize + ". Actual size is " + dict.size());
}
if (tracePrint != null) {
tracePrint.println("Time to build dictexpansion: " + (System.nanoTime() - startTime) / 1e9 + "s");
}
final DataSet> expandedLearn = new DataSet.Stub>(null) {
@Override
public Seq at(int i) {
return dict.parse(learn.at(i));
}
@Override
public int length() {
return learn.length();
}
@Override
public Class extends Seq> elementType() {
return null;
}
};
Function, Vec> model = optimization.fit(expandedLearn, loss);
return new DictExpansionModel<>(model, dict);
}
static class DictExpansionModel> implements Function, Vec> {
private Function, Vec> model;
@JsonIgnore
private Dictionary dict;
public DictExpansionModel(Function, Vec> model, Dictionary dict) {
this.model = model;
this.dict = dict;
}
public Function, Vec> getModel() {
return model;
}
public void setModel(Function, Vec> model) {
this.model = model;
}
public Dictionary getDict() {
return dict;
}
public void setDict(Dictionary dict) {
this.dict = dict;
}
@Override
public Vec apply(Seq seq) {
return model.apply(dict.parse(seq));
}
}
}