All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.expleague.ml.methods.seq.RunSpliceData Maven / Gradle / Ivy

package com.expleague.ml.methods.seq;

import com.expleague.commons.func.Action;
import com.expleague.commons.func.Computable;
import com.expleague.commons.io.codec.seq.DictExpansion;
import com.expleague.commons.io.codec.seq.Dictionary;
import com.expleague.commons.math.vectors.Mx;
import com.expleague.commons.math.vectors.SingleValueVec;
import com.expleague.commons.math.vectors.Vec;
import com.expleague.commons.math.vectors.VecTools;
import com.expleague.commons.math.vectors.impl.vectors.ArrayVec;
import com.expleague.commons.random.FastRandom;
import com.expleague.commons.seq.CharSeq;
import com.expleague.commons.seq.CharSeqArray;
import com.expleague.commons.seq.IntSeq;
import com.expleague.commons.seq.Seq;
import com.expleague.commons.seq.regexp.Alphabet;
import com.expleague.ml.data.set.DataSet;
import com.expleague.ml.loss.LLLogit;
import com.expleague.ml.methods.SeqOptimization;
import com.expleague.ml.optimization.impl.AdamDescent;
import com.expleague.commons.math.vectors.impl.mx.VecBasedMx;
import gnu.trove.list.TIntList;
import gnu.trove.list.array.TIntArrayList;

import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.nio.charset.StandardCharsets;
import java.nio.file.Paths;
import java.text.SimpleDateFormat;
import java.util.*;
import java.util.zip.GZIPInputStream;

public class RunSpliceData {
  private static final List CLASSES = Arrays.asList("EI", "IE", "N");
  private static final int ALPHABET_SIZE = 15;
  private static final int BOOST_ITERS = 450;
  private static final double BOOST_STEP = 1;
  private static final int MAX_STATE_COUNT = 6;
  private static final double GRAD_STEP = 0.1;
  private static final FastRandom random = new FastRandom(239);
  private static final int THREAD_COUNT = 1;
  private static final int EPOCH_COUNT = 100;

  final List> trainData = new ArrayList<>();
  final List trainDataAsMx = new ArrayList<>();
  Vec trainTarget;

  final List> testData = new ArrayList<>();
  final List testDataAsMx = new ArrayList<>();
  Vec testTarget;

  final Alphabet alphabet = new IntAlphabet(ALPHABET_SIZE);
  private Dictionary dictionary;

  public void loadData() throws IOException {
    System.out.println("Number of cores: " + Runtime.getRuntime().availableProcessors());
    System.out.println("Alphabet size: " + ALPHABET_SIZE);
    System.out.println("States count: " + MAX_STATE_COUNT);
    System.out.println("GradBoost step: " + BOOST_STEP);
    System.out.println("GradBoost iters: " + BOOST_ITERS);
    System.out.println("GradDesc step: " + GRAD_STEP);
    System.out.println("Grad epoch: " + EPOCH_COUNT);

    final List data = new ArrayList<>();
    final TIntList classes = new TIntArrayList();
    final int[] classCount = new int[CLASSES.size()];

    new BufferedReader(new InputStreamReader(new GZIPInputStream(new FileInputStream(
        Paths.get("ml/src/splice.data.txt.gz").toFile())), StandardCharsets.UTF_8)
    ).lines().forEach(line -> {
      final String[] tokens = line.split(",");
      final int clazz = CLASSES.indexOf(tokens[0]);
      if (clazz == -1) {
        throw new IllegalStateException("Unknown class " + tokens[0]);
      }
      classes.add(clazz);
      classCount[clazz]++;
      data.add(new CharSeqArray(tokens[2].trim().toCharArray()));
    });

    List alpha = Arrays.asList('A', 'C', 'G', 'T');
    DictExpansion de = new DictExpansion<>(alpha, ALPHABET_SIZE);
    for (int i = 0; i < 50; i++) {
      data.forEach(de::accept);
    }

    dictionary = de.result();

    System.out.println("New dictionary: " + dictionary.alphabet().toString());

    int[] freqs = new int[ALPHABET_SIZE];
    int[] sum = new int[1];
    data.forEach(word -> {
      Seq seq = dictionary.parse(word);
      for (int i =0; i < seq.length(); i++) {
        freqs[seq.at(i)]++;
        sum[0]++;
      }
    });
    Map map = new TreeMap<>();
    for (int i = 0; i < ALPHABET_SIZE; i++) {
      map.put(dictionary.get(i).toString(), 1.0 * freqs[i] / sum[0]);
    }

    map.entrySet().stream().sorted(Map.Entry.comparingByValue(Collections.reverseOrder())).forEach(it -> System.out.printf("%s: %.5f, ", it.getKey(), it.getValue()));
    System.out.println();
    final int sampleCount = Arrays.stream(classCount).min().orElse(0);
    final int trainCount = sampleCount * 8 / 10;
    int[] trainClasses = new int[trainCount * CLASSES.size()];
    int[] testClasses = new int[(sampleCount - trainCount) * CLASSES.size()];
    for (int clazz = 0; clazz < CLASSES.size(); clazz++) {
      int cnt = 0;
      for (int i = 0; i < data.size(); i++) {
        if (classes.get(i) != clazz) {
          continue;
        }

        final Seq seq = dictionary.parse(data.get(i));
        final Mx mx = new VecBasedMx(seq.length(), dictionary.size());
        for (int j = 0; j < seq.length(); j++) {
          mx.set(j, seq.at(j), 1);
        }
        if (cnt < trainCount) {
          trainClasses[trainData.size()] = classes.get(i);
          trainData.add(seq);
          trainDataAsMx.add(mx);
        } else if (cnt < sampleCount){
          testClasses[testData.size()] = classes.get(i);
          testData.add(seq);
          testDataAsMx.add(mx);
        }
        cnt++;
      }
    }

    trainTarget = VecTools.fromIntSeq(new IntSeq(trainClasses));
    testTarget = VecTools.fromIntSeq(new IntSeq(testClasses));
  }

  public void test() {

    DataSet> data = new DataSet.Stub>(null) {
      @Override
      public Seq at(int i) {
        return trainData.get(i);
      }

      @Override
      public int length() {
        return trainData.size();
      }

      @Override
      public Class> elementType() {
        return null;
      }
    };

    List labels = new ArrayList<>();
    for (int i = 0; i < trainTarget.length(); i++) {
      labels.add((int) Math.round(trainTarget.get(i)));
    }

    long start = System.nanoTime();
    final GradientSeqBoosting boosting = new GradientSeqBoosting<>(
        new BootstrapSeqOptimization<>(
            new PNFA<>(MAX_STATE_COUNT, alphabet.size(), random,
                //new SAGADescent(GRAD_STEP, EPOCH_COUNT, random, THREAD_COUNT)
              new AdamDescent(random, 100, 4), 1
            )
            , random
        ), BOOST_ITERS, BOOST_STEP
    );

    Action, Vec>> listener = classifier -> {
      System.out.println("Current time: " + new SimpleDateFormat("yyyyMMdd_HH:mm:ss").format(Calendar.getInstance().getTime()));
      System.out.println("Train loss: " + getLoss(trainData, trainTarget, classifier));
      System.out.println("Test loss: " + getLoss(testData, testTarget, classifier));
    };
    boosting.addListener(listener);

    final Computable, Vec> classifier =
        new OneVsRest(boosting, CLASSES.size(), labels).fit(data, new LLLogit(trainTarget, null));
    long end = System.nanoTime();

    System.out.println(String.format("Elapsed %.2f minutes", (end - start) / 60e9));
    System.out.println("Train accuracy of " + getAccuracy(trainData, trainTarget, classifier));
    System.out.println("Test accuracy of  " + getAccuracy(testData, testTarget, classifier));
  }

  private class OneVsRest implements SeqOptimization {
    private final int classCount;
    private final SeqOptimization optimization;
    private final List labels;

    OneVsRest(SeqOptimization optimization, int classCount, List labels) {
      this.classCount = classCount;
      this.optimization = optimization;
      this.labels = labels;
    }

    @Override
    public Computable, Vec> fit(DataSet> learn, LLLogit llLogit) {
      List, Vec>> classifiers = new ArrayList<>();
      for (int i = 0; i < classCount; i++) {
        Vec newTarget = VecTools.copy(llLogit.labels());
        for (int j = 0; j < llLogit.labels().length(); j++) {
          if (labels.get(j) != i) {
            newTarget.set(j, 0);
          } else {
            newTarget.set(j, 1);
          }
        }
        classifiers.add(optimization.fit(learn, new LLLogit(newTarget, null)));
      }

      return argument -> {
        double max = -Double.MAX_VALUE;
        int clazz = -1;
        for (int i = 0; i < classCount; i++) {
          double prob = classifiers.get(i).compute(argument).get(0);
          if (prob > max) {
            max = prob;
            clazz = i;
          }
        }
        return new SingleValueVec(clazz);
      };
    }
  }

  private double getLoss(List> data, Vec target, Computable, Vec> classifier) {
    final LLLogit lllogit = new LLLogit(target, null);
    Vec values = new ArrayVec(target.dim());
    for (int i =0 ; i < target.dim(); i++) {
      values.set(i, classifier.compute(data.get(i)).get(0));
    }
    return lllogit.value(values);
  }


  private double getAccuracy(List> data, Vec target, Computable, Vec> classifier) {
    int passedCnt = 0;
    int classAccuracy[] = new int[3];
    int classSize[] = new int[3];
    for (int i = 0; i < data.size(); i++) {
      if (Math.round(classifier.compute(data.get(i)).get(0)) == target.get(i)) {
        passedCnt++;
        classAccuracy[(int) Math.round(target.get(i))]++;
      }
      classSize[(int) Math.round(target.get(i))]++;
    }
    for (int i = 0; i < 3; i++) {
      System.out.println("Class " + i + " accuracy:" + classAccuracy[i] * 1.0 / classSize[i]);
    }
    return 1.0 * passedCnt / data.size();
  }

  public static void main(String[] args) throws IOException {
    RunSpliceData test = new RunSpliceData();
    test.loadData();
    test.test();
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy