All Downloads are FREE. Search and download functionalities are using the official Maven repository.

net.librec.data.splitter.RatioDataSplitter Maven / Gradle / Ivy

/**
 * Copyright (C) 2016 LibRec
 * 

* This file is part of LibRec. * LibRec is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. *

* LibRec is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. *

* You should have received a copy of the GNU General Public License * along with LibRec. If not, see . */ package net.librec.data.splitter; import net.librec.common.LibrecException; import net.librec.conf.Configuration; import net.librec.data.DataConvertor; import net.librec.math.algorithm.Randoms; import net.librec.math.structure.MatrixEntry; import net.librec.math.structure.SparseMatrix; import net.librec.math.structure.SparseVector; import net.librec.math.structure.VectorEntry; import net.librec.util.RatingContext; import java.util.ArrayList; import java.util.Collections; import java.util.List; /** * Ratio Data Splitter.
* Split dataset into train set, test set, valid set by ratio.
* * @author WangYuFeng and Liuxz */ public class RatioDataSplitter extends AbstractDataSplitter { /** The rate dataset for splitting */ private SparseMatrix preferenceMatrix; /** The datetime dataset for splitting */ private SparseMatrix datetimeMatrix; /** * Empty constructor. */ public RatioDataSplitter() { } /** * Initializes a newly created {@code RatioDataSplitter} object * with convertor and configuration. * * @param dataConvertor * the convertor for the splitter. * @param conf * the configuration for the splitter. */ public RatioDataSplitter(DataConvertor dataConvertor, Configuration conf) { this.dataConvertor = dataConvertor; this.conf = conf; } /** * Split the dataset according to the configuration file.
* * @throws LibrecException if error occurs */ @Override public void splitData() throws LibrecException { this.preferenceMatrix = dataConvertor.getPreferenceMatrix(); this.datetimeMatrix = dataConvertor.getDatetimeMatrix(); String splitter = conf.get("data.splitter.ratio"); switch (splitter.toLowerCase()) { case "rating": { double ratio = Double.parseDouble(conf.get("data.splitter.trainset.ratio")); getRatioByRating(ratio); break; } case "user": { double ratio = Double.parseDouble(conf.get("data.splitter.trainset.ratio")); getRatioByUser(ratio); break; } case "userfixed": { double ratio = Double.parseDouble(conf.get("data.splitter.trainset.ratio")); getFixedRatioByUser(ratio); break; } case "item": { double ratio = Double.parseDouble(conf.get("data.splitter.trainset.ratio")); getRatioByItem(ratio); break; } case "valid": { double trainRatio = Double.parseDouble(conf.get("data.splitter.trainset.ratio")); double validationRaito = Double.parseDouble(conf.get("data.splitter.validset.ratio")); getRatio(trainRatio, validationRaito); break; } case "ratingdate": { double ratio = Double.parseDouble(conf.get("data.splitter.trainset.ratio")); getRatioByRatingDate(ratio); break; } case "userdate": { double ratio = Double.parseDouble(conf.get("data.splitter.trainset.ratio")); getRatioByUserDate(ratio); break; } case "itemdate": { double ratio = Double.parseDouble(conf.get("data.splitter.trainset.ratio")); getRatioByItemDate(ratio); break; } } } /** * Split ratings into two parts: (ratio) training, (1-ratio) test subsets. * * @param ratio the ratio of training data over all the ratings */ public void getRatioByRating(double ratio) { if (ratio > 0 && ratio < 1) { testMatrix = new SparseMatrix(preferenceMatrix); trainMatrix = new SparseMatrix(preferenceMatrix); for(MatrixEntry matrixEntry: preferenceMatrix){ int userIdx = matrixEntry.row(); int itemIdx = matrixEntry.column(); double rdm = Randoms.uniform(); if (rdm < ratio) { testMatrix.set(userIdx, itemIdx, 0.0); } else { trainMatrix.set(userIdx, itemIdx, 0.0); } } SparseMatrix.reshape(testMatrix); SparseMatrix.reshape(trainMatrix); } } /** * Split the ratings (by date) into two parts: (ratio) training, (1-ratio) * test subsets. * * @param ratio the ratio of training data */ public void getRatioByRatingDate(double ratio) { if (ratio > 0 && ratio < 1) { testMatrix = new SparseMatrix(preferenceMatrix); trainMatrix = new SparseMatrix(preferenceMatrix); List rcs = new ArrayList<>(datetimeMatrix.size()); for (MatrixEntry me : preferenceMatrix) rcs.add(new RatingContext(me.row(), me.column(), (long) datetimeMatrix.get(me.row(), me.column()))); Collections.sort(rcs); int trainSize = (int) (rcs.size() * ratio); for (int i = 0; i < rcs.size(); i++) { RatingContext rc = rcs.get(i); int u = rc.getUser(); int j = rc.getItem(); if (i < trainSize) testMatrix.set(u, j, 0.0); else trainMatrix.set(u, j, 0.0); } rcs = null; SparseMatrix.reshape(trainMatrix); SparseMatrix.reshape(testMatrix); } } /** * Split ratings into two parts: the training set consisting of user-item * ratings where {@code ratio} percentage of ratings are preserved for each * user, and the rest are used as the testing data. * * @param ratio the ratio of training data */ public void getRatioByUser(double ratio) { if (ratio > 0 && ratio < 1) { trainMatrix = new SparseMatrix(preferenceMatrix); testMatrix = new SparseMatrix(preferenceMatrix); for (int u = 0, um = preferenceMatrix.numRows(); u < um; u++) { List items = preferenceMatrix.getColumns(u); for (int j : items) { if (Randoms.uniform() < ratio) { testMatrix.set(u, j, 0.0); } else { trainMatrix.set(u, j, 0.0); } } } SparseMatrix.reshape(testMatrix); SparseMatrix.reshape(trainMatrix); } } /** * Split ratings into two parts: the training set consisting of user-item * ratings where a fixed number of ratings corresponding to the given * {@code ratio} are preserved for each user as training data with the rest * as test. * * @param ratio the ratio of training data */ public void getFixedRatioByUser(double ratio) { if (ratio > 0 && ratio < 1) { trainMatrix = new SparseMatrix(preferenceMatrix); testMatrix = new SparseMatrix(preferenceMatrix); for (int u = 0, um = preferenceMatrix.numRows(); u < um; u++) { List items = preferenceMatrix.getColumns(u); // k is the test set, this will be smaller, so we want these indices in the list int k = (int) Math.floor(items.size() * (1 - ratio)); try { List testIndexes = Randoms.randInts(k, 0, items.size()); for (int j : items) { if (testIndexes.contains(j)) { trainMatrix.set(u, j, 0.0); } else { testMatrix.set(u, j, 0.0); } } } catch (java.lang.Exception e) { LOG.error("This error should not happen because k cannot be outside of the range if ratio is " + ratio); } } SparseMatrix.reshape(testMatrix); SparseMatrix.reshape(trainMatrix); } } /** * Split the ratings of each user (by date) into two parts: (ratio) * training, (1-ratio) test subsets * * @param ratio the ratio of train data */ public void getRatioByUserDate(double ratio) { if (ratio > 0 && ratio < 1) { trainMatrix = new SparseMatrix(preferenceMatrix); testMatrix = new SparseMatrix(preferenceMatrix); for (int user = 0, um = preferenceMatrix.numRows(); user < um; user++) { List unsortedItems = preferenceMatrix.getColumns(user); int size = unsortedItems.size(); List rcs = new ArrayList<>(size); for (int item : unsortedItems) { rcs.add(new RatingContext(user, item, (long) datetimeMatrix.get(user, item))); } Collections.sort(rcs); int trainSize = (int) (rcs.size() * ratio); for (int i = 0; i < rcs.size(); i++) { RatingContext rc = rcs.get(i); int u = rc.getUser(); int j = rc.getItem(); if (i < trainSize) testMatrix.set(u, j, 0.0); else trainMatrix.set(u, j, 0.0); } } SparseMatrix.reshape(trainMatrix); SparseMatrix.reshape(testMatrix); } } /** * Split ratings into two parts: the training set consisting of user-item * ratings where {@code ratio} percentage of ratings are preserved for each * item, and the rest are used as the testing data. * * @param ratio the ratio of training data */ public void getRatioByItem(double ratio) { if (ratio > 0 && ratio < 1) { trainMatrix = new SparseMatrix(preferenceMatrix); testMatrix = new SparseMatrix(preferenceMatrix); for (int i = 0, im = preferenceMatrix.numColumns(); i < im; i++) { List users = preferenceMatrix.getRows(i); for (int u : users) { if (Randoms.uniform() < ratio) { testMatrix.set(u, i, 0.0); } else { trainMatrix.set(u, i, 0.0); } } } SparseMatrix.reshape(trainMatrix); SparseMatrix.reshape(testMatrix); } } /** * Split the ratings of each item (by date) into two parts: (ratio) * training, (1-ratio) test subsets. * * @param ratio the ratio of training data */ public void getRatioByItemDate(double ratio) { if (ratio > 0 && ratio < 1) { trainMatrix = new SparseMatrix(preferenceMatrix); testMatrix = new SparseMatrix(preferenceMatrix); for (int item = 0, im = preferenceMatrix.numColumns(); item < im; item++) { List unsortedUsers = preferenceMatrix.getRows(item); int size = unsortedUsers.size(); List rcs = new ArrayList<>(size); for (int user : unsortedUsers) { rcs.add(new RatingContext(user, item, (long) datetimeMatrix.get(user, item))); } Collections.sort(rcs); int trainSize = (int) (rcs.size() * ratio); for (int i = 0; i < rcs.size(); i++) { RatingContext rc = rcs.get(i); int u = rc.getUser(); int j = rc.getItem(); if (i < trainSize) testMatrix.set(u, j, 0.0); else trainMatrix.set(u, j, 0.0); } } SparseMatrix.reshape(testMatrix); SparseMatrix.reshape(trainMatrix); } } /** * Split the rating into : (train-ratio) training, (validation-ratio) * validation, and test three subsets. * * @param trainRatio training ratio * @param validationRatio validation ratio */ public void getRatio(double trainRatio, double validationRatio) { if ((trainRatio > 0 && validationRatio > 0) && (trainRatio + validationRatio) < 1) { trainMatrix = new SparseMatrix(preferenceMatrix); validationMatrix = new SparseMatrix(preferenceMatrix); testMatrix = new SparseMatrix(preferenceMatrix); for (int u = 0, um = preferenceMatrix.numRows(); u < um; u++) { SparseVector uv = preferenceMatrix.row(u); for (VectorEntry j : uv) { double rdm = Randoms.uniform(); if (rdm < trainRatio) { // training validationMatrix.set(u, j.index(), 0.0); testMatrix.set(u, j.index(), 0.0); } else if (rdm < trainRatio + validationRatio) { // validation trainMatrix.set(u, j.index(), 0.0); testMatrix.set(u, j.index(), 0.0); } else { // test trainMatrix.set(u, j.index(), 0.0); validationMatrix.set(u, j.index(), 0.0); } } } SparseMatrix.reshape(trainMatrix); SparseMatrix.reshape(validationMatrix); SparseMatrix.reshape(testMatrix); } } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy