All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.lenskit.eval.traintest.recommend.RecommendEvalTask Maven / Gradle / Ivy

/*
 * LensKit, an open source recommender systems toolkit.
 * Copyright 2010-2014 LensKit Contributors.  See CONTRIBUTORS.md.
 * Work on LensKit has been funded by the National Science Foundation under
 * grants IIS 05-34939, 08-08692, 08-12148, and 10-17697.
 *
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU Lesser General Public License as
 * published by the Free Software Foundation; either version 2.1 of the
 * License, or (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful, but WITHOUT
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
 * FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
 * details.
 *
 * You should have received a copy of the GNU General Public License along with
 * this program; if not, write to the Free Software Foundation, Inc., 51
 * Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
 */
package org.lenskit.eval.traintest.recommend;

import com.fasterxml.jackson.databind.JsonNode;
import com.google.common.base.Function;
import com.google.common.base.Preconditions;
import com.google.common.collect.FluentIterable;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import it.unimi.dsi.fastutil.longs.LongList;
import it.unimi.dsi.fastutil.longs.LongSet;
import org.grouplens.grapht.util.ClassLoaders;
import org.grouplens.lenskit.util.io.CompressionMode;
import org.lenskit.api.ItemRecommender;
import org.lenskit.api.Recommender;
import org.lenskit.api.Result;
import org.lenskit.api.ResultList;
import org.lenskit.eval.traintest.*;
import org.lenskit.eval.traintest.metrics.Metric;
import org.lenskit.eval.traintest.metrics.MetricLoaderHelper;
import org.lenskit.eval.traintest.metrics.MetricResult;
import org.lenskit.eval.traintest.predict.PredictEvalTask;
import org.lenskit.util.collections.LongUtils;
import org.lenskit.util.table.TableLayout;
import org.lenskit.util.table.TableLayoutBuilder;
import org.lenskit.util.table.writer.CSVWriter;
import org.lenskit.util.table.writer.TableWriter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import java.io.IOException;
import java.net.URI;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.*;

/**
 * An eval task that attempts to recommend items for a test user.
 */
public class RecommendEvalTask implements EvalTask {
    private static final Logger logger = LoggerFactory.getLogger(RecommendEvalTask.class);
    private static final TopNMetric[] DEFAULT_METRICS = {
            new TopNLengthMetric(),
            new TopNNDCGMetric()
    };

    private Path outputFile;
    private String labelPrefix;
    private int listSize = -1;
    private List> topNMetrics = Lists.newArrayList(DEFAULT_METRICS);
    private volatile ItemSelector candidateSelector = ItemSelector.allItems();
    private volatile ItemSelector excludeSelector = ItemSelector.userTrainItems();

    private ExperimentOutputLayout experimentOutputLayout;
    private TableWriter outputTable;

    /**
     * Create a new recommend eval task.
     */
    public RecommendEvalTask() {}

    /**
     * Parse a recommend task from JSON.
     * @param json The JSON data.
     * @param base The base URI (for resolving relative paths).
     * @return The task.
     * @throws IOException If there is an I/O error.
     */
    public static RecommendEvalTask fromJSON(JsonNode json, URI base) throws IOException {
        RecommendEvalTask task = new RecommendEvalTask();

        String outFile = json.path("output_file").asText(null);
        if (outFile != null) {
            task.setOutputFile(Paths.get(base.resolve(outFile)));
        }

        task.setLabelPrefix(json.path("label_prefix").asText(null));
        task.setListSize(json.path("list_size").asInt(-1));

        String sel = json.path("candidates").asText(null);
        if (sel != null) {
            task.setCandidateSelector(ItemSelector.compileSelector(sel));
        }
        sel = json.path("exclude").asText(null);
        if (sel != null) {
            task.setExcludeSelector(ItemSelector.compileSelector(sel));
        }

        JsonNode metrics = json.get("metrics");
        if (metrics != null && !metrics.isNull()) {
            task.topNMetrics.clear();
            MetricLoaderHelper mlh = new MetricLoaderHelper(ClassLoaders.inferDefault(PredictEvalTask.class),
                                                            "topn-metrics");
            for (JsonNode mn: metrics) {
                TopNMetric metric = mlh.createMetric(TopNMetric.class, mn);
                if (metric != null) {
                    task.addMetric(metric);
                } else {
                    throw new RuntimeException("cannot build metric for " + mn.toString());
                }
            }
        }

        return task;
    }

    /**
     * Get the output file for writing predictions.
     * @return The output file, or {@code null} if no file is configured.
     */
    public Path getOutputFile() {
        return outputFile;
    }

    /**
     * Set the output file for predictions.
     * @param file The output file for writing predictions. Will get a CSV file.
     */
    public void setOutputFile(Path file) {
        outputFile = file;
    }

    /**
     * Get the prefix applied to column labels.
     * @return The column label prefix.
     */
    public String getLabelPrefix() {
        return labelPrefix;
    }

    /**
     * Set the prefix applied to column labels.  If provided, it will be prepended to column labels from this task,
     * along with a ".".
     * @param prefix The label prefix.
     */
    public void setLabelPrefix(String prefix) {
        labelPrefix = prefix;
    }

    /**
     * Get the list size to use.
     * @return The number of items to recommend per user.
     */
    public int getListSize() {
        return listSize;
    }

    /**
     * Set the list size to use.
     * @param n The number of items to recommend per user.
     */
    public void setListSize(int n) {
        listSize = n;
    }

    /**
     * Get the active candidate selector.
     * @return The candidate selector to use.
     */
    public ItemSelector getCandidateSelector() {
        return candidateSelector;
    }

    /**
     * Set the candidate selector.
     * @param sel The candidate selector.
     */
    public void setCandidateSelector(ItemSelector sel) {
        candidateSelector = sel;
    }

    /**
     * Get the active exclude selector.
     * @return The exclude selector to use.
     */
    public ItemSelector getExcludeSelector() {
        return excludeSelector;
    }

    /**
     * Set the exclude selector.
     * @param sel The exclude selector.
     */
    public void setExcludeSelector(ItemSelector sel) {
        excludeSelector = sel;
    }

    /**
     * Get the list of prediction metrics.
     * @return The list of prediction metrics.  This list is live, not copied, so it can be modified or cleared.
     */
    public List> getTopNMetrics() {
        return topNMetrics;
    }

    /**
     * Get the list of all metrics.
     * @return A list containing all metrics used by this task.
     */
    public List> getAllMetrics() {
        ImmutableList.Builder> metrics = ImmutableList.builder();
        metrics.addAll(topNMetrics);
        return metrics.build();
    }

    /**
     * Add a prediction metric.
     * @param metric The metric to add.
     */
    public void addMetric(TopNMetric metric) {
        topNMetrics.add(metric);
    }

    @Override
    public Set> getRequiredRoots() {
        return FluentIterable.from(getAllMetrics())
                             .transformAndConcat(new Function, Iterable>>() {
                                 @Nullable
                                 @Override
                                 public Iterable> apply(Metric input) {
                                     return input.getRequiredRoots();
                                 }
                             }).toSet();
    }

    @Override
    public List getGlobalColumns() {
        ImmutableList.Builder columns = ImmutableList.builder();
        for (Metric m: getAllMetrics()) {
            for (String label: m.getAggregateColumnLabels()) {
                columns.add(prefixColumn(label));
            }
        }
        return columns.build();
    }

    @Override
    public List getUserColumns() {
        ImmutableList.Builder columns = ImmutableList.builder();
        for (TopNMetric pm: getTopNMetrics()) {
            for (String label: pm.getColumnLabels()) {
                columns.add(prefixColumn(label));
            }
        }
        return columns.build();
    }

    private String prefixColumn(String input) {
        String pfx = getLabelPrefix();
        if (pfx == null) {
            return input;
        } else {
            return pfx + "." + input;
        }
    }

    @Override
    public void start(ExperimentOutputLayout outputLayout) {
        experimentOutputLayout = outputLayout;
        Path outFile = getOutputFile();
        if (outFile == null) {
            return;
        }

        TableLayoutBuilder tlb = TableLayoutBuilder.copy(outputLayout.getConditionLayout());
        TableLayout layout = tlb.addColumn("User")
                                .addColumn("Rank")
                                .addColumn("Item")
                                .addColumn("Score")
                                .build();
        try {
            logger.info("writing recommendations to {}", outFile);
            outputTable = CSVWriter.open(outFile.toFile(), layout, CompressionMode.AUTO);
        } catch (IOException e) {
            throw new EvaluationException("error opening prediction output file", e);
        }
    }

    @Override
    public void finish() {
        experimentOutputLayout = null;
        if (outputTable != null) {
            try {
                outputTable.close();
                outputTable = null;
            } catch (IOException e) {
                throw new EvaluationException("error closing prediction output file", e);
            }
        }
    }

    @Override
    public ConditionEvaluator createConditionEvaluator(AlgorithmInstance algorithm, DataSet dataSet, Recommender rec) {
        Preconditions.checkState(experimentOutputLayout != null, "experiment not started");
        TableWriter recTable = experimentOutputLayout.prefixTable(outputTable, dataSet, algorithm);
        LongSet items = dataSet.getAllItems();
        ItemRecommender irec = rec.getItemRecommender();
        if (irec == null) {
            logger.warn("algorithm {} has no item recommender", algorithm);
            return null;
        }

        // we need details to write recommendation output
        boolean useDetails = recTable != null;
        List> contexts = new ArrayList<>(topNMetrics.size());
        for (TopNMetric metric: topNMetrics) {
            logger.debug("setting up metric {}", metric);
            MetricContext mc = MetricContext.create(metric, algorithm, dataSet, rec);
            contexts.add(mc);
            // does this metric require details?
            useDetails |= mc.usesDetails();
        }

        return new TopNConditionEvaluator(recTable, rec, irec, contexts, items, useDetails);
    }

    static class MetricContext {
        final TopNMetric metric;
        final X context;

        public MetricContext(TopNMetric m, X ctx) {
            metric = m;
            context = ctx;
        }

        public boolean usesDetails() {
            return !(metric instanceof ListOnlyTopNMetric);
        }

        @Nonnull
        public MetricResult measureUser(TestUser user, int n, ResultList recommendations) {
            return metric.measureUser(user, n, recommendations, context);
        }

        @Nonnull
        public MetricResult measureUser(TestUser user, int n, LongList recommendations) {
            return ((ListOnlyTopNMetric) metric).measureUser(user, n, recommendations, context);
        }

        @Nonnull
        public MetricResult getAggregateMeasurements() {
            return metric.getAggregateMeasurements(context);
        }

        /**
         * Create a new metric context. Indirected through this method to help the type checker.
         */
        public static  MetricContext create(TopNMetric metric, AlgorithmInstance algorithm, DataSet dataSet, Recommender rec) {
            X ctx = metric.createContext(algorithm, dataSet, rec);
            return new MetricContext<>(metric, ctx);
        }
    }

    class TopNConditionEvaluator implements ConditionEvaluator {
        private final TableWriter writer;
        private final Recommender recommender;
        private final ItemRecommender itemRecommender;
        private final List> predictMetricContexts;
        private final LongSet allItems;
        private final boolean useDetails;

        public TopNConditionEvaluator(TableWriter tw, Recommender rec, ItemRecommender irec,
                                      List> mcs, LongSet items, boolean details) {
            writer = tw;
            recommender = rec;
            itemRecommender = irec;
            predictMetricContexts = mcs;
            allItems = items;
            useDetails = details;
        }

        @Nonnull
        @Override
        public Map measureUser(TestUser testUser) {
            LongSet candidates = getCandidateSelector().selectItems(allItems, recommender, testUser);
            LongSet excludes = getExcludeSelector().selectItems(allItems, recommender, testUser);
            int n = getListSize();
            ResultList results = null;
            LongList items = null;
            if (useDetails) {
                results = itemRecommender.recommendWithDetails(testUser.getUserId(), n,
                                                               candidates, excludes);
            } else {
                // no one needs details, save time collecting them
                items = LongUtils.asLongList(itemRecommender.recommend(testUser.getUserId(), n,
                                                                       candidates, excludes));
            }

            // Measure the user results
            Map row = new HashMap<>();
            for (MetricContext mc: predictMetricContexts) {
                MetricResult res;
                if (useDetails) {
                    res = mc.measureUser(testUser, n, results);
                } else {
                    res = mc.measureUser(testUser, n, items);
                }
                row.putAll(res.withPrefix(getLabelPrefix())
                              .getValues());
            }

            // Write all attempted predictions
            if (writer != null) {
                assert results != null; // we use details when writer is nonnull
                int rank = 0;
                for (Result rec : results) {
                    try {
                        rank += 1;
                        writer.writeRow(testUser.getUserId(), rank, rec.getId(), rec.getScore());
                    } catch (IOException ex) {
                        throw new EvaluationException("error writing prediction row", ex);
                    }
                }
            }

            return row;
        }

        @Nonnull
        @Override
        public Map finish() {
            Map results = new HashMap<>();
            for (MetricContext mc: predictMetricContexts) {
                logger.debug("finishing metric {}", mc.metric);
                results.putAll(mc.getAggregateMeasurements()
                                 .withPrefix(getLabelPrefix())
                                 .getValues());
            }
            return results;
        }
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy