All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.expleague.ml.methods.seq.DictExpansionOptimization Maven / Gradle / Ivy

package com.expleague.ml.methods.seq;

import com.expleague.commons.io.codec.seq.DictExpansion;
import com.expleague.commons.io.codec.seq.Dictionary;
import com.expleague.commons.math.vectors.Vec;
import com.expleague.commons.seq.Seq;
import com.expleague.ml.TargetFunc;
import com.expleague.ml.data.set.DataSet;
import com.expleague.ml.methods.SeqOptimization;
import com.fasterxml.jackson.annotation.JsonIgnore;

import java.io.PrintStream;
import java.util.Collection;
import java.util.HashSet;
import java.util.function.Function;

public class DictExpansionOptimization, Loss extends TargetFunc>
    implements SeqOptimization {

  private final SeqOptimization optimization;
  private final int maxAlphabetSize;
  private Collection alphabet;
  private PrintStream tracePrint;

  public DictExpansionOptimization(
      SeqOptimization optimization,
      int maxAlphabetSize,
      Collection alphabet,
      PrintStream tracePrint
  ) {
    this.optimization = optimization;
    this.maxAlphabetSize = maxAlphabetSize;
    this.alphabet = alphabet;
    this.tracePrint = tracePrint;
  }

  @Override
  public Function, Vec> fit(DataSet> learn, Loss loss) {
    final long startTime = System.nanoTime();

    Collection realAlphabet = alphabet;
    if (alphabet == null) {
      realAlphabet = new HashSet<>();
      for (Seq seq: learn) {
        for (T t: seq) {
          realAlphabet.add(t);
        }
      }
    }

    final DictExpansion de = new DictExpansion<>(realAlphabet, maxAlphabetSize, tracePrint);

    for (int iter = 0; iter < 1000; iter++) {
      for (Seq seq: learn) {
        de.accept(seq);
        if (de.result() != null && de.result().size() == maxAlphabetSize) {
          break;
        }
      }
    }

    final Dictionary dict = de.result();

    if (dict.size() < maxAlphabetSize) {
      throw new IllegalStateException("Cannot build alphabet of size " + maxAlphabetSize + ". Actual size is " + dict.size());
    }

    if (tracePrint != null) {
      tracePrint.println("Time to build dictexpansion: " + (System.nanoTime() - startTime) / 1e9 + "s");
    }

    final DataSet> expandedLearn = new DataSet.Stub>(null) {

      @Override
      public Seq at(int i) {
        return dict.parse(learn.at(i));
      }

      @Override
      public int length() {
        return learn.length();
      }

      @Override
      public Class> elementType() {
        return null;
      }
    };

    Function, Vec> model = optimization.fit(expandedLearn, loss);
    return new DictExpansionModel<>(model, dict);
  }

  static class DictExpansionModel> implements Function, Vec> {
    private Function, Vec> model;
    @JsonIgnore
    private Dictionary dict;

    public DictExpansionModel(Function, Vec> model, Dictionary dict) {
      this.model = model;
      this.dict = dict;
    }

    public Function, Vec> getModel() {
      return model;
    }

    public void setModel(Function, Vec> model) {
      this.model = model;
    }

    public Dictionary getDict() {
      return dict;
    }

    public void setDict(Dictionary dict) {
      this.dict = dict;
    }

    @Override
    public Vec apply(Seq seq) {
      return model.apply(dict.parse(seq));
    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy