All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.expleague.ml.methods.seq.RunImdb Maven / Gradle / Ivy

There is a newer version: 1.4.9
Show newest version
package com.expleague.ml.methods.seq;

import com.expleague.commons.io.codec.seq.DictExpansion;
import com.expleague.commons.io.codec.seq.Dictionary;
import com.expleague.commons.math.vectors.Vec;
import com.expleague.commons.math.vectors.VecTools;
import com.expleague.commons.math.vectors.impl.vectors.ArrayVec;
import com.expleague.commons.random.FastRandom;
import com.expleague.commons.seq.CharSeq;
import com.expleague.commons.seq.CharSeqArray;
import com.expleague.commons.seq.IntSeq;
import com.expleague.commons.seq.Seq;
import com.expleague.ml.data.set.DataSet;
import com.expleague.ml.loss.LLLogit;
import com.expleague.ml.methods.seq.param.*;
import com.expleague.ml.optimization.impl.AdamDescent;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.apache.commons.cli.*;

import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.text.SimpleDateFormat;
import java.util.*;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.stream.Collectors;

public class RunImdb {
  private static final int ALPHABET_SIZE = 10000;
  private static final int TRAIN_SIZE = 25000;
  private static final FastRandom random = new FastRandom(239);

  private static final int BOOST_ITERS = 4;
  private static final int VALUES_EPOCH_COUNT = 5;
  private static final double VALUE_GRAD_STEP = 0.3;
  private final double alpha;
  private final double addToDiag;
  private final int stateCount;
  private final double boostStep;
  private final int epochCount;
  private final double gradientStep;
  private final String wParam;
  private final String bParam;
  private final boolean validateEachEpoch;

  private List> train;
  private Vec trainTarget;

  private List> test;
  private Vec testTarget;
  private final List alphabet = new ArrayList<>();
  private int maxLen;

  private Dictionary dictionary;

  private final BettaParametrization bettaParametrization;
  private final WeightParametrization weightParametrization;

  public RunImdb(final int stateCount, final double alpha, final double addToDiag, final double boostStep, int epochCount, double gradientStep, String wParam, String bParam, boolean validateEachEpoch) {
    this.stateCount = stateCount;
    this.alpha = alpha;
    this.addToDiag = addToDiag;
    this.boostStep = boostStep;
    this.epochCount = epochCount;
    this.gradientStep = gradientStep;
    this.wParam = wParam;
    this.bParam = bParam;
    this.validateEachEpoch = validateEachEpoch;
    if (bParam.equals("mx")) {
      bettaParametrization = new BettaMxParametrization(addToDiag);
    } else if (bParam.equals("vec")) {
      bettaParametrization = new BettaTwoVecParametrization(addToDiag);
    } else {
      throw new IllegalArgumentException(bParam);
    }

    if (wParam.equals("sqr")) {
      weightParametrization = new WeightSquareParametrization(bettaParametrization);
    } else if (wParam.equals("exp")) {
      weightParametrization = new WeightExpParametrization(bettaParametrization);
    } else {
      throw new IllegalArgumentException(wParam);
    }
  }

  public void loadWordData() throws IOException {
    train = new ArrayList<>(TRAIN_SIZE);
    test = new ArrayList<>(TRAIN_SIZE);
    trainTarget = new ArrayVec(TRAIN_SIZE);
    testTarget = new ArrayVec(TRAIN_SIZE);

    readWordData("src/train.txt", train, trainTarget);
    readWordData("src/test.txt", test, testTarget);
    //    loadData();
  }

  public void loadData() throws IOException {
    System.out.println("Number of cores: " + Runtime.getRuntime().availableProcessors());
    System.out.println("Alphabet size: " + ALPHABET_SIZE);
    System.out.println("States count: " + stateCount);
    System.out.println("GradBoost step: " + boostStep);
    System.out.println("GradBoost iters: " + BOOST_ITERS);
    System.out.println("GradDesc step: " + gradientStep);
    System.out.println("Grad iters: " + epochCount);
    System.out.println("Train size: " + TRAIN_SIZE);


    List positiveRaw = readData("ml/src/aclImdb/train/pos");
    List negativeRaw = readData("ml/src/aclImdb/train/neg");

    List all = new ArrayList<>(positiveRaw);
    all.addAll(negativeRaw);
    DictExpansion de = new DictExpansion<>(all.stream().flatMapToInt(CharSequence::chars)
        .sorted()
        .distinct()
        .mapToObj(i -> (char) i)
        .collect(Collectors.toList()), ALPHABET_SIZE);
    for (int i = 0; i < 10; i++) {
      positiveRaw.forEach(de::accept);
      negativeRaw.forEach(de::accept);
    }
    dictionary = de.result();
    //System.out.println("New dictionary: " + result.alphabet().toString());
    System.out.println("New dictionary size: " + dictionary.alphabet().size());

    int size = 0;
    for (CharSeq seq: positiveRaw) {
      size += dictionary.parse(seq).length();
    }
    System.out.println(size + " " + size / positiveRaw.size());
    System.out.println("Real alphabet size = " + dictionary.size());
    Collections.shuffle(positiveRaw, random);
    Collections.shuffle(negativeRaw, random);
    positiveRaw = positiveRaw.stream().limit(TRAIN_SIZE).collect(Collectors.toList());
    negativeRaw = negativeRaw.stream().limit(TRAIN_SIZE).collect(Collectors.toList());

    train = positiveRaw.stream().map(dictionary::parse).collect(Collectors.toList());
    train.addAll(negativeRaw.stream().map(dictionary::parse).collect(Collectors.toList()));
    maxLen = 0;
    for (int i = 0; i < train.size(); i++) {
      maxLen = Math.max(maxLen, train.get(i).length());
    }

    int[] targetArray = new int[train.size()];
    for (int i = 0; i < train.size() / 2; i++) {
      targetArray[i] = 1;
    }
    for (int i = train.size() / 2; i < train.size(); i++) {
      targetArray[i] = 0;
    }
    trainTarget = VecTools.fromIntSeq(new IntSeq(targetArray));

    test = readData("ml/src/aclImdb/test/pos").stream().map(dictionary::parse).collect(Collectors
        .toList());
    test.addAll(readData("ml/src/aclImdb/test/neg").stream().map(dictionary::parse).collect
        (Collectors.toList()));
    for (int i = 0; i < train.size(); i++) {
      maxLen = Math.max(maxLen, train.get(i).length());
    }
    targetArray = new int[test.size()];
    for (int i = 0; i < test.size() / 2; i++) {
      targetArray[i] = 1;
    }
    for (int i = test.size() / 2; i < test.size(); i++) {
      targetArray[i] = 0;
    }
    testTarget = VecTools.fromIntSeq(new IntSeq(targetArray));
    System.out.println("Data loaded");
  }

  public void test() {

    DataSet> data = new DataSet.Stub>(null) {
      @Override
      public Seq at(int i) {
        return train.get(i);
      }

      @Override
      public int length() {
        return train.size();
      }

      @Override
      public Class> elementType() {
        return null;
      }
    };


    IntAlphabet alphabet = new IntAlphabet(ALPHABET_SIZE);
    AdamDescent adamDescent = new AdamDescent(random, epochCount, 4, gradientStep);
    if (validateEachEpoch) {
      adamDescent.setListener(x -> {
        PNFAModel model = new PNFAModel<>(
            x,
            stateCount,
            1,
            addToDiag,
            alpha,
            alphabet,
            bettaParametrization,
            weightParametrization
        );
        System.out.println("Train accuracy: " + (1 - getAccuracy(train, trainTarget, model)));
        System.out.println("Test accuracy: " + (1 - getAccuracy(test, testTarget, model)));
      });
    }
    final GradientSeqBoosting boosting = new GradientSeqBoosting<>(
        new BootstrapSeqOptimization<>(
            new PNFARegressor<>(stateCount, stateCount,1, alphabet, alpha, 0.000, addToDiag, 0.1, random,
               // new SAGADescent(0.002, WEIGHTS_EPOCH_COUNT * train.size(), random, System.out),
                adamDescent,
                bettaParametrization, weightParametrization),
            random, 1),
        BOOST_ITERS, boostStep
    );


    Consumer,Vec>> listener = classifier -> {
      try {
        Files.write(Paths.get("kek/1"), new ObjectMapper().writeValueAsString(classifier).getBytes());
      }
      catch (IOException e) {
        e.printStackTrace();
      }

      System.out.println("Current time: " + new SimpleDateFormat("yyyy/MM/dd_HH:mm:ss").format(Calendar.getInstance().getTime()));
      System.out.println("Current accuracy:");
      System.out.println("Train accuracy: " + getAccuracy(train, trainTarget, classifier));
      System.out.println("Test accuracy: " + getAccuracy(test, testTarget, classifier));
      System.out.println("Train loss: " + getLoss(train, trainTarget, classifier));
      System.out.println("Test loss: " + getLoss(test, testTarget, classifier));
    };

    boosting.addListener(listener);
    final Function, Vec> classifier = boosting.fit(data, new LLLogit(trainTarget, null));

    System.out.println("Train accuracy: " + getAccuracy(train, trainTarget, classifier));
    System.out.println("Test accuracy: " +  getAccuracy(test, testTarget, classifier));
  }

  private void readWordData(String path, List> data, Vec target) throws IOException {
    List list = Files.readAllLines(Paths.get(path));
    Collections.shuffle(list, random);

    for (int i = 0; i < TRAIN_SIZE; i++) {
      String[] tokens = list.get(i).split(" ");
      target.set(i, Integer.parseInt(tokens[0]));
      data.add(new IntSeq(Arrays.stream(tokens, 1, tokens.length).mapToInt(Integer::parseInt).toArray()));
    }
  }

  private List readData(final String filePath) throws IOException {
    long start = System.nanoTime();
    final List data = Files.list(Paths.get(filePath)).map(path -> {
      try {
        return  new CharSeqArray(Files.readAllLines(path)
            .stream()
            .map(String::toLowerCase)
            .map(str -> str.replaceAll("[^a-zA-Z]", ""))
            .collect(Collectors.joining("\n"))
            .toCharArray());
      } catch (IOException e) {
        e.printStackTrace();
        return null;
      }
    }).filter(x -> x != null).collect(Collectors.toList());
    System.out.printf("Data read in %.2f minutes\n", (System.nanoTime() - start) / 60e9);
    return data;
  }

  private double getAccuracy(List> data, Vec target, Function, Vec> classifier) {
    int passedCnt = 0;
    for (int i = 0; i < data.size(); i++) {
      final double val = classifier.apply(data.get(i)).get(0);
      if ((target.get(i) > 0 && val > 0) || (target.get(i) <= 0 && val <= 0)) {
        passedCnt++;
      }
    }
    return 1.0 * passedCnt / data.size();
  }

  private double getLoss(List> data, Vec target, Function, Vec> classifier) {
    final LLLogit lllogit = new LLLogit(target, null);
    Vec values = new ArrayVec(target.dim());
    for (int i =0 ; i < target.dim(); i++) {
      values.set(i, classifier.apply(data.get(i)).get(0));
    }
    return lllogit.value(values);
  }

  private static Options options = new Options();
  static {
    options.addOption("stateCount", true, "stateCount");
    options.addOption("alpha", true, "alpha");
    options.addOption("addToDiag", true, "addToDiag");
    options.addOption("boostStep", true, "boostStep");
    options.addOption("epochs", true, "epochs");
    options.addOption("step", true, "step");
    options.addOption("wParam", true, "wParam");
    options.addOption("bParam", true, "bParam");
    options.addOption("validateEpoch", true, "validateEpoch");
    options.addOption("alphabet", true, "alphabet");
  }
  public static void main(String[] args) throws IOException, ParseException {
    final CommandLineParser parser = new GnuParser();
    final CommandLine command = parser.parse(options, args);
    System.out.println(Arrays.toString(args));
    RunImdb test = new RunImdb(
        Integer.parseInt(command.getOptionValue("stateCount")),
        Double.parseDouble(command.getOptionValue("alpha")),
        Double.parseDouble(command.getOptionValue("addToDiag")),
        Double.parseDouble(command.getOptionValue("boostStep")),
        Integer.parseInt(command.getOptionValue("epochs")),
        Double.parseDouble(command.getOptionValue("step")),
        command.getOptionValue("wParam"),
        command.getOptionValue("bParam"),
        Boolean.parseBoolean(command.getOptionValue("validateEpoch"))
        );
    if (command.getOptionValue("alphabet").equals("vgram")) {
      test.loadData();
    } else {
      test.loadWordData();
    }
    test.test();
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy