org.ranksys.javafm.example.WineQualityExample Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of JavaFM Show documentation
Show all versions of JavaFM Show documentation
Java 8 Factorization Machines Library.
The newest version!
/*
* Copyright (C) 2016 RankSys http://ranksys.org
*
* This Source Code Form is subject to the terms of the Mozilla Public
* License, v. 2.0. If a copy of the MPL was not distributed with this
* file, You can obtain one at http://mozilla.org/MPL/2.0/.
*/
package org.ranksys.javafm.example;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import static java.lang.Double.parseDouble;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Random;
import org.ranksys.javafm.data.SimpleFMData;
import org.ranksys.javafm.FMInstance;
import java.net.URL;
import java.nio.channels.Channels;
import java.nio.channels.ReadableByteChannel;
import java.util.ArrayList;
import org.ranksys.javafm.learner.gd.PointWiseGradientDescent;
import java.util.DoubleSummaryStatistics;
import static java.util.stream.IntStream.range;
import java.util.stream.Stream;
import org.ranksys.javafm.BoundedFM;
import static org.ranksys.javafm.learner.gd.PointWiseError.rmse;
import org.ranksys.javafm.data.FMData;
import static java.util.stream.Collectors.groupingBy;
/**
* Regression example with the Wine Quality dataset.
*
* https://archive.ics.uci.edu/ml/datasets/Wine+Quality
*
* @author Saúl Vargas ([email protected])
*/
public class WineQualityExample {
public static void main(String[] args) throws Exception {
FMData dataset = getWineQualityDataset();
List partition = getRandomPartition(dataset, 0.6, new Random(1L));
FMData train = partition.get(0);
FMData test = partition.get(1);
double learnRate = 0.001;
int numIter = 200;
double sdev = 1.0;
double regB = 0.01;
double[] regW = new double[train.numFeatures()];
Arrays.fill(regW, 0.01);
double[] regM = new double[train.numFeatures()];
Arrays.fill(regM, 0.01);
int K = 10;
BoundedFM fm = new BoundedFM(3.0, 9.0, train.numInstances(), K, new Random(), sdev);
new PointWiseGradientDescent(learnRate, numIter, rmse(), regB, regW, regM)
.learn(fm, train, test);
}
private static FMData getWineQualityDataset() throws IOException {
int columns = 11;
SimpleFMData data = new SimpleFMData(columns);
String filePath = "winequality-white.csv";
if (!new File(filePath).exists()) {
URL url = new URL("https://archive.ics.uci.edu/ml/machine-learning-databases/wine-quality/winequality-white.csv");
ReadableByteChannel rbc = Channels.newChannel(url.openStream());
FileOutputStream fos = new FileOutputStream(filePath);
fos.getChannel().transferFrom(rbc, 0, Long.MAX_VALUE);
}
InputStream is = new FileInputStream(filePath);
try (BufferedReader in = new BufferedReader(new InputStreamReader(is))) {
in.readLine();
String instance;
while ((instance = in.readLine()) != null) {
String[] tokens = instance.split(";");
double target = parseDouble(tokens[columns]);
int[] k = range(0, columns).toArray();
double[] v = Stream.of(tokens)
.limit(columns)
.mapToDouble(Double::parseDouble)
.toArray();
data.add(new FMInstance(target, k, v));
}
}
for (int _col = 0; _col < columns; _col++) {
int col = _col;
DoubleSummaryStatistics stats = data.stream()
.mapToDouble(x -> x.get(col))
.summaryStatistics();
double max = stats.getMax();
double min = stats.getMin();
if (max == min) {
data.stream().forEach(x -> x.set(col, 0.0));
} else {
data.stream().forEach(x -> x.set(col, (x.get(col) - min) / (max - min)));
}
}
return data;
}
private static List getRandomPartition(FMData dataset, double trainProp, Random rnd) {
List trainList = new ArrayList<>();
List testList = new ArrayList<>();
dataset.shuffle();
dataset.stream().forEach(instance -> {
if (rnd.nextDouble() < trainProp) {
trainList.add(instance);
} else {
testList.add(instance);
}
});
FMData train = new SimpleFMData(dataset.numFeatures(), rnd, trainList);
FMData test = new SimpleFMData(dataset.numFeatures(), rnd, testList);
return Arrays.asList(train, test);
}
}