org.lenskit.eval.crossfold.Crossfolder Maven / Gradle / Ivy
/*
* LensKit, an open source recommender systems toolkit.
* Copyright 2010-2014 LensKit Contributors. See CONTRIBUTORS.md.
* Work on LensKit has been funded by the National Science Foundation under
* grants IIS 05-34939, 08-08692, 08-12148, and 10-17697.
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2.1 of the
* License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
* FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
* details.
*
* You should have received a copy of the GNU General Public License along with
* this program; if not, write to the Free Software Foundation, Inc., 51
* Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
*/
package org.lenskit.eval.crossfold;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.ArrayNode;
import com.fasterxml.jackson.databind.node.JsonNodeFactory;
import com.fasterxml.jackson.databind.node.ObjectNode;
import com.fasterxml.jackson.dataformat.yaml.YAMLFactory;
import com.google.common.base.Charsets;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import com.google.common.collect.Sets;
import it.unimi.dsi.fastutil.longs.LongSet;
import org.grouplens.lenskit.util.io.UpToDateChecker;
import org.lenskit.data.dao.DataAccessException;
import org.lenskit.data.dao.DataAccessObject;
import org.lenskit.data.dao.file.EntitySource;
import org.lenskit.data.dao.file.StaticDataSource;
import org.lenskit.data.dao.file.TextEntitySource;
import org.lenskit.data.entities.CommonAttributes;
import org.lenskit.data.entities.CommonTypes;
import org.lenskit.data.entities.EntityType;
import org.lenskit.data.output.OutputFormat;
import org.lenskit.data.output.RatingWriter;
import org.lenskit.data.output.RatingWriters;
import org.lenskit.eval.traintest.DataSet;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import javax.annotation.Nullable;
import java.io.BufferedWriter;
import java.io.File;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.*;
/**
* Partitions a data set for cross-validation.
*
* The resulting data is placed in an output directory with the following files:
*
* - `datasets.yaml` - a manifest file listing all the data sets
* - `partNN.train.csv` - a CSV file containing the train data for part *NN*
* - `partNN.train.yaml` - a YAML manifest for the training data for part *NN*
* - `partNN.test.csv` - a CSV file containing the test data for part *NN*
* - `partNN.test.yaml` - a YAML manifest for the test data for part *NN*
*
* @author GroupLens Research
*/
public class Crossfolder {
public static final String ITEM_FILE_NAME = "items.txt";
private static final Logger logger = LoggerFactory.getLogger(Crossfolder.class);
private Random rng;
private String name;
private StaticDataSource source;
private EntityType entityType = CommonTypes.RATING;
private int partitionCount = 5;
private Path outputDir;
private OutputFormat outputFormat = OutputFormat.CSV;
private boolean skipIfUpToDate = false;
private CrossfoldMethod method = CrossfoldMethods.partitionUsers(SortOrder.RANDOM, HistoryPartitions.holdout(10));
private boolean writeTimestamps = true;
private boolean executed = false;
public Crossfolder() {
this(null);
}
public Crossfolder(String n) {
name = n;
rng = new Random();
}
/**
* Get the entity type that this crossfolder will crossfold.
* @return The entity type to crossfold.
*/
public EntityType getEntityType() {
return entityType;
}
/**
* Set the entity type that this crossfolder will crossfold.
* @param entityType The entity type to crossfold.
*/
public void setEntityType(EntityType entityType) {
this.entityType = entityType;
}
/**
* Set the number of partitions to generate.
*
* @param partition The number of paritions
* @return The CrossfoldCommand object (for chaining)
*/
public Crossfolder setPartitionCount(int partition) {
partitionCount = partition;
return this;
}
/**
* Get the partition count.
* @return The number of partitions that will be generated.
*/
public int getPartitionCount() {
return partitionCount;
}
/**
* Set the output format for the crossfolder.
* @param format The output format.
* @return The crossfolder (for chaining).
*/
public Crossfolder setOutputFormat(OutputFormat format) {
outputFormat = format;
return this;
}
/**
* Get the output format for the crossfolder.
* @return The format the crossfolder will use for writing its output.
*/
public OutputFormat getOutputFormat() {
return outputFormat;
}
/**
* Set the output directory for this crossfold operation.
* @param dir The output directory.
* @return The crossfolder (for chaining).
*/
public Crossfolder setOutputDir(Path dir) {
outputDir = dir;
return this;
}
/**
* Set the output directory for this crossfold operation.
* @param dir The output directory.
* @return The crossfolder (for chaining).
*/
public Crossfolder setOutputDir(File dir) {
return setOutputDir(dir.toPath());
}
/**
* Set the output directory for this crossfold operation.
* @param dir The output directory.
* @return The crossfolder (for chaining).
*/
public Crossfolder setOutputDir(String dir) {
return setOutputDir(Paths.get(dir));
}
/**
* Get the output directory.
* @return The directory into which crossfolding output will be placed.
*/
public Path getOutputDir() {
if (outputDir != null) {
return outputDir;
} else {
return Paths.get(getName() + ".split");
}
}
/**
* Set the data source.
* @param src
* @return The crossfolder (for chaining)
*/
public Crossfolder setSource(StaticDataSource src) {
source = src;
return this;
}
/**
* Set the method to be used by the crossfolder.
* @param meth The method to use.
* @return The crossfolder (for chaining).
*/
public Crossfolder setMethod(CrossfoldMethod meth) {
method = meth;
return this;
}
/**
* Get the method to be used for crossfolding.
* @return The configured crossfold method.
*/
public CrossfoldMethod getMethod() {
return method;
}
/**
* Configure whether to include timestamps in the output file.
* @param pack {@code true} to include timestamps (the default), {@code false} otherwise.
* @return The task (for chaining).
*/
public Crossfolder setWriteTimestamps(boolean pack) {
writeTimestamps = pack;
return this;
}
/**
* Query whether timestamps will be written.
* @return {@code true} if output will include timestamps.
*/
public boolean getWriteTimestamps() {
return writeTimestamps;
}
/**
* Get the visible name of this crossfold split.
*
* @return The name of the crossfold split.
*/
public String getName() {
if (name == null) {
return source.getName();
} else {
return name;
}
}
/**
* Set a name for this crossfolder. It will be used to generate the names of individual data sets, for example.
* @param n The crossfolder name.
* @return The crossfolder (for chaining).
*/
public Crossfolder setName(String n) {
name = n;
return this;
}
/**
* Get the data source backing this crossfold manager.
*
* @return The underlying data source.
*/
public StaticDataSource getSource() {
return source;
}
/**
* Set whether the crossfolder should skip if all files are up to date. The default is to always re-crossfold, even
* if the files are up to date.
*
* @param skip `true` to skip crossfolding if files are up to date.
* @return The crossfolder (for chaining).
*/
public Crossfolder setSkipIfUpToDate(boolean skip) {
skipIfUpToDate = skip;
return this;
}
public boolean getSkipIfUpToDate() {
return skipIfUpToDate;
}
/**
* Run the crossfold command. Write the partition files to the disk by reading in the source file.
*/
public void execute() {
try {
if (skipIfUpToDate) {
UpToDateChecker check = new UpToDateChecker();
for (EntitySource src: source.getSources()) {
if (src instanceof TextEntitySource) {
Path path = ((TextEntitySource) src).getFile();
check.addInput(Files.getLastModifiedTime(path).toMillis());
}
}
for (Path p: Iterables.concat(getTrainingFiles(), getTestFiles(), getSpecFiles())) {
check.addOutput(p.toFile());
}
if (check.isUpToDate()) {
logger.info("crossfold {} up to date", getName());
executed = true;
return;
}
}
logger.info("ensuring output directory {} exists", outputDir);
Files.createDirectories(outputDir);
logger.info("making sure item list is available");
JsonNode itemDataInfo = writeItemFile(source);
logger.info("writing train-test split files");
createTTFiles(source);
logger.info("writing manifests and specs");
Map metadata = new HashMap<>();
for (EntitySource src: source.getSourcesForType(entityType)) {
metadata.putAll(src.getMetadata());
}
writeManifests(source, metadata, itemDataInfo);
executed = true;
} catch (IOException ex) {
// TODO Use application-specific exception
throw new RuntimeException("Error writing data sets", ex);
}
}
List getTrainingFiles() {
return getFileList("part%02d.train." + outputFormat.getExtension());
}
List getTrainingManifestFiles() {
return getFileList("part%02d.train.yaml");
}
List getTestFiles() {
return getFileList("part%02d.test." + outputFormat.getExtension());
}
List getTestManifestFiles() {
return getFileList("part%02d.test.yaml");
}
List getSpecFiles() {
return getFileList("part%02d.json");
}
private List getFileList(String pattern) {
List files = new ArrayList<>(partitionCount);
for (int i = 1; i <= partitionCount; i++) {
files.add(getOutputDir().resolve(String.format(pattern, i)));
}
return files;
}
/**
* Write the items to a file.
* @param data The input data.
* @return The JSON data to include in the manifest to describe the item file.
* @throws IOException if there's a problem writing the file.
*/
@Nullable
private JsonNode writeItemFile(StaticDataSource data) throws IOException {
List itemSources = data.getSourcesForType(CommonTypes.ITEM);
if (itemSources.isEmpty()) {
logger.info("writing item IDs to {}", ITEM_FILE_NAME);
Path itemFile = outputDir.resolve(ITEM_FILE_NAME);
DataAccessObject dao = data.get();
LongSet items = dao.getEntityIds(CommonTypes.ITEM);
try (BufferedWriter writer = Files.newBufferedWriter(itemFile, Charsets.UTF_8)) {
for (Long item: items) { // escape analysis should elide allocations
writer.append(item.toString())
.append(System.lineSeparator());
}
}
// make the node describing this
JsonNodeFactory fac = JsonNodeFactory.instance;
ObjectNode node = fac.objectNode();
node.set("type", fac.textNode("textfile"));
node.set("format", fac.textNode("tsv"));
node.set("file", fac.textNode(ITEM_FILE_NAME));
node.set("entity_type", fac.textNode(CommonTypes.ITEM.getName()));
ArrayNode cols = fac.arrayNode();
cols.add(CommonAttributes.ENTITY_ID.getName());
node.set("columns", cols);
return node;
} else {
logger.info("input data specifies an item source, reusing that");
return null;
}
}
/**
* Write train-test split files.
*
* @throws IOException if there is an error writing the files.
* @param data The input data.
*/
private void createTTFiles(StaticDataSource data) throws IOException {
if (entityType != CommonTypes.RATING) {
logger.warn("entity type is not 'rating', crossfolding may not work correctly");
logger.warn("crossfolding non-rating data is a work in progress");
}
List sources = data.getSourcesForType(entityType);
logger.info("crossfolding {} data from {} sources", entityType, sources);
for (EntitySource source: sources) {
Set types = source.getTypes();
if (types.size() > 1) {
logger.warn("source {} has multiple entity types", source);
logger.warn("the following types will be ignored: {}",
Sets.difference(types, ImmutableSet.of(entityType)));
}
}
try (CrossfoldOutput out = new CrossfoldOutput(this, rng)) {
logger.info("running crossfold method {}", method);
method.crossfold(data.get(), out, entityType);
}
}
private void writeManifests(StaticDataSource data, Map meta, JsonNode itemData) throws IOException {
logger.debug("writing manifests");
YAMLFactory ioFactory = new YAMLFactory();
ObjectMapper mapper = new ObjectMapper(ioFactory);
JsonNodeFactory nf = JsonNodeFactory.instance;
List trainFiles = getTrainingFiles();
List trainManifestFiles = getTrainingManifestFiles();
List testFiles = getTestFiles();
List testManifestFiles = getTestManifestFiles();
Path dataSetFile = outputDir.resolve("datasets.yaml");
ObjectNode dsNode = nf.objectNode();
dsNode.set("name", nf.textNode(name));
ArrayNode dsList = nf.arrayNode();
for (int i = 0; i < partitionCount; i++) {
ObjectNode dsListEntry = nf.objectNode();
dsListEntry.set("train", nf.textNode(outputDir.relativize(trainManifestFiles.get(i)).toString()));
dsListEntry.set("test", nf.textNode(outputDir.relativize(testManifestFiles.get(i)).toString()));
dsList.add(dsListEntry);
// TODO Support various columns in crossfold output
logger.debug("writing train manifest {}", i);
Path trainFile = trainManifestFiles.get(i);
ArrayNode trainList = nf.arrayNode();
ObjectNode train = nf.objectNode();
train.set("type", nf.textNode("textfile"));
train.set("file", nf.textNode(outputDir.relativize(trainFiles.get(i)).toString()));
train.set("format", nf.textNode("csv"));
train.set("entity_type", nf.textNode(entityType.getName()));
train.set("metadata", mapper.valueToTree(meta));
trainList.add(train);
// write the item output
if (itemData != null) {
trainList.add(itemData);
}
// write the other data files
for (EntitySource source: data.getSources()) {
if (source.getTypes().contains(entityType)) {
continue; // this one was crossfolded
}
if (source instanceof TextEntitySource) {
trainList.add(((TextEntitySource) source).toJSON(trainFile.toUri()));
} else {
logger.warn("ignoring non-file data source {}", source);
}
}
mapper.writeValue(trainFile.toFile(), trainList);
logger.debug("writing test manifest {}", i);
ObjectNode test = nf.objectNode();
test.set("type", nf.textNode("textfile"));
test.set("file", nf.textNode(outputDir.relativize(testFiles.get(i)).toString()));
test.set("format", nf.textNode("csv"));
test.set("entity_type", nf.textNode(entityType.getName()));
test.set("metadata", mapper.valueToTree(meta));
mapper.writeValue(testManifestFiles.get(i).toFile(), test);
}
dsNode.set("datasets", dsList);
mapper.writeValue(dataSetFile.toFile(), dsNode);
}
/**
* Get the train-test splits as data sets.
*
* @return The data sets produced by this crossfolder.
*/
public List getDataSets() {
Preconditions.checkState(executed, "crossfolder has not been executed");
Path dataSetFile = outputDir.resolve("datasets.yaml");
try {
return DataSet.load(dataSetFile);
} catch (IOException e) {
throw new DataAccessException("cannot load data sets", e);
}
}
RatingWriter openWriter(Path file) throws IOException {
return RatingWriters.csv(file.toFile(), writeTimestamps);
}
@Override
public String toString() {
return String.format("{CXManager %s}", source);
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy