All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.grouplens.lenskit.eval.traintest.TrainTestEvalTask Maven / Gradle / Ivy

There is a newer version: 3.0-T5
Show newest version
/*
 * LensKit, an open source recommender systems toolkit.
 * Copyright 2010-2014 LensKit Contributors.  See CONTRIBUTORS.md.
 * Work on LensKit has been funded by the National Science Foundation under
 * grants IIS 05-34939, 08-08692, 08-12148, and 10-17697.
 *
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU Lesser General Public License as
 * published by the Free Software Foundation; either version 2.1 of the
 * License, or (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful, but WITHOUT
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
 * FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
 * details.
 *
 * You should have received a copy of the GNU General Public License along with
 * this program; if not, write to the Free Software Foundation, Inc., 51
 * Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
 */
package org.grouplens.lenskit.eval.traintest;

import com.google.common.base.Function;
import com.google.common.base.Preconditions;
import com.google.common.base.Throwables;
import com.google.common.collect.*;
import com.google.common.io.Closer;
import org.apache.commons.lang3.tuple.Pair;
import org.grouplens.grapht.Component;
import org.grouplens.grapht.Dependency;
import org.grouplens.grapht.graph.DAGNode;
import org.grouplens.grapht.graph.MergePool;
import org.grouplens.lenskit.Recommender;
import org.grouplens.lenskit.core.LenskitConfiguration;
import org.grouplens.lenskit.core.RecommenderConfigurationException;
import org.grouplens.lenskit.data.dao.UserEventDAO;
import org.grouplens.lenskit.eval.AbstractTask;
import org.grouplens.lenskit.eval.ExecutionInfo;
import org.grouplens.lenskit.eval.TaskExecutionException;
import org.grouplens.lenskit.eval.algorithm.AlgorithmInstance;
import org.grouplens.lenskit.eval.algorithm.AlgorithmInstanceBuilder;
import org.grouplens.lenskit.eval.data.traintest.TTDataSet;
import org.grouplens.lenskit.eval.metrics.Metric;
import org.grouplens.lenskit.symbols.Symbol;
import org.grouplens.lenskit.util.parallel.TaskGraphExecutor;
import org.grouplens.lenskit.util.table.Table;
import org.grouplens.lenskit.util.table.TableBuilder;
import org.grouplens.lenskit.util.table.TableLayout;
import org.grouplens.lenskit.util.table.writer.CSVWriter;
import org.grouplens.lenskit.util.table.writer.MultiplexedTableWriter;
import org.grouplens.lenskit.util.table.writer.TableWriter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import java.io.File;
import java.io.IOException;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.ExecutionException;

/**
 * The command that run the algorithmInfo instance and output the prediction result file and the evaluation result file
 *
 * @author GroupLens Research
 */
public class TrainTestEvalTask extends AbstractTask {
    private static final Logger logger = LoggerFactory.getLogger(TrainTestEvalTask.class);

    private List dataSets;
    private List algorithms;
    private List externalAlgorithms;
    private List> metrics;
    private boolean isolate;
    private boolean separateAlgorithms;
    private File outputFile;
    private File userOutputFile;
    private File cacheDir;
    private File taskGraphFile;
    private File taskStatusFile;
    private boolean cacheAll = false;
    private OutputPredictMetric.FactoryBuilder defaultPredict = new OutputPredictMetric.FactoryBuilder();
    private OutputTopNMetric.FactoryBuilder defaultRecommend = new OutputTopNMetric.FactoryBuilder();

    private ExperimentSuite experiments;
    private MeasurementSuite measurements;
    private ExperimentOutputLayout layout;
    private ExperimentOutputs outputs;

    public TrainTestEvalTask() {
        this("train-test");
    }

    public TrainTestEvalTask(String name) {
        super(name);
        dataSets = Lists.newArrayList();
        algorithms = Lists.newArrayList();
        externalAlgorithms = Lists.newArrayList();
        metrics = Lists.newArrayList();
        outputFile = new File("train-test-results.csv");
        isolate = false;
    }

    public TrainTestEvalTask addDataset(TTDataSet source) {
        dataSets.add(source);
        return this;
    }

    public TrainTestEvalTask addAlgorithm(AlgorithmInstance algorithm) {
        algorithms.add(algorithm);
        return this;
    }

    public TrainTestEvalTask addAlgorithm(Map attrs, String file) throws IOException, RecommenderConfigurationException {
        algorithms.add(new AlgorithmInstanceBuilder().setProject(getProject())
                                                     .configureFromFile(attrs, new File(file))
                                                     .build());
        return this;
    }

    public TrainTestEvalTask addExternalAlgorithm(ExternalAlgorithm algorithm) {
        externalAlgorithms.add(algorithm);
        return this;
    }

    /**
     * Add a metric.
     * @param metric The metric to add.
     * @return The task (for chaining).
     */
    public TrainTestEvalTask addMetric(Metric metric) {
        metrics.add(MetricFactory.forMetric(metric));
        return this;
    }

    /**
     * Add a metric by factory.
     * @param factory The metric factory to add.
     * @return The task (for chaining).
     */
    public TrainTestEvalTask addMetric(MetricFactory factory) {
        metrics.add(factory);
        return this;
    }

    /**
     * Add a metric by class.
     * @param metricClass The metric class.
     * @param  The metric's return type.
     * @return The task (for chaining).
     */
    public  TrainTestEvalTask addMetric(Class> metricClass) {
        metrics.add(MetricFactory.forMetricClass(metricClass));
        return this;
    }

    /**
     * Add a metric that may write multiple columns per algorithmInfo.
     * @param file The output file.
     * @param columns The column headers.
     * @param metric The metric function. It should return a list of table rows, each of which has
     *               entries corresponding to each column.
     * @return The command (for chaining).
     */
    public TrainTestEvalTask addMultiMetric(File file, List columns, Function>> metric) {
        metrics.add(new FunctionMultiModelMetric.Factory(file, columns, metric));
        return this;
    }

    /**
     * Add a metric that takes some metric of an algorithmInfo.
     * @param columns The column headers.
     * @param metric The metric function. It should return a list of table rows, each of which has
     *               entries corresponding to each column.
     * @return The command (for chaining).
     */
    public TrainTestEvalTask addMetric(List columns, Function> metric) {
        addMetric(new FunctionModelMetric(columns, metric));
        return this;
    }

    /**
     * Add a channel to be recorded with predict output.
     *
     * @param channelSym The channel to output.
     * @return The command (for chaining)
     * @see #addWritePredictionChannel
     */
    public TrainTestEvalTask addWritePredictionChannel(@Nonnull Symbol channelSym) {
        return addWritePredictionChannel(channelSym, null);
    }

    /**
     * Add a channel to be recorded with predict output.
     *
     * @param channelSym The channel to record, if present in the prediction output vector.
     * @param label   The column label. If {@code null}, the channel symbol's name is used.
     * @return The command (for chaining).
     * @see #setPredictOutput(File)
     * @deprecated Use the {@code predictions} metric.
     */
    @Deprecated
    public TrainTestEvalTask addWritePredictionChannel(@Nonnull Symbol channelSym,
                                                       @Nullable String label) {
        Preconditions.checkNotNull(channelSym, "channel is null");
        if (label == null) {
            label = channelSym.getName();
        }
        defaultPredict.addChannel(channelSym, label);
        return this;
    }

    public TrainTestEvalTask setOutput(File file) {
        outputFile = file;
        return this;
    }

    public TrainTestEvalTask setOutput(String fn) {
        return setOutput(new File(fn));
    }

    public TrainTestEvalTask setUserOutput(File file) {
        userOutputFile = file;
        return this;
    }

    public TrainTestEvalTask setUserOutput(String fn) {
        return setUserOutput(new File(fn));
    }

    @Deprecated
    public TrainTestEvalTask setPredictOutput(File file) {
        logger.warn("predictOutput is deprecated, use predictions metric");
        defaultPredict.setFile(file);
        return this;
    }

    @Deprecated
    public TrainTestEvalTask setPredictOutput(String fn) {
        return setPredictOutput(new File(fn));
    }

    @Deprecated
    public TrainTestEvalTask setRecommendOutput(File file) {
        logger.warn("recommendOutput is deprecated, use predictions metric");
        defaultRecommend.setFile(file);
        return this;
    }

    @Deprecated
    public TrainTestEvalTask setRecommendOutput(String fn) {
        return setRecommendOutput(new File(fn));
    }

    /**
     * Set the component cache directory.  This directory is used for caching components that
     * might take a lot of memory but are shared between runs.  Only meaningful if algorithms
     * are not separated.
     *
     * @param file The cache directory.
     * @return The task (for chaining)
     * @see #setSeparateAlgorithms(boolean)
     */
    public TrainTestEvalTask setComponentCacheDirectory(File file) {
        cacheDir = file;
        return this;
    }

    /**
     * Set the component cache directory by name.
     *
     * @see #setComponentCacheDirectory(java.io.File)
     */
    public TrainTestEvalTask setComponentCacheDirectory(String fn) {
        return setComponentCacheDirectory(new File(fn));
    }

    public File getComponentCacheDirectory() {
        return cacheDir;
    }

    /**
     * Control whether all components are cached.  Caching all components will use more disk space,
     * but may allow future evaluations to re-use component instances from prior runs.
     *
     * @param flag {@code true} to cache all components.
     * @return The task (for chaining).
     */
    public TrainTestEvalTask setCacheAllComponents(boolean flag) {
        cacheAll = flag;
        return this;
    }

    public boolean getCacheAllComponents() {
        return cacheAll;
    }

    /**
     * Control whether the train-test evaluator will isolate data sets.  If set to {@code true},
     * then each data set will be run in turn, with no inter-data-set parallelism.  This can
     * reduce memory usage for some large runs, keeping the data from only a single data set in
     * memory at a time.  Otherwise (the default), individual algorithmInfo/data-set runs may be freely
     * intermingled.
     *
     * @param iso Whether to isolate data sets.
     * @return The task (for chaining).
     * @deprecated Isolate data sets instead.
     */
    @Deprecated
    public TrainTestEvalTask setIsolate(boolean iso) {
        logger.warn("Eval task isolation is deprecated. Isolate data sets instead.");
        isolate = iso;
        return this;
    }

    /**
     * Control whether the evaluator separates or combines algorithms.  If set to {@code true}, each
     * algorithm instance is built without reusing components from other algorithm instances.
     * By default, LensKit will try to reuse common components between algorithm configurations
     * (one downside of this is that it makes build timings meaningless).
     *
     * @param sep If {@code true}, separate algorithms.
     * @return The task (for chaining).
     */
    public TrainTestEvalTask setSeparateAlgorithms(boolean sep) {
        separateAlgorithms = sep;
        return this;
    }

    List dataSources() {
        return dataSets;
    }

    List getAlgorithms() {
        return algorithms;
    }

    List getExternalAlgorithms() {
        return externalAlgorithms;
    }

    List> getMetricFactories() {
        return metrics;
    }

    List> getPredictionChannels() {
        return defaultPredict.getChannels();
    }

    File getOutput() {
        return outputFile;
    }

    File getPredictOutput() {
        return defaultPredict.getFile();
    }

    File getRecommendOutput() {
        return defaultRecommend.getFile();
    }

    /**
     * Set an output file for a description of the task graph.
     * @param f The output file.
     * @return The task (for chaining).
     */
    public TrainTestEvalTask setTaskGraphFile(File f) {
        taskGraphFile = f;
        return this;
    }

    /**
     * Set an output file for a description of the task graph.
     * @param f The output file name
     * @return The task (for chaining).
     */
    public TrainTestEvalTask setTaskGraphFile(String f) {
        return setTaskGraphFile(new File(f));
    }

    public File getTaskGraphFile() {
        return taskGraphFile;
    }

    /**
     * Set an output file for task status updates.
     * @param f The output file.
     * @return The task (for chaining).
     */
    public TrainTestEvalTask setTaskStatusFile(File f) {
        taskStatusFile = f;
        return this;
    }

    /**
     * Set an output file for task status updates.
     * @param f The output file name
     * @return The task (for chaining).
     */
    public TrainTestEvalTask setTaskStatusFile(String f) {
        return setTaskStatusFile(new File(f));
    }

    public File getTaskStatusFile() {
        return taskStatusFile;
    }

    /**
     * Run the evaluation on the train test data source files
     *
     * @return The summary output table.
     * @throws org.grouplens.lenskit.eval.TaskExecutionException
     *          Failure of the evaluation
     */
    @Override
    @SuppressWarnings("PMD.AvoidCatchingThrowable")
    public Table perform() throws TaskExecutionException, InterruptedException {
        try {
            experiments = createExperimentSuite();
            measurements = createMeasurementSuite();
            layout = ExperimentOutputLayout.create(experiments, measurements);
            TableBuilder resultsBuilder = new TableBuilder(layout.getResultsLayout());

            logger.info("Starting evaluation of {} algorithms ({} from LensKit) on {} data sets",
                        Iterables.size(experiments.getAllAlgorithms()),
                        experiments.getAlgorithms().size(),
                        experiments.getDataSets().size());
            Closer closer = Closer.create();

            try {
                outputs = openExperimentOutputs(layout, measurements, resultsBuilder, closer);
                DAGNode jobGraph;
                try {
                    jobGraph = makeJobGraph(experiments);
                } catch (RecommenderConfigurationException ex) {
                    throw new TaskExecutionException("Recommender configuration error", ex);
                }
                if (taskGraphFile != null) {
                    logger.info("writing task graph to {}", taskGraphFile);
                    JobGraph.writeGraphDescription(jobGraph, taskGraphFile);
                }
                registerTaskListener(jobGraph);

                // tell all metrics to get started
                runEvaluations(jobGraph);
            } catch (Throwable th) { // NOSONAR using a closer
                throw closer.rethrow(th, TaskExecutionException.class, InterruptedException.class);
            } finally {
                closer.close();
            }

            logger.info("evaluation {} completed", getName());

            return resultsBuilder.build();
        } catch (IOException e) {
            throw new TaskExecutionException("I/O error", e);
        } finally {
            experiments = null;
            measurements = null;
            outputs = null;
            layout = null;
        }
    }

    private void registerTaskListener(DAGNode jobGraph) {
        if (taskStatusFile != null) {
            ImmutableSet.Builder jobs = ImmutableSet.builder();
            for (DAGNode node: jobGraph.getReachableNodes()) {
                TrainTestJob job = node.getLabel().getJob();
                if (job != null) {
                    jobs.add(job);
                }
            }
            JobStatusWriter monitor = new JobStatusWriter(this, jobs.build(), taskStatusFile);
            getProject().getEventBus().register(monitor);
        }
    }

    public ExperimentSuite getExperiments() {
        Preconditions.checkState(experiments != null, "evaluation not in progress");
        return experiments;
    }

    public MeasurementSuite getMeasurements() {
        Preconditions.checkState(measurements != null, "evaluation not in progress");
        return measurements;
    }

    public ExperimentOutputLayout getOutputLayout() {
        Preconditions.checkState(layout != null, "evaluation not in progress");
        return layout;
    }

    public ExperimentOutputs getOutputs() {
        Preconditions.checkState(outputs != null, "evaluation not in progress");
        return outputs;
    }

    ExperimentSuite createExperimentSuite() {
        return new ExperimentSuite(algorithms, externalAlgorithms, dataSets);
    }

    MeasurementSuite createMeasurementSuite() {
        ImmutableList.Builder> activeMetrics = ImmutableList.builder();
        activeMetrics.addAll(metrics);
        if (defaultRecommend.getFile() != null) {
            activeMetrics.add(defaultRecommend.build());
        }
        if (defaultPredict.getFile() != null) {
            activeMetrics.add(defaultPredict.build());
        }
        return new MeasurementSuite(activeMetrics.build());
    }

    private void runEvaluations(DAGNode graph) throws TaskExecutionException, InterruptedException {
        int nthreads = getProject().getConfig().getThreadCount();
        TaskGraphExecutor exec;
        logger.info("Running evaluator with {} threads", nthreads);
        if (nthreads == 1) {
            exec = TaskGraphExecutor.singleThreaded();
        } else {
            exec = TaskGraphExecutor.create(nthreads);
        }

        try {
            exec.execute(graph);
        } catch (ExecutionException e) {
            Throwables.propagateIfInstanceOf(e.getCause(), TaskExecutionException.class);
            throw new TaskExecutionException("error in evaluation job task", e.getCause());
        }
    }

    DAGNode makeJobGraph(ExperimentSuite experiments) throws RecommenderConfigurationException {
        Multimap grouped = LinkedHashMultimap.create();
        for (TTDataSet dataset : experiments.getDataSets()) {
            UUID grp = dataset.getIsolationGroup();
            if (isolate) {
                // force isolation
                grp = UUID.randomUUID();
            }
            grouped.put(grp, dataset);
        }

        ComponentCache cache = new ComponentCache(cacheDir, getProject().getClassLoader());
        JobGraphBuilder builder = new JobGraphBuilder(this, cache);

        for (UUID groupId: grouped.keySet()) {
            Collection dss = grouped.get(groupId);
            String groupName = dss.size() == 1 ? dss.iterator().next().getName() : groupId.toString();
            for (TTDataSet dataset: dss) {
                // Add LensKit algorithms
                addAlgorithmNodes(builder, dataset, experiments.getAlgorithms(), cache);

                // Add external algorithms
                for (ExternalAlgorithm algo: experiments.getExternalAlgorithms()) {
                    builder.addExternalJob(algo, dataset);
                }
            }

            builder.fence(groupName);
        }
        return builder.getGraph();
    }

    private void addAlgorithmNodes(JobGraphBuilder builder, TTDataSet dataset,
                                   List algorithms,
                                   ComponentCache cache) throws RecommenderConfigurationException {
        MergePool pool = MergePool.create();

        for (AlgorithmInstance algo: algorithms) {
            logger.debug("building graph for algorithm {}", algo);
            LenskitConfiguration dataConfig = new LenskitConfiguration();
            ExecutionInfo info = ExecutionInfo.newBuilder()
                                              .setAlgorithm(algo)
                                              .setDataSet(dataset)
                                              .build();
            dataConfig.addComponent(info);
            dataset.configure(dataConfig);
            // we need the user event DAO in the test phase
            dataConfig.addRoot(UserEventDAO.class);
            // Build the graph
            DAGNode graph = algo.buildRecommenderGraph(dataConfig);

            if (!separateAlgorithms) {
                logger.debug("merging algorithm {} with previous graphs", algo);
                graph = pool.merge(graph);
            }
            if (cacheAll && cache != null) {
                cache.registerSharedNodes(graph.getReachableNodes());
            }

            builder.addLenskitJob(algo, dataset, graph);
        }
    }

    /**
     * Prepare the evaluation by opening all outputs and initializing metrics.
     */
    ExperimentOutputs openExperimentOutputs(ExperimentOutputLayout layouts, MeasurementSuite measures, TableWriter results, Closer closer) throws IOException {
        TableLayout resultLayout = layouts.getResultsLayout();
        TableWriter allResults = results;
        if (outputFile != null) {
            TableWriter disk = closer.register(CSVWriter.open(outputFile, resultLayout));
            allResults = new MultiplexedTableWriter(resultLayout, allResults, disk);
        }
        TableWriter user = null;
        if (userOutputFile != null) {
            user = closer.register(CSVWriter.open(userOutputFile, layouts.getUserLayout()));
        }
        List> metrics = Lists.newArrayList();
        for (MetricFactory metric : measures.getMetricFactories()) {
            metrics.add(closer.register(metric.createMetric(this)));
        }
        return new ExperimentOutputs(layouts, allResults, user, metrics);
    }
}