com.expleague.ml.methods.seq.automaton.AutomatonBruteforce Maven / Gradle / Ivy
package com.expleague.ml.methods.seq.automaton;
import com.expleague.commons.func.Computable;
import com.expleague.commons.math.vectors.SingleValueVec;
import com.expleague.commons.math.vectors.Vec;
import com.expleague.commons.seq.Seq;
import com.expleague.commons.seq.regexp.Alphabet;
import com.expleague.ml.data.set.DataSet;
import com.expleague.ml.loss.L2;
import com.expleague.ml.methods.SeqOptimization;
import com.expleague.ml.methods.seq.automaton.transform.AddTransitionTransform;
public class AutomatonBruteforce implements SeqOptimization {
private final Alphabet alphabet;
private final Computable, Double> stateEvaluation;
private final int maxStateCount;
public AutomatonBruteforce(final Alphabet alphabet,
final Computable, Double> stateEvaluation,
final int maxStateCount) {
this.alphabet = alphabet;
this.stateEvaluation = stateEvaluation;
this.maxStateCount = maxStateCount;
}
private AutomatonStats bruteforce(int curState, int curAlpha, final AutomatonStats automatonStats) {
if (curAlpha == alphabet.size()) {
curAlpha = 0;
curState++;
}
if (curState == maxStateCount) {
return automatonStats;
}
double optCost = Double.MAX_VALUE;
AutomatonStats optStats = automatonStats;
for (int toState = 0; toState < maxStateCount; toState++) {
final AutomatonStats newAutomatonStats = new AddTransitionTransform<>(
curState, toState, alphabet.getT(alphabet.get(curAlpha))
).applyTransform(automatonStats);
final AutomatonStats curAutomatonStats = bruteforce(curState, curAlpha + 1, newAutomatonStats);
final double curCost = stateEvaluation.compute(curAutomatonStats);
if (curCost < optCost) {
optCost = curCost;
optStats = curAutomatonStats;
}
}
return optStats;
}
@Override
public Computable, Vec> fit(DataSet> learn, Loss loss) {
AutomatonStats automatonStats = new AutomatonStats(alphabet, learn, loss);
for (int i = 1; i < maxStateCount; i++) {
automatonStats.addNewState();
}
automatonStats = bruteforce(0, 0, automatonStats);
final DFA automaton = automatonStats.getAutomaton();
final double[] stateValue = new double[automaton.getStateCount()];
for (int i = 0; i < automaton.getStateCount(); i++) {
stateValue[i] = automatonStats.getStateSum().get(i) / automatonStats.getStateWeight().get(i);
}
System.out.println("Cur cost = " + stateEvaluation.compute(automatonStats));
return argument -> new SingleValueVec(stateValue[automaton.run(argument)]);
}
}