org.grouplens.lenskit.eval.traintest.TrainTestEvalTask Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of lenskit-eval Show documentation
Show all versions of lenskit-eval Show documentation
Facilities for evaluating recommender algorithms.
/*
* LensKit, an open source recommender systems toolkit.
* Copyright 2010-2014 LensKit Contributors. See CONTRIBUTORS.md.
* Work on LensKit has been funded by the National Science Foundation under
* grants IIS 05-34939, 08-08692, 08-12148, and 10-17697.
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2.1 of the
* License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
* FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
* details.
*
* You should have received a copy of the GNU General Public License along with
* this program; if not, write to the Free Software Foundation, Inc., 51
* Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
*/
package org.grouplens.lenskit.eval.traintest;
import com.google.common.base.Function;
import com.google.common.base.Preconditions;
import com.google.common.base.Throwables;
import com.google.common.collect.*;
import com.google.common.io.Closer;
import org.apache.commons.lang3.tuple.Pair;
import org.grouplens.grapht.Component;
import org.grouplens.grapht.Dependency;
import org.grouplens.grapht.graph.DAGNode;
import org.grouplens.grapht.graph.MergePool;
import org.grouplens.lenskit.Recommender;
import org.grouplens.lenskit.core.LenskitConfiguration;
import org.grouplens.lenskit.core.RecommenderConfigurationException;
import org.grouplens.lenskit.data.dao.UserEventDAO;
import org.grouplens.lenskit.eval.AbstractTask;
import org.grouplens.lenskit.eval.ExecutionInfo;
import org.grouplens.lenskit.eval.TaskExecutionException;
import org.grouplens.lenskit.eval.algorithm.AlgorithmInstance;
import org.grouplens.lenskit.eval.algorithm.AlgorithmInstanceBuilder;
import org.grouplens.lenskit.eval.data.traintest.TTDataSet;
import org.grouplens.lenskit.eval.metrics.Metric;
import org.grouplens.lenskit.symbols.Symbol;
import org.grouplens.lenskit.util.parallel.TaskGraphExecutor;
import org.grouplens.lenskit.util.table.Table;
import org.grouplens.lenskit.util.table.TableBuilder;
import org.grouplens.lenskit.util.table.TableLayout;
import org.grouplens.lenskit.util.table.writer.CSVWriter;
import org.grouplens.lenskit.util.table.writer.MultiplexedTableWriter;
import org.grouplens.lenskit.util.table.writer.TableWriter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import java.io.File;
import java.io.IOException;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.ExecutionException;
/**
* The command that run the algorithmInfo instance and output the prediction result file and the evaluation result file
*
* @author GroupLens Research
*/
public class TrainTestEvalTask extends AbstractTask {
private static final Logger logger = LoggerFactory.getLogger(TrainTestEvalTask.class);
private List dataSets;
private List algorithms;
private List externalAlgorithms;
private List> metrics;
private boolean isolate;
private boolean separateAlgorithms;
private File outputFile;
private File userOutputFile;
private File cacheDir;
private File taskGraphFile;
private File taskStatusFile;
private boolean cacheAll = false;
private OutputPredictMetric.FactoryBuilder defaultPredict = new OutputPredictMetric.FactoryBuilder();
private OutputTopNMetric.FactoryBuilder defaultRecommend = new OutputTopNMetric.FactoryBuilder();
private ExperimentSuite experiments;
private MeasurementSuite measurements;
private ExperimentOutputLayout layout;
private ExperimentOutputs outputs;
public TrainTestEvalTask() {
this("train-test");
}
public TrainTestEvalTask(String name) {
super(name);
dataSets = Lists.newArrayList();
algorithms = Lists.newArrayList();
externalAlgorithms = Lists.newArrayList();
metrics = Lists.newArrayList();
outputFile = new File("train-test-results.csv");
isolate = false;
}
public TrainTestEvalTask addDataset(TTDataSet source) {
dataSets.add(source);
return this;
}
public TrainTestEvalTask addAlgorithm(AlgorithmInstance algorithm) {
algorithms.add(algorithm);
return this;
}
public TrainTestEvalTask addAlgorithm(Map attrs, String file) throws IOException, RecommenderConfigurationException {
algorithms.add(new AlgorithmInstanceBuilder().setProject(getProject())
.configureFromFile(attrs, new File(file))
.build());
return this;
}
public TrainTestEvalTask addExternalAlgorithm(ExternalAlgorithm algorithm) {
externalAlgorithms.add(algorithm);
return this;
}
/**
* Add a metric.
* @param metric The metric to add.
* @return The task (for chaining).
*/
public TrainTestEvalTask addMetric(Metric> metric) {
metrics.add(MetricFactory.forMetric(metric));
return this;
}
/**
* Add a metric by factory.
* @param factory The metric factory to add.
* @return The task (for chaining).
*/
public TrainTestEvalTask addMetric(MetricFactory> factory) {
metrics.add(factory);
return this;
}
/**
* Add a metric by class.
* @param metricClass The metric class.
* @param The metric's return type.
* @return The task (for chaining).
*/
public TrainTestEvalTask addMetric(Class extends Metric> metricClass) {
metrics.add(MetricFactory.forMetricClass(metricClass));
return this;
}
/**
* Add a metric that may write multiple columns per algorithmInfo.
* @param file The output file.
* @param columns The column headers.
* @param metric The metric function. It should return a list of table rows, each of which has
* entries corresponding to each column.
* @return The command (for chaining).
*/
public TrainTestEvalTask addMultiMetric(File file, List columns, Function>> metric) {
metrics.add(new FunctionMultiModelMetric.Factory(file, columns, metric));
return this;
}
/**
* Add a metric that takes some metric of an algorithmInfo.
* @param columns The column headers.
* @param metric The metric function. It should return a list of table rows, each of which has
* entries corresponding to each column.
* @return The command (for chaining).
*/
public TrainTestEvalTask addMetric(List columns, Function> metric) {
addMetric(new FunctionModelMetric(columns, metric));
return this;
}
/**
* Add a channel to be recorded with predict output.
*
* @param channelSym The channel to output.
* @return The command (for chaining)
* @see #addWritePredictionChannel
*/
public TrainTestEvalTask addWritePredictionChannel(@Nonnull Symbol channelSym) {
return addWritePredictionChannel(channelSym, null);
}
/**
* Add a channel to be recorded with predict output.
*
* @param channelSym The channel to record, if present in the prediction output vector.
* @param label The column label. If {@code null}, the channel symbol's name is used.
* @return The command (for chaining).
* @see #setPredictOutput(File)
* @deprecated Use the {@code predictions} metric.
*/
@Deprecated
public TrainTestEvalTask addWritePredictionChannel(@Nonnull Symbol channelSym,
@Nullable String label) {
Preconditions.checkNotNull(channelSym, "channel is null");
if (label == null) {
label = channelSym.getName();
}
defaultPredict.addChannel(channelSym, label);
return this;
}
public TrainTestEvalTask setOutput(File file) {
outputFile = file;
return this;
}
public TrainTestEvalTask setOutput(String fn) {
return setOutput(new File(fn));
}
public TrainTestEvalTask setUserOutput(File file) {
userOutputFile = file;
return this;
}
public TrainTestEvalTask setUserOutput(String fn) {
return setUserOutput(new File(fn));
}
@Deprecated
public TrainTestEvalTask setPredictOutput(File file) {
logger.warn("predictOutput is deprecated, use predictions metric");
defaultPredict.setFile(file);
return this;
}
@Deprecated
public TrainTestEvalTask setPredictOutput(String fn) {
return setPredictOutput(new File(fn));
}
@Deprecated
public TrainTestEvalTask setRecommendOutput(File file) {
logger.warn("recommendOutput is deprecated, use predictions metric");
defaultRecommend.setFile(file);
return this;
}
@Deprecated
public TrainTestEvalTask setRecommendOutput(String fn) {
return setRecommendOutput(new File(fn));
}
/**
* Set the component cache directory. This directory is used for caching components that
* might take a lot of memory but are shared between runs. Only meaningful if algorithms
* are not separated.
*
* @param file The cache directory.
* @return The task (for chaining)
* @see #setSeparateAlgorithms(boolean)
*/
public TrainTestEvalTask setComponentCacheDirectory(File file) {
cacheDir = file;
return this;
}
/**
* Set the component cache directory by name.
*
* @see #setComponentCacheDirectory(java.io.File)
*/
public TrainTestEvalTask setComponentCacheDirectory(String fn) {
return setComponentCacheDirectory(new File(fn));
}
public File getComponentCacheDirectory() {
return cacheDir;
}
/**
* Control whether all components are cached. Caching all components will use more disk space,
* but may allow future evaluations to re-use component instances from prior runs.
*
* @param flag {@code true} to cache all components.
* @return The task (for chaining).
*/
public TrainTestEvalTask setCacheAllComponents(boolean flag) {
cacheAll = flag;
return this;
}
public boolean getCacheAllComponents() {
return cacheAll;
}
/**
* Control whether the train-test evaluator will isolate data sets. If set to {@code true},
* then each data set will be run in turn, with no inter-data-set parallelism. This can
* reduce memory usage for some large runs, keeping the data from only a single data set in
* memory at a time. Otherwise (the default), individual algorithmInfo/data-set runs may be freely
* intermingled.
*
* @param iso Whether to isolate data sets.
* @return The task (for chaining).
* @deprecated Isolate data sets instead.
*/
@Deprecated
public TrainTestEvalTask setIsolate(boolean iso) {
logger.warn("Eval task isolation is deprecated. Isolate data sets instead.");
isolate = iso;
return this;
}
/**
* Control whether the evaluator separates or combines algorithms. If set to {@code true}, each
* algorithm instance is built without reusing components from other algorithm instances.
* By default, LensKit will try to reuse common components between algorithm configurations
* (one downside of this is that it makes build timings meaningless).
*
* @param sep If {@code true}, separate algorithms.
* @return The task (for chaining).
*/
public TrainTestEvalTask setSeparateAlgorithms(boolean sep) {
separateAlgorithms = sep;
return this;
}
List dataSources() {
return dataSets;
}
List getAlgorithms() {
return algorithms;
}
List getExternalAlgorithms() {
return externalAlgorithms;
}
List> getMetricFactories() {
return metrics;
}
List> getPredictionChannels() {
return defaultPredict.getChannels();
}
File getOutput() {
return outputFile;
}
File getPredictOutput() {
return defaultPredict.getFile();
}
File getRecommendOutput() {
return defaultRecommend.getFile();
}
/**
* Set an output file for a description of the task graph.
* @param f The output file.
* @return The task (for chaining).
*/
public TrainTestEvalTask setTaskGraphFile(File f) {
taskGraphFile = f;
return this;
}
/**
* Set an output file for a description of the task graph.
* @param f The output file name
* @return The task (for chaining).
*/
public TrainTestEvalTask setTaskGraphFile(String f) {
return setTaskGraphFile(new File(f));
}
public File getTaskGraphFile() {
return taskGraphFile;
}
/**
* Set an output file for task status updates.
* @param f The output file.
* @return The task (for chaining).
*/
public TrainTestEvalTask setTaskStatusFile(File f) {
taskStatusFile = f;
return this;
}
/**
* Set an output file for task status updates.
* @param f The output file name
* @return The task (for chaining).
*/
public TrainTestEvalTask setTaskStatusFile(String f) {
return setTaskStatusFile(new File(f));
}
public File getTaskStatusFile() {
return taskStatusFile;
}
/**
* Run the evaluation on the train test data source files
*
* @return The summary output table.
* @throws org.grouplens.lenskit.eval.TaskExecutionException
* Failure of the evaluation
*/
@Override
@SuppressWarnings("PMD.AvoidCatchingThrowable")
public Table perform() throws TaskExecutionException, InterruptedException {
try {
experiments = createExperimentSuite();
measurements = createMeasurementSuite();
layout = ExperimentOutputLayout.create(experiments, measurements);
TableBuilder resultsBuilder = new TableBuilder(layout.getResultsLayout());
logger.info("Starting evaluation of {} algorithms ({} from LensKit) on {} data sets",
Iterables.size(experiments.getAllAlgorithms()),
experiments.getAlgorithms().size(),
experiments.getDataSets().size());
Closer closer = Closer.create();
try {
outputs = openExperimentOutputs(layout, measurements, resultsBuilder, closer);
DAGNode jobGraph;
try {
jobGraph = makeJobGraph(experiments);
} catch (RecommenderConfigurationException ex) {
throw new TaskExecutionException("Recommender configuration error", ex);
}
if (taskGraphFile != null) {
logger.info("writing task graph to {}", taskGraphFile);
JobGraph.writeGraphDescription(jobGraph, taskGraphFile);
}
registerTaskListener(jobGraph);
// tell all metrics to get started
runEvaluations(jobGraph);
} catch (Throwable th) { // NOSONAR using a closer
throw closer.rethrow(th, TaskExecutionException.class, InterruptedException.class);
} finally {
closer.close();
}
logger.info("evaluation {} completed", getName());
return resultsBuilder.build();
} catch (IOException e) {
throw new TaskExecutionException("I/O error", e);
} finally {
experiments = null;
measurements = null;
outputs = null;
layout = null;
}
}
private void registerTaskListener(DAGNode jobGraph) {
if (taskStatusFile != null) {
ImmutableSet.Builder jobs = ImmutableSet.builder();
for (DAGNode node: jobGraph.getReachableNodes()) {
TrainTestJob job = node.getLabel().getJob();
if (job != null) {
jobs.add(job);
}
}
JobStatusWriter monitor = new JobStatusWriter(this, jobs.build(), taskStatusFile);
getProject().getEventBus().register(monitor);
}
}
public ExperimentSuite getExperiments() {
Preconditions.checkState(experiments != null, "evaluation not in progress");
return experiments;
}
public MeasurementSuite getMeasurements() {
Preconditions.checkState(measurements != null, "evaluation not in progress");
return measurements;
}
public ExperimentOutputLayout getOutputLayout() {
Preconditions.checkState(layout != null, "evaluation not in progress");
return layout;
}
public ExperimentOutputs getOutputs() {
Preconditions.checkState(outputs != null, "evaluation not in progress");
return outputs;
}
ExperimentSuite createExperimentSuite() {
return new ExperimentSuite(algorithms, externalAlgorithms, dataSets);
}
MeasurementSuite createMeasurementSuite() {
ImmutableList.Builder> activeMetrics = ImmutableList.builder();
activeMetrics.addAll(metrics);
if (defaultRecommend.getFile() != null) {
activeMetrics.add(defaultRecommend.build());
}
if (defaultPredict.getFile() != null) {
activeMetrics.add(defaultPredict.build());
}
return new MeasurementSuite(activeMetrics.build());
}
private void runEvaluations(DAGNode graph) throws TaskExecutionException, InterruptedException {
int nthreads = getProject().getConfig().getThreadCount();
TaskGraphExecutor exec;
logger.info("Running evaluator with {} threads", nthreads);
if (nthreads == 1) {
exec = TaskGraphExecutor.singleThreaded();
} else {
exec = TaskGraphExecutor.create(nthreads);
}
try {
exec.execute(graph);
} catch (ExecutionException e) {
Throwables.propagateIfInstanceOf(e.getCause(), TaskExecutionException.class);
throw new TaskExecutionException("error in evaluation job task", e.getCause());
}
}
DAGNode makeJobGraph(ExperimentSuite experiments) throws RecommenderConfigurationException {
Multimap grouped = LinkedHashMultimap.create();
for (TTDataSet dataset : experiments.getDataSets()) {
UUID grp = dataset.getIsolationGroup();
if (isolate) {
// force isolation
grp = UUID.randomUUID();
}
grouped.put(grp, dataset);
}
ComponentCache cache = new ComponentCache(cacheDir, getProject().getClassLoader());
JobGraphBuilder builder = new JobGraphBuilder(this, cache);
for (UUID groupId: grouped.keySet()) {
Collection dss = grouped.get(groupId);
String groupName = dss.size() == 1 ? dss.iterator().next().getName() : groupId.toString();
for (TTDataSet dataset: dss) {
// Add LensKit algorithms
addAlgorithmNodes(builder, dataset, experiments.getAlgorithms(), cache);
// Add external algorithms
for (ExternalAlgorithm algo: experiments.getExternalAlgorithms()) {
builder.addExternalJob(algo, dataset);
}
}
builder.fence(groupName);
}
return builder.getGraph();
}
private void addAlgorithmNodes(JobGraphBuilder builder, TTDataSet dataset,
List algorithms,
ComponentCache cache) throws RecommenderConfigurationException {
MergePool pool = MergePool.create();
for (AlgorithmInstance algo: algorithms) {
logger.debug("building graph for algorithm {}", algo);
LenskitConfiguration dataConfig = new LenskitConfiguration();
ExecutionInfo info = ExecutionInfo.newBuilder()
.setAlgorithm(algo)
.setDataSet(dataset)
.build();
dataConfig.addComponent(info);
dataset.configure(dataConfig);
// we need the user event DAO in the test phase
dataConfig.addRoot(UserEventDAO.class);
// Build the graph
DAGNode graph = algo.buildRecommenderGraph(dataConfig);
if (!separateAlgorithms) {
logger.debug("merging algorithm {} with previous graphs", algo);
graph = pool.merge(graph);
}
if (cacheAll && cache != null) {
cache.registerSharedNodes(graph.getReachableNodes());
}
builder.addLenskitJob(algo, dataset, graph);
}
}
/**
* Prepare the evaluation by opening all outputs and initializing metrics.
*/
ExperimentOutputs openExperimentOutputs(ExperimentOutputLayout layouts, MeasurementSuite measures, TableWriter results, Closer closer) throws IOException {
TableLayout resultLayout = layouts.getResultsLayout();
TableWriter allResults = results;
if (outputFile != null) {
TableWriter disk = closer.register(CSVWriter.open(outputFile, resultLayout));
allResults = new MultiplexedTableWriter(resultLayout, allResults, disk);
}
TableWriter user = null;
if (userOutputFile != null) {
user = closer.register(CSVWriter.open(userOutputFile, layouts.getUserLayout()));
}
List> metrics = Lists.newArrayList();
for (MetricFactory> metric : measures.getMetricFactories()) {
metrics.add(closer.register(metric.createMetric(this)));
}
return new ExperimentOutputs(layouts, allResults, user, metrics);
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy