org.lenskit.gradle.Crossfold.groovy Maven / Gradle / Ivy
/*
* LensKit, an open source recommender systems toolkit.
* Copyright 2010-2014 LensKit Contributors. See CONTRIBUTORS.md.
* Work on LensKit has been funded by the National Science Foundation under
* grants IIS 05-34939, 08-08692, 08-12148, and 10-17697.
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2.1 of the
* License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
* FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
* details.
*
* You should have received a copy of the GNU General Public License along with
* this program; if not, write to the Free Software Foundation, Inc., 51
* Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
*/
package org.lenskit.gradle
import groovy.json.JsonOutput
import org.gradle.api.tasks.Input
import org.gradle.api.tasks.InputFiles
import org.gradle.api.tasks.OutputDirectory
import org.lenskit.gradle.traits.DataSources
/**
* Crossfold a data set. This task can only crossfold a single data set; multiple tasks must be used to produce
* multiple cross-validation splits.
*
* @see DataSources
* @see
*/
class Crossfold extends LenskitTask implements DataSources, DataSetProvider {
/**
* The output directory for cross-validation. Defaults to "build/$name.out", where $name is the name of the task.
*/
def outputDir
private Object source
private Object srcFile
private List userPartitionArgs = []
def String method = 'partition-users'
def Integer sampleSize
def Integer partitionCount
def String outputFormat
def String dataSetName
@Deprecated
def boolean includeTimestamps = true
public Crossfold() {
conventionMapping.outputDir = {
"$project.buildDir/${getDataSetName()}.out"
}
conventionMapping.dataSetName = {
getName()
}
}
/**
* Set the input source manifest.
* @param file The path to an input source manifest file (in YAML format).
*/
void input(Object file) {
srcFile = file
}
void input(Map spec) {
source = spec
}
/**
* Configure an input CSV file of ratings. Convenience method; {@link #input(Object)} is more general.
* @param csv A CSV file containing ratings.
*/
void inputFile(Object csv) {
source = [type: "textfile",
file: project.uri(csv).toString(),
format: "csv"]
}
@InputFiles
Set getInputFiles() {
def files = new HashSet()
if (srcFile) {
files << srcFile
}
// TODO Extract source files
return files
}
@OutputDirectory
File getOutputDirectory() {
return project.file(getOutputDir())
}
@Override
String getCommand() {
return 'crossfold'
}
@Override
void doPrepare() {
project.mkdir outputDirectory
}
@Override
@Input
List getCommandArgs() {
def args = ["--output-dir", outputDirectory, "--name", getDataSetName()]
if (srcFile != null) {
args << "--data-source" << project.file(srcFile)
} else {
project.mkdir project.buildDir
project.file("$project.buildDir/$name-input.json").text = JsonOutput.toJson(source)
// FIXME Don't use JSON spec
args << "--data-source" << project.file("$project.buildDir/$name-input.json")
}
args << "--$method"
args.addAll userPartitionArgs
if (partitionCount) {
args << '--partition-count' <
* partition-users
* partition-entities
* sample-users
*
* @param m The method
*/
public void method(String m) {
// accept partition-ratings for backwards compatibility
if (!(m =~ /^(?i:partition[_-](users|ratings|entities)|sample[_-]users)$/)) {
throw new IllegalArgumentException("invalid partition method " + m)
}
method = m.replaceAll('_', '-').toLowerCase()
}
/**
* Hold out a fixed number of ratings per user
* @param n The number of ratings to hold out for each user.
* @param order The sort order. Defaults to `random`.
*/
public Object holdout(int n, String order = 'random') {
userPartitionArgs = ['--holdout-count', "$n"]
if (order == 'timestamp') {
userPartitionArgs << '--timestamp-order'
}
}
/**
* Utility method to create a retain-N user partition method.
* @param n The number of ratings to hold out for each user.
* @param order The sort order. Defaults to `random`.
*/
public Object retain(int n, String order = 'random') {
userPartitionArgs = ['--retain', "$n"]
if (order == 'timestamp') {
userPartitionArgs << '--timestamp-order'
}
}
/**
* Utility method to create a holdout-fraction user partition method.
* @param f The fraction of ratings to hold out per user.
* @param order The sort order. Defaults to `random`.
*/
public Object holdoutFraction(double f, String order = 'random') {
userPartitionArgs = ['--holdout-fraction', "$f"]
if (order == 'timestamp') {
userPartitionArgs << '--timestamp-order'
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy