All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.expleague.ml.methods.seq.automaton.AutomatonBruteforce Maven / Gradle / Ivy

package com.expleague.ml.methods.seq.automaton;

import com.expleague.commons.func.Computable;
import com.expleague.commons.math.vectors.SingleValueVec;
import com.expleague.commons.math.vectors.Vec;
import com.expleague.commons.seq.Seq;
import com.expleague.commons.seq.regexp.Alphabet;
import com.expleague.ml.data.set.DataSet;
import com.expleague.ml.loss.L2;
import com.expleague.ml.methods.SeqOptimization;
import com.expleague.ml.methods.seq.automaton.transform.AddTransitionTransform;

public class AutomatonBruteforce implements SeqOptimization {
  private final Alphabet alphabet;
  private final Computable, Double> stateEvaluation;
  private final int maxStateCount;

  public AutomatonBruteforce(final Alphabet alphabet,
                             final Computable, Double> stateEvaluation,
                             final int maxStateCount) {
    this.alphabet = alphabet;
    this.stateEvaluation = stateEvaluation;
    this.maxStateCount = maxStateCount;
  }

  private AutomatonStats bruteforce(int curState, int curAlpha, final AutomatonStats automatonStats) {
    if (curAlpha == alphabet.size()) {
      curAlpha = 0;
      curState++;
    }

    if (curState == maxStateCount) {
      return automatonStats;
    }

    double optCost = Double.MAX_VALUE;
    AutomatonStats optStats = automatonStats;

    for (int toState = 0; toState < maxStateCount; toState++) {
      final AutomatonStats newAutomatonStats = new AddTransitionTransform<>(
              curState, toState, alphabet.getT(alphabet.get(curAlpha))
      ).applyTransform(automatonStats);

      final AutomatonStats curAutomatonStats = bruteforce(curState, curAlpha + 1, newAutomatonStats);
      final double curCost = stateEvaluation.compute(curAutomatonStats);

      if (curCost < optCost) {
        optCost = curCost;
        optStats = curAutomatonStats;
      }
    }

    return optStats;
  }

  @Override
  public Computable, Vec> fit(DataSet> learn, Loss loss) {
    AutomatonStats automatonStats = new AutomatonStats(alphabet, learn, loss);
    for (int i = 1; i < maxStateCount; i++) {
      automatonStats.addNewState();
    }

    automatonStats = bruteforce(0, 0, automatonStats);
    final DFA automaton = automatonStats.getAutomaton();

    final double[] stateValue = new double[automaton.getStateCount()];
    for (int i = 0; i < automaton.getStateCount(); i++) {
      stateValue[i] = automatonStats.getStateSum().get(i) / automatonStats.getStateWeight().get(i);
    }
    System.out.println("Cur cost = " + stateEvaluation.compute(automatonStats));
    return argument -> new SingleValueVec(stateValue[automaton.run(argument)]);
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy