All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.lenskit.eval.traintest.TrainTestExperiment Maven / Gradle / Ivy

There is a newer version: 3.0-M3
Show newest version
/*
 * LensKit, an open source recommender systems toolkit.
 * Copyright 2010-2014 LensKit Contributors.  See CONTRIBUTORS.md.
 * Work on LensKit has been funded by the National Science Foundation under
 * grants IIS 05-34939, 08-08692, 08-12148, and 10-17697.
 *
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU Lesser General Public License as
 * published by the Free Software Foundation; either version 2.1 of the
 * License, or (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful, but WITHOUT
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
 * FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
 * details.
 *
 * You should have received a copy of the GNU General Public License along with
 * this program; if not, write to the Free Software Foundation, Inc., 51
 * Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
 */
package org.lenskit.eval.traintest;

import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.dataformat.yaml.YAMLFactory;
import com.google.common.base.Preconditions;
import com.google.common.collect.FluentIterable;
import com.google.common.collect.Sets;
import com.google.common.io.Closer;
import groovy.lang.Closure;
import org.grouplens.grapht.Component;
import org.grouplens.grapht.Dependency;
import org.grouplens.grapht.graph.MergePool;
import org.grouplens.grapht.util.ClassLoaders;
import org.grouplens.lenskit.util.io.CompressionMode;
import org.grouplens.lenskit.util.io.LKFileUtils;
import org.lenskit.LenskitConfiguration;
import org.lenskit.config.ConfigHelpers;
import org.lenskit.eval.traintest.predict.PredictEvalTask;
import org.lenskit.eval.traintest.recommend.RecommendEvalTask;
import org.lenskit.util.parallel.TaskGroup;
import org.lenskit.util.table.Table;
import org.lenskit.util.table.TableBuilder;
import org.lenskit.util.table.TableLayout;
import org.lenskit.util.table.TableLayoutBuilder;
import org.lenskit.util.table.writer.CSVWriter;
import org.lenskit.util.table.writer.MultiplexedTableWriter;
import org.lenskit.util.table.writer.TableWriter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import java.io.IOException;
import java.net.URI;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.*;
import java.util.concurrent.ForkJoinPool;

/**
 * Sets up and runs train-test evaluations.  This class can be used directly, but it will usually be controlled from
 * the `train-test` command line tool in turn driven by a Gradle script.  For a simpler way to programatically run an
 * evaluation, see {@link org.lenskit.eval.traintest.SimpleEvaluator}, which provides a simplified interface
 * to train-test evaluations with cross-validation.
 *
 * A train-test experiment experiment consists of three things:
 *
 * - A collection of algorithms.
 * - A collection of train-test data sets.
 * - A collection of tasks, each of which performs an action on the recommender (e.g. predict users' test
 * ratings, or produce recommendations) and measures the recommender's performance on that task using one
 * or more metrics.
 *
 * Global output is aggregated into a CSV file; individual tasks or metrics may produce additional files.
 */
public class TrainTestExperiment {
    private static final Logger logger = LoggerFactory.getLogger(TrainTestExperiment.class);
    private Path outputFile;
    private Path userOutputFile;
    private Path cacheDir;
    private boolean shareModelComponents = true;
    private int threadCount = 1;
    private ClassLoader classLoader = ClassLoaders.inferDefault(TrainTestExperiment.class);

    private List algorithms = new ArrayList<>();
    private List dataSets = new ArrayList<>();
    private List tasks = new ArrayList<>();

    private TableWriter globalOutput;
    private TableWriter userOutput;
    private TableBuilder resultBuilder;
    private Closer resultCloser;
    private ExperimentOutputLayout outputLayout;
    private List allJobs;
    private TaskGroup rootJob;

    /**
     * Set the primary output file.
     * @param out The file where the primary aggregate output should go.
     */
    public void setOutputFile(Path out) {
        outputFile = out;
    }

    /**
     * Get the primary output file.
     * @return The primary output file.
     */
    public Path getOutputFile() {
        return outputFile;
    }

    /**
     * Get the per-user output file.
     * @return The output file for per-user measurements.
     */
    public Path getUserOutputFile() {
        return userOutputFile;
    }

    /**
     * Set the per-user output file.
     * @param file The file for per-user measurements.
     */
    public void setUserOutputFile(Path file) {
        userOutputFile = file;
    }

    /**
     * Get the algorithm instances.
     * @return The algorithms to run.
     */
    public List getAlgorithms() {
        return algorithms;
    }

    /**
     * Add an algorithm to the experiment.
     * @param algo The algorithm to add.
     */
    public void addAlgorithm(AlgorithmInstance algo) {
        algorithms.add(algo);
    }

    /**
     * Add multiple algorithm instances.
     * @param algos The algorithm instances to add.
     */
    public void addAlgorithms(List algos) {
        algorithms.addAll(algos);
    }

    /**
     * Add an algorithm configured by a Groovy closure.  Mostly useful for testing.
     * @param name The algorithm name.
     * @param block The algorithm configuration block.
     */
    public void addAlgorithm(String name, Closure block) {
        AlgorithmInstanceBuilder aib = new AlgorithmInstanceBuilder(name);
        LenskitConfiguration config = aib.getConfig();
        ConfigHelpers.configure(config, block);
        addAlgorithm(aib.build());
    }

    /**
     * Add one or more algorithms by loading a config file.
     * @param name The algorithm name.
     * @param file The config file to load.
     */
    public void addAlgorithm(String name, Path file) {
        addAlgorithms(AlgorithmInstance.load(file, name, classLoader));
    }

    /**
     * Add one or more algorithms from a configuration file.
     * @param file The configuration file.
     */
    public void addAlgorithms(Path file) {
        addAlgorithm(null, file);
    }

    /**
     * Get the list of data sets to use.
     * @return The list of data sets to use.
     */
    public List getDataSets() {
        return dataSets;
    }

    /**
     * Add a data set.
     * @param ds The data set to add.
     */
    public void addDataSet(DataSet ds) {
        dataSets.add(ds);
    }

    /**
     * Add several data sets.
     * @param dss The data sets to add.
     */
    public void addDataSets(List dss) {
        dataSets.addAll(dss);
    }

    /**
     * Query whether this experiment will cache and share components.
     *
     * @return {@code true} if model components will be shared.
     * @see #setShareModelComponents(boolean)
     */
    public boolean getShareModelComponents() {
        return shareModelComponents;
    }

    /**
     * Control whether model components will be shared.  If {@link #setCacheDirectory(Path)} is also set,
     * components will be cached on disk; otherwise, they will be opportunistically shared in memory.
     *
     * Cached output improves throughput and memory use, but makes build times effectively meaningless.  It
     * is turned on by default, but turn it off if you want to measure recommender build times.
     *
     * @param shares `true` to enable caching of shared model components.
     */
    public void setShareModelComponents(boolean shares) {
        shareModelComponents = shares;
    }

    /**
     * Get the cache directory for model components.
     * @return The directory where model components will be cached.
     */
    public Path getCacheDirectory() {
        return cacheDir;
    }

    /**
     * Set the cache directory for model components.
     * @param dir The directory where model components will be cached.
     */
    public void setCacheDirectory(Path dir) {
        cacheDir = dir;
    }

    /**
     * Get the number of threads that the experiment may use.
     *
     * @return The number of threads that the experiment may use.
     */
    public int getThreadCount() {
        int tc = threadCount;
        if (tc <= 0) {
            String prop = System.getProperty("lenskit.eval.threadCount");
            if (prop != null) {
                tc = Integer.parseInt(prop);
            }
        }
        if (tc <= 0) {
            tc = Runtime.getRuntime().availableProcessors();
        }
        return tc;
    }

    /**
     * Set the number of threads the experiment may use.
     *
     * @param tc The number of threads that the experiment may use.  If 0 (the default), consults the property
     *           `lenskit.eval.threadCount`, and if that is unset, uses as many threads as there
     *           are available processors according to {@link Runtime#availableProcessors()}.
     */
    public void setThreadCount(int tc) {
        threadCount = tc;
    }

    /**
     * Get the class loader for this experiment.
     * @return The class loader that will be used.
     */
    public ClassLoader getClassLoader() {
        return classLoader;
    }

    /**
     * Set the class loader for this experiment.
     * @param loader The class loader to use.
     */
    public void setClassLoader(ClassLoader loader) {
        classLoader = loader;
    }

    /**
     * Get the eval tasks to be used in this experiment.
     * @return The evaluation tasks to run.
     */
    public List getTasks() {
        return tasks;
    }

    /**
     * Add an evaluation task.
     * @param task An evaluation task to run.
     */
    public void addTask(EvalTask task) {
        tasks.add(task);
    }

    /**
     * Convenience method to get the prediction task for the experiment.  If there is not yet a prediction task, then
     * one is added.
     * @return The experiment's prediction task.
     */
    PredictEvalTask getPredictionTask() {
        List taskList = FluentIterable.from(tasks)
                                                   .filter(PredictEvalTask.class)
                                                   .toList();
        if (taskList.isEmpty()) {
            PredictEvalTask task = new PredictEvalTask();
            addTask(task);
            return task;
        } else {
            if (taskList.size() > 1) {
                logger.warn("multiple prediction tasks configured");
            }
            return taskList.get(0);
        }
    }

    /**
     * Get the global output table.
     * @return The global output table.
     */
    @Nonnull
    TableWriter getGlobalOutput() {
        Preconditions.checkState(resultBuilder != null, "Experiment has not been started");
        assert globalOutput != null;
        return globalOutput;
    }

    /**
     * Get the per-user output table.
     * @return The per-user output table.
     */
    @Nullable
    TableWriter getUserOutput() {
        Preconditions.checkState(resultBuilder != null, "Experiment has not been started");
        return userOutput;
    }

    /**
     * Run the experiment.
     * @return The global aggregate results from the experiment.
     */
    public Table execute() {
        try {
            try {
                resultCloser = Closer.create();
                logger.debug("setting up output");
                ExperimentOutputLayout layout = makeExperimentOutputLayout();
                openOutputs(layout);
                for (EvalTask task: tasks) {
                    task.start(layout);
                }

                logger.debug("gathering jobs");
                buildJobGraph();
                int nthreads = getThreadCount();
                if (nthreads > 1) {
                    logger.info("running with {} threads", nthreads);
                    runJobGraph(nthreads);
                } else {
                    logger.info("running in a single thread");
                    runJobList();
                }

                logger.info("train-test evaluation complete");
                // done before closing, but that is ok
                return resultBuilder.build();
            } catch (Throwable th) { //NOSONAR using closer
                throw resultCloser.rethrow(th);
            } finally {
                outputLayout = null;
                // FIXME Handle exceptions in task shutdown cleanly
                for (EvalTask task: tasks) {
                    task.finish();
                }
                resultBuilder = null;
                resultCloser.close();
            }
        } catch (IOException ex) {
            throw new EvaluationException("I/O error in evaluation", ex);
        }
    }

    public ExperimentOutputLayout getOutputLayout() {
        if (outputLayout == null) {
            throw new IllegalStateException("experiment not started");
        }
        return outputLayout;
    }

    private ExperimentOutputLayout makeExperimentOutputLayout() {
        Set dataColumns = Sets.newLinkedHashSet();
        Set algoColumns = Sets.newLinkedHashSet();
        for (DataSet ds: getDataSets()) {
            dataColumns.addAll(ds.getAttributes().keySet());
        }
        for (AlgorithmInstance ai: getAlgorithms()) {
            algoColumns.addAll(ai.getAttributes().keySet());
        }
        return new ExperimentOutputLayout(dataColumns, algoColumns);
    }

    private void openOutputs(ExperimentOutputLayout eol) throws IOException {
        TableLayout globalLayout = makeGlobalResultLayout(eol);
        resultBuilder = resultCloser.register(new TableBuilder(globalLayout));
        if (outputFile != null) {
            TableWriter csvw = resultCloser.register(CSVWriter.open(outputFile.toFile(), globalLayout, CompressionMode.AUTO));
            globalOutput = resultCloser.register(new MultiplexedTableWriter(globalLayout, resultBuilder, csvw));
        } else {
            globalOutput = resultBuilder;
        }

        if (userOutputFile != null) {
            TableLayout ul = makeUserResultLayout(eol);
            userOutput = resultCloser.register(CSVWriter.open(userOutputFile.toFile(), ul, CompressionMode.AUTO));
        }
        outputLayout = eol;
    }

    private TableLayout makeGlobalResultLayout(ExperimentOutputLayout eol) {
        TableLayoutBuilder tlb = TableLayoutBuilder.copy(eol.getConditionLayout());
        tlb.addColumn("BuildTime")
           .addColumn("TestTime");
        for (EvalTask task: tasks) {
            tlb.addColumns(task.getGlobalColumns());
        }
        return tlb.build();
    }

    private TableLayout makeUserResultLayout(ExperimentOutputLayout eol) {
        TableLayoutBuilder tlb = TableLayoutBuilder.copy(eol.getConditionLayout());
        tlb.addColumn("User")
           .addColumn("TestTime");
        for (EvalTask task: tasks) {
            tlb.addColumns(task.getUserColumns());
        }
        return tlb.build();
    }


    /**
     * Create the tree of jobs to run in this experiment.
     * @return The job tree, as a root fork-join task.
     */
    @Nonnull
    private void buildJobGraph() {
        allJobs = new ArrayList<>();
        ComponentCache cache = null;
        if (shareModelComponents) {
            cache = new ComponentCache(cacheDir, classLoader);
        }
        Map groups = new HashMap<>();

        // set up the roots
        LenskitConfiguration config = new LenskitConfiguration();
        for (EvalTask task: tasks) {
            for (Class cls: task.getRequiredRoots()) {
                config.addRoot(cls);
            }
        }

        // make tasks
        for (DataSet ds: getDataSets()) {
            // TODO support global isolation
            UUID gid = ds.getIsolationGroup();
            TaskGroup group = groups.get(gid);
            if (group == null) {
                group = new TaskGroup(true);
                groups.put(gid, group);
            }
            MergePool pool = null;
            if (cache != null) {
                pool = MergePool.create();
            }
            for (AlgorithmInstance ai: getAlgorithms()) {
                ExperimentJob job = new ExperimentJob(this, ai, ds, config, cache, pool);
                allJobs.add(job);
                group.addTask(job);
            }
        }

        TaskGroup root;
        if (groups.size() > 1) {
            root = new TaskGroup(false);
            for (TaskGroup g: groups.values()) {
                root.addTask(g);
            }
        } else {
            root = FluentIterable.from(groups.values()).first().orNull();
        }
        if (root == null) {
            throw new IllegalStateException("no jobs defined");
        }
        rootJob = root;
    }

    /**
     * Run the jobs in sequence.
     */
    private void runJobList() {
        Preconditions.checkState(allJobs != null, "job graph not built");

        for (ExperimentJob job: allJobs) {
            job.execute();
        }
    }

    private void runJobGraph(int nthreads) {
        Preconditions.checkState(rootJob != null, "job graph not built");
        ForkJoinPool pool = new ForkJoinPool(nthreads);
        pool.invoke(rootJob);
    }

    /**
     * Load a train-test experiment from a YAML file.
     * @param file The file to load.
     * @return The train-test experiment.
     */
    public static TrainTestExperiment load(Path file) throws IOException {
        YAMLFactory factory = new YAMLFactory();
        ObjectMapper mapper = new ObjectMapper(factory);
        JsonNode node = mapper.readTree(file.toFile());

        return fromJSON(node, file.toUri());
    }

    /**
     * Configure a train-test experiment from JSON.
     * @param json The JSON node.
     * @param base The base URI for resolving relative paths.
     * @return The train-test experiment.
     * @throws IOException if there is an IO error.
     */
    static TrainTestExperiment fromJSON(JsonNode json, URI base) throws IOException {
        TrainTestExperiment exp = new TrainTestExperiment();

        // configure basic settings
        String outFile = json.path("output_file").asText(null);
        if (outFile != null) {
            exp.setOutputFile(Paths.get(base.resolve(outFile)));
        }
        outFile = json.path("user_output_file").asText(null);
        if (outFile != null) {
            exp.setUserOutputFile(Paths.get(base.resolve(outFile)));
        }
        String cacheDir = json.path("cache_directory").asText(null);
        if (cacheDir != null) {
            exp.setCacheDirectory(Paths.get(base.resolve(cacheDir)));
        }
        if (json.has("thread_count")) {
            exp.setThreadCount(json.get("thread_count").asInt(1));
        }
        if (json.has("share_model_components")) {
            exp.setShareModelComponents(json.get("share_model_components").asBoolean());
        }
        if (!json.has("datasets")) {
            throw new IllegalArgumentException("no data sets specified");
        }

        // configure data sets
        for (JsonNode ds: json.get("datasets")) {
            List dss;
            if (ds.isTextual()) {
                URI dsURI = base.resolve(ds.asText());
                dss = DataSet.load(dsURI.toURL());
            } else {
                dss = DataSet.fromJSON(ds, base);
            }
            exp.addDataSets(dss);
        }

        // configure the algorithms
        JsonNode algo = json.path("algorithms");
        if (algo.isTextual()) {
            // name of groovy file
            URI af = base.resolve(algo.asText());
            String aname = LKFileUtils.basename(af.getPath(), false);
            // FIXME Support algorithms from URLs
            exp.addAlgorithm(aname, Paths.get(af));
        } else if (algo.isObject()) {
            // mapping of names to groovy files
            Iterator> algoIter = algo.fields();
            while (algoIter.hasNext()) {
                Map.Entry e = algoIter.next();
                URI algoUri = base.resolve(e.getValue().asText());
                // FIXME Support algorithms from URLs
                exp.addAlgorithm(e.getKey(), Paths.get(algoUri));
            }
        } else if (algo.isArray()) {
            // list of groovy file names
            for (JsonNode an: algo) {
                URI af = base.resolve(an.asText());
                String aname = LKFileUtils.basename(af.getPath(), false);
                // FIXME Support algorithms from URLs
                exp.addAlgorithm(aname, Paths.get(af));
            }
        } else if (!algo.isMissingNode()) {
            throw new IllegalArgumentException("unexpected type for algorithms config");
        }

        // configure the tasks and their metrics
        JsonNode tasks = json.get("tasks");
        for (JsonNode task: tasks) {
            exp.addTask(configureTask(task, base));
        }

        return exp;
    }

    private static EvalTask configureTask(JsonNode task, URI base) throws IOException {
        String type = task.path("type").asText(null);
        Preconditions.checkArgument(type != null, "no task type specified");
        switch (type) {
        case "predict":
            return PredictEvalTask.fromJSON(task, base);
        case "recommend":
            return RecommendEvalTask.fromJSON(task, base);
        default:
            throw new IllegalArgumentException("invalid eval task type " + type);
        }
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy