com.expleague.ml.methods.seq.automaton.IncrementalAutomatonBuilder Maven / Gradle / Ivy
package com.expleague.ml.methods.seq.automaton;
import com.expleague.commons.math.MathTools;
import com.expleague.ml.methods.seq.automaton.transform.*;
import com.expleague.commons.math.vectors.SingleValueVec;
import com.expleague.commons.math.vectors.Vec;
import com.expleague.commons.seq.Seq;
import com.expleague.commons.seq.regexp.Alphabet;
import com.expleague.ml.data.set.DataSet;
import com.expleague.ml.loss.L2;
import com.expleague.ml.methods.SeqOptimization;
import java.util.*;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.function.Function;
public class IncrementalAutomatonBuilder implements SeqOptimization {
private final int maxStateCount;
private final Alphabet alphabet;
private final Function, Double> stateEvaluation;
private final int maxIterations;
private final ExecutorService executorService = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors() - 1);
public IncrementalAutomatonBuilder(final Alphabet alphabet,
final Function, Double> stateEvaluation,
final int maxStateCount,
final int maxIterations) {
this.alphabet = alphabet;
this.stateEvaluation = stateEvaluation;
this.maxStateCount = maxStateCount;
this.maxIterations = maxIterations;
}
@Override
public Function, Vec> fit(final DataSet> learn, final Loss loss) {
AutomatonStats automatonStats = new AutomatonStats<>(alphabet, learn, loss);
double oldCost = stateEvaluation.apply(automatonStats);
for (int iter = 0; iter < maxIterations; iter++) {
final AutomatonStats automatonStats1 = automatonStats;
final List>> futures = new ArrayList<>();
for (Transform transform: getTransforms(automatonStats)) {
futures.add(executorService.submit(() -> transform.applyTransform(automatonStats1)));
/*
final AutomatonStats newAutomatonStats = transform.applyTransform(automatonStats);
final double newCost = stateEvaluation.compute(newAutomatonStats);
if (newCost < optCost) {
optCost = newCost;
optTransform = transform;
}*/
}
final AutomatonStats optNewStats = futures.stream().map(future -> {
try {
return future.get();
} catch (InterruptedException | ExecutionException e) {
e.printStackTrace();
return null;
}
}).min(Comparator.comparingDouble(stateEvaluation::apply)).orElse(null);
final double optCost = stateEvaluation.apply(optNewStats);
if (optNewStats == null || (optCost >= oldCost - 1e-9)) {
System.out.println("Elapsed " + iter + " iterations");
break;
}
automatonStats = optNewStats ;
removeUnreachableStates(automatonStats);
if (iter % 100 == 0 && iter != 0) {
System.out.printf("Iter=%d, newCost=%f, state count=%d\n",
iter, optCost, automatonStats.getAutomaton().getStateCount());
/*
System.out.printf("Iter=%d, transform=%s, newCost=%f, state count=%d\n",
iter, optTransform.getDescription(), optCost, automatonStats.getAutomaton().getStateCount());
*/
System.out.flush();
}
oldCost = optCost;
}
final DFA automaton = automatonStats.getAutomaton();
final double[] stateValue = new double[automaton.getStateCount()];
for (int i = 0; i < automaton.getStateCount(); i++) {
if (automatonStats.getStateWeight().get(i) > MathTools.EPSILON) {
stateValue[i] = automatonStats.getStateSum().get(i) / automatonStats.getStateWeight().get(i);
}
}
System.out.println("Cur cost = " + stateEvaluation.apply(automatonStats));
return argument -> new SingleValueVec(stateValue[automaton.run(argument)]);
}
private List> getTransforms(final AutomatonStats automatonStats) {
final DFA automaton = automatonStats.getAutomaton();
final int stateCount = automaton.getStateCount();
final Alphabet alphabet = automatonStats.getAlphabet();
final List> transforms = new ArrayList<>();
for (int from = 0; from < stateCount; from++) {
for (int c = 0; c < alphabet.size(); c++) {
if (automaton.hasTransition(from, alphabet.getT(alphabet.condition(c)))) {
// todo commented out to improve performance
transforms.add(new RemoveTransitionTransform<>(from, alphabet.getT(alphabet.condition(c))));
for (int to = 0; to < stateCount; to++) {
if (to != from) {
// todo commented out to improve performance
transforms.add(new ReplaceTransitionTransform<>(from, to, alphabet.getT(alphabet.condition(c))));
}
}
} else {
final T cT = alphabet.getT(alphabet.condition(c));
if (stateCount < maxStateCount) {
transforms.add(new SplitStateTransform<>(from, cT));
}
for (int to = 0; to < stateCount; to++) {
transforms.add(new AddTransitionTransform<>(from, to, cT));
}
}
}
if (stateCount < maxStateCount) {
for (int to = 0; to < stateCount; to++) {
for (int c = 0; c < alphabet.size(); c++) {
final T cT = alphabet.getT(alphabet.condition(c));
if (!automaton.hasTransition(from, cT)) {
for (int c1 = 0; c1 < alphabet.size(); c1++) {
transforms.add(new AddNewStateTransform<>(from, to, cT, alphabet.getT(alphabet.condition(c1))));
}
}
}
}
}
}
return transforms;
}
private void removeUnreachableStates(final AutomatonStats automatonStats) {
final DFA automaton = automatonStats.getAutomaton();
final Queue queue = new LinkedList<>();
final Alphabet alphabet = automatonStats.getAlphabet();
queue.add(automaton.getStartState());
final boolean[] reached = new boolean[automaton.getStateCount()];
reached[automaton.getStartState()] = true;
while (!queue.isEmpty()) {
final int v = queue.poll();
for (int c = 0; c < automatonStats.getAlphabet().size(); c++) {
final int to = automaton.getTransition(v, alphabet.getT(alphabet.condition(c)));
if (to != -1 && !reached[to]) {
queue.add(to);
reached[to] = true;
}
}
}
for (int i = automaton.getStateCount() - 1; i >= 0; i--) {
if (!reached[i] && i != automaton.getStartState()) {
automaton.removeState(i);
automatonStats.getSamplesEndState().remove(i);
automatonStats.getStateWeight().remove(i);
automatonStats.getStateSum().remove(i);
automatonStats.getStateSum2().remove(i);
automatonStats.getSamplesViaState().remove(i);
}
}
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy