Please wait. This can take some minutes ...
Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance.
Project price only 1 $
You can buy this project and download/modify it how often you want.
com.expleague.ml.methods.seq.RunImdb Maven / Gradle / Ivy
package com.expleague.ml.methods.seq;
import com.expleague.commons.math.vectors.Vec;
import com.expleague.commons.math.vectors.VecTools;
import com.expleague.commons.math.vectors.impl.vectors.ArrayVec;
import com.expleague.commons.seq.CharSeqArray;
import com.expleague.commons.seq.IntSeq;
import com.expleague.commons.io.codec.seq.DictExpansion;
import com.expleague.commons.io.codec.seq.Dictionary;
import com.expleague.commons.random.FastRandom;
import com.expleague.commons.seq.CharSeq;
import com.expleague.commons.seq.Seq;
import com.expleague.ml.data.set.DataSet;
import com.expleague.ml.loss.LLLogit;
import com.expleague.ml.optimization.impl.AdamDescent;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.text.SimpleDateFormat;
import java.util.*;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.stream.Collectors;
public class RunImdb {
private static final int ALPHABET_SIZE = 10000;
private static final int TRAIN_SIZE = 25000;
private static final FastRandom random = new FastRandom(239);
private static final int BOOST_ITERS = 1000;
private static final double BOOST_STEP = 0.2;
private static final int MAX_STATE_COUNT = 4;
private static final int EPOCH_COUNT = 1;
private static final double GRAD_STEP = 0.3;
private static final int THREAD_COUNT = 1;
private List> train;
private Vec trainTarget;
private List> test;
private Vec testTarget;
private final List alphabet = new ArrayList<>();
private int maxLen;
private Dictionary dictionary;
public void loadWordData() throws IOException {
train = new ArrayList<>(TRAIN_SIZE);
test = new ArrayList<>(TRAIN_SIZE);
trainTarget = new ArrayVec(TRAIN_SIZE);
testTarget = new ArrayVec(TRAIN_SIZE);
readWordData("src/train.txt", train, trainTarget);
readWordData("src/test.txt", test, testTarget);
}
public void loadData() throws IOException {
System.out.println("Number of cores: " + Runtime.getRuntime().availableProcessors());
System.out.println("Alphabet size: " + ALPHABET_SIZE);
System.out.println("States count: " + MAX_STATE_COUNT);
System.out.println("GradBoost step: " + BOOST_STEP);
System.out.println("GradBoost iters: " + BOOST_ITERS);
System.out.println("GradDesc step: " + GRAD_STEP);
System.out.println("Grad iters: " + EPOCH_COUNT);
System.out.println("Train size: " + TRAIN_SIZE);
List positiveRaw = readData("src/aclImdb/train/pos");
List negativeRaw = readData("src/aclImdb/train/neg");
List all = new ArrayList<>(positiveRaw);
all.addAll(negativeRaw);
DictExpansion de = new DictExpansion<>(all.stream().flatMapToInt(CharSequence::chars)
.sorted()
.distinct()
.mapToObj(i -> (char) i)
.collect(Collectors.toList()), ALPHABET_SIZE);
for (int i = 0; i < 10; i++) {
positiveRaw.forEach(de::accept);
negativeRaw.forEach(de::accept);
}
dictionary = de.result();
//System.out.println("New dictionary: " + result.alphabet().toString());
System.out.println("New dictionary size: " + dictionary.alphabet().size());
int size = 0;
for (CharSeq seq: positiveRaw) {
size += dictionary.parse(seq).length();
}
System.out.println(size + " " + size / positiveRaw.size());
System.out.println("Real alphabet size = " + dictionary.size());
Collections.shuffle(positiveRaw, random);
Collections.shuffle(negativeRaw, random);
positiveRaw = positiveRaw.stream().limit(TRAIN_SIZE).collect(Collectors.toList());
negativeRaw = negativeRaw.stream().limit(TRAIN_SIZE).collect(Collectors.toList());
train = positiveRaw.stream().map(dictionary::parse).collect(Collectors.toList());
train.addAll(negativeRaw.stream().map(dictionary::parse).collect(Collectors.toList()));
maxLen = 0;
for (int i = 0; i < train.size(); i++) {
maxLen = Math.max(maxLen, train.get(i).length());
}
int[] targetArray = new int[train.size()];
for (int i = 0; i < train.size() / 2; i++) {
targetArray[i] = 1;
}
for (int i = train.size() / 2; i < train.size(); i++) {
targetArray[i] = 0;
}
trainTarget = VecTools.fromIntSeq(new IntSeq(targetArray));
test = readData("src/aclImdb/test/pos").stream().map(dictionary::parse).collect(Collectors.toList());
test.addAll(readData("src/aclImdb/test/neg").stream().map(dictionary::parse).collect(Collectors.toList()));
for (int i = 0; i < train.size(); i++) {
maxLen = Math.max(maxLen, train.get(i).length());
}
targetArray = new int[test.size()];
for (int i = 0; i < test.size() / 2; i++) {
targetArray[i] = 1;
}
for (int i = test.size() / 2; i < test.size(); i++) {
targetArray[i] = 0;
}
testTarget = VecTools.fromIntSeq(new IntSeq(targetArray));
System.out.println("Data loaded");
}
public void test() {
DataSet> data = new DataSet.Stub>(null) {
@Override
public Seq at(int i) {
return train.get(i);
}
@Override
public int length() {
return train.size();
}
@Override
public Class> elementType() {
return null;
}
};
final GradientSeqBoosting boosting = new GradientSeqBoosting<>(
new BootstrapSeqOptimization<>(
new PNFA<>(MAX_STATE_COUNT, ALPHABET_SIZE, random,
//new SAGADescent(GRAD_STEP, EPOCH_COUNT, random, THREAD_COUNT)
new AdamDescent(random, EPOCH_COUNT, 4), 2
), random
),
BOOST_ITERS, BOOST_STEP
);
Consumer,Vec>> listener = classifier -> {
System.out.println("Current time: " + new SimpleDateFormat("yyyy/MM/dd_HH:mm:ss").format(Calendar.getInstance().getTime()));
System.out.println("Current accuracy:");
System.out.println("Train accuracy: " + getAccuracy(train, trainTarget, classifier));
System.out.println("Test accuracy: " + getAccuracy(test, testTarget, classifier));
System.out.println("Train loss: " + getLoss(train, trainTarget, classifier));
System.out.println("Test loss: " + getLoss(test, testTarget, classifier));
};
boosting.addListener(listener);
final Function, Vec> classifier = boosting.fit(data, new LLLogit(trainTarget, null));
System.out.println("Train accuracy: " + getAccuracy(train, trainTarget, classifier));
System.out.println("Test accuracy: " + getAccuracy(test, testTarget, classifier));
}
private void readWordData(String path, List> data, Vec target) throws IOException {
List list = Files.readAllLines(Paths.get(path));
Collections.shuffle(list, random);
for (int i = 0; i < TRAIN_SIZE; i++) {
String[] tokens = list.get(i).split(" ");
target.set(i, Integer.parseInt(tokens[0]));
data.add(new IntSeq(Arrays.stream(tokens, 1, tokens.length).mapToInt(Integer::parseInt).toArray()));
}
}
private List readData(final String filePath) throws IOException {
long start = System.nanoTime();
final List data = Files.list(Paths.get(filePath)).map(path -> {
try {
return new CharSeqArray(Files.readAllLines(path)
.stream()
.map(String::toLowerCase)
.map(str -> str.replaceAll("[^\\x00-\\x7F]", "").replaceAll("\\s{2,}", " ").trim())
.collect(Collectors.joining("\n"))
.toCharArray());
} catch (IOException e) {
e.printStackTrace();
return null;
}
}).filter(x -> x != null).collect(Collectors.toList());
System.out.printf("Data read in %.2f minutes\n", (System.nanoTime() - start) / 60e9);
return data;
}
private double getAccuracy(List> data, Vec target, Function, Vec> classifier) {
int passedCnt = 0;
for (int i = 0; i < data.size(); i++) {
final double val = classifier.apply(data.get(i)).get(0);
if ((target.get(i) > 0 && val > 0) || (target.get(i) <= 0 && val <= 0)) {
passedCnt++;
}
}
return 1.0 * passedCnt / data.size();
}
private double getLoss(List> data, Vec target, Function, Vec> classifier) {
final LLLogit lllogit = new LLLogit(target, null);
Vec values = new ArrayVec(target.dim());
for (int i =0 ; i < target.dim(); i++) {
values.set(i, classifier.apply(data.get(i)).get(0));
}
return lllogit.value(values);
}
public static void main(String[] args) throws IOException {
RunImdb test = new RunImdb();
test.loadWordData();
test.test();
}
}