com.expleague.ml.methods.seq.automaton.AutomatonStats Maven / Gradle / Ivy
package com.expleague.ml.methods.seq.automaton;
import com.expleague.commons.math.vectors.Vec;
import com.expleague.commons.math.vectors.VecTools;
import com.expleague.commons.math.vectors.impl.vectors.ArrayVec;
import com.expleague.commons.seq.Seq;
import com.expleague.commons.seq.regexp.Alphabet;
import com.expleague.ml.data.set.DataSet;
import com.expleague.ml.loss.L2;
import com.expleague.ml.loss.WeightedL2;
import gnu.trove.list.TDoubleList;
import gnu.trove.list.array.TDoubleArrayList;
import gnu.trove.map.TIntIntMap;
import gnu.trove.map.hash.TIntIntHashMap;
import gnu.trove.set.TIntSet;
import gnu.trove.set.hash.TIntHashSet;
import java.util.ArrayList;
import java.util.List;
public class AutomatonStats {
private DFA automaton;
private final Alphabet alphabet;
private final DataSet> dataSet;
private final Vec target;
private final Vec weights;
private TDoubleList stateSum = new TDoubleArrayList();
private TDoubleList stateSum2 = new TDoubleArrayList();
private TDoubleList stateWeight = new TDoubleArrayList();
private List samplesViaState = new ArrayList<>();
private List samplesEndState = new ArrayList<>();
public AutomatonStats(AutomatonStats other) {
automaton = other.automaton.copy();
alphabet = other.alphabet;
dataSet = other.dataSet;
target = other.target;
weights = other.weights;
samplesEndState = new ArrayList<>(other.samplesEndState);
samplesViaState = new ArrayList<>(other.samplesViaState);
stateSum = new TDoubleArrayList(other.stateSum);
stateSum2 = new TDoubleArrayList(other.stateSum2);
stateWeight = new TDoubleArrayList(other.stateWeight);
}
public AutomatonStats(Alphabet alphabet, DataSet> dataSet, L2 loss) {
automaton = new DFA(alphabet);
this.dataSet = dataSet;
if (loss instanceof WeightedL2) {
final WeightedL2 weightedLoss = (WeightedL2) loss;
this.weights = weightedLoss.getWeights();
this.target = weightedLoss.target();
} else {
this.target = VecTools.copy(loss.target());
this.weights = new ArrayVec(target.length());
VecTools.fill(this.weights, 1);
VecTools.scale(this.target, this.weights);
}
this.alphabet = alphabet;
final TIntSet allIndicesSet = new TIntHashSet();
final TIntIntMap allIndicesMap = new TIntIntHashMap();
for (int i = 0; i < dataSet.length(); i++) {
allIndicesSet.add(i);
allIndicesMap.put(i, 0);
}
stateWeight.add(dataSet.length());
stateSum.add(VecTools.sum(target));
stateSum2.add(VecTools.sum2(target));
samplesEndState.add(allIndicesMap);
samplesViaState.add(allIndicesSet);
samplesEndState.add(new TIntIntHashMap());
samplesViaState.add(new TIntHashSet());
}
public Alphabet getAlphabet() {
return alphabet;
}
public DFA getAutomaton() {
return automaton;
}
public List getSamplesViaState() {
return samplesViaState;
}
public List getSamplesEndState() {
return samplesEndState;
}
public DataSet> getDataSet() {
return dataSet;
}
public TDoubleList getStateSum() {
return stateSum;
}
public TDoubleList getStateSum2() {
return stateSum2;
}
public void setSamplesViaState(List samplesViaState) {
this.samplesViaState = samplesViaState;
}
public void setSamplesEndState(List samplesEndState) {
this.samplesEndState = samplesEndState;
}
public TDoubleList getStateWeight() {
return stateWeight;
}
public Vec getTarget() {
return target;
}
public void setStateSum(TDoubleList stateSum) {
this.stateSum = stateSum;
}
public void setStateSum2(TDoubleList stateSum2) {
this.stateSum2 = stateSum2;
}
public void setStateWeight(TDoubleList stateWeight) {
this.stateWeight = stateWeight;
}
public void setAutomaton(DFA automaton) {
this.automaton = automaton;
}
public Vec getWeights() {
return weights;
}
public int addNewState() {
final int newState = automaton.createNewState();
samplesViaState.add(new TIntHashSet());
samplesEndState.add(new TIntIntHashMap());
stateWeight.add(0);
stateSum.add(0);
stateSum2.add(0);
return newState;
}
}