All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.codelibs.elasticsearch.taste.eval.AbstractDifferenceEvaluator Maven / Gradle / Ivy

/**
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.codelibs.elasticsearch.taste.eval;

import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;

import org.apache.mahout.common.RandomUtils;
import org.codelibs.elasticsearch.taste.common.FastByIDMap;
import org.codelibs.elasticsearch.taste.common.LongPrimitiveIterator;
import org.codelibs.elasticsearch.taste.exception.NoSuchItemException;
import org.codelibs.elasticsearch.taste.exception.NoSuchUserException;
import org.codelibs.elasticsearch.taste.exception.TasteException;
import org.codelibs.elasticsearch.taste.model.DataModel;
import org.codelibs.elasticsearch.taste.model.GenericDataModel;
import org.codelibs.elasticsearch.taste.model.GenericPreference;
import org.codelibs.elasticsearch.taste.model.GenericUserPreferenceArray;
import org.codelibs.elasticsearch.taste.model.Preference;
import org.codelibs.elasticsearch.taste.model.PreferenceArray;
import org.codelibs.elasticsearch.taste.recommender.Recommender;
import org.codelibs.elasticsearch.taste.writer.ResultWriter;
import org.codelibs.elasticsearch.util.io.IOUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;

/**
 * Abstract superclass of a couple implementations, providing shared functionality.
 */
public abstract class AbstractDifferenceEvaluator implements Evaluator {

    private static final Logger log = LoggerFactory
            .getLogger(AbstractDifferenceEvaluator.class);

    protected final Random random;

    protected float maxPreference;

    protected float minPreference;

    protected ResultWriter resultWriter;

    protected String id;

    private boolean interrupted = false;

    protected AbstractDifferenceEvaluator() {
        random = RandomUtils.getRandom();
        maxPreference = Float.NaN;
        minPreference = Float.NaN;
    }

    @Override
    public void setId(final String id) {
        this.id = id;
    }

    @Override
    public String getId() {
        return id;
    }

    @Override
    public void setResultWriter(final ResultWriter resultWriter) {
        this.resultWriter = resultWriter;
    }

    public final float getMaxPreference() {
        return maxPreference;
    }

    public final void setMaxPreference(final float maxPreference) {
        this.maxPreference = maxPreference;
    }

    public final float getMinPreference() {
        return minPreference;
    }

    public final void setMinPreference(final float minPreference) {
        this.minPreference = minPreference;
    }

    @Override
    public Evaluation evaluate(final RecommenderBuilder recommenderBuilder,
            final DataModel dataModel, final EvaluationConfig config) {
        Preconditions.checkNotNull(recommenderBuilder);
        Preconditions.checkNotNull(dataModel);
        final double trainingPercentage = config.getTrainingPercentage();
        final double evaluationPercentage = config.getEvaluationPercentage();
        Preconditions.checkArgument(trainingPercentage >= 0.0
                && trainingPercentage <= 1.0, "Invalid trainingPercentage: "
                + trainingPercentage
                + ". Must be: 0.0 <= trainingPercentage <= 1.0");
        Preconditions.checkArgument(evaluationPercentage >= 0.0
                && evaluationPercentage <= 1.0,
                "Invalid evaluationPercentage: " + evaluationPercentage
                        + ". Must be: 0.0 <= evaluationPercentage <= 1.0");

        log.info("Beginning evaluation using {} of {}", trainingPercentage,
                dataModel);

        final int numUsers = dataModel.getNumUsers();
        final FastByIDMap trainingPrefs = new FastByIDMap(
                1 + (int) (evaluationPercentage * numUsers));
        final FastByIDMap testPrefs = new FastByIDMap(
                1 + (int) (evaluationPercentage * numUsers));

        final LongPrimitiveIterator it = dataModel.getUserIDs();
        while (it.hasNext()) {
            final long userID = it.nextLong();
            if (random.nextDouble() < evaluationPercentage) {
                splitOneUsersPrefs(trainingPercentage, trainingPrefs,
                        testPrefs, userID, dataModel);
            }
        }

        final DataModel trainingModel = new GenericDataModel(trainingPrefs);

        final Recommender recommender = recommenderBuilder
                .buildRecommender(trainingModel);

        final Evaluation result = getEvaluation(testPrefs, recommender,
                config.getMarginForError());
        result.setTraining(trainingPrefs.size());
        result.setTest(testPrefs.size());
        log.info("Evaluation result: {}", result);

        if (resultWriter != null) {
            IOUtils.closeQuietly(resultWriter);
        }

        return result;
    }

    private void splitOneUsersPrefs(final double trainingPercentage,
            final FastByIDMap trainingPrefs,
            final FastByIDMap testPrefs, final long userID,
            final DataModel dataModel) {
        List oneUserTrainingPrefs = null;
        List oneUserTestPrefs = null;
        final PreferenceArray prefs = dataModel.getPreferencesFromUser(userID);
        final int size = prefs.length();
        for (int i = 0; i < size; i++) {
            final Preference newPref = new GenericPreference(userID,
                    prefs.getItemID(i), prefs.getValue(i));
            if (random.nextDouble() < trainingPercentage) {
                if (oneUserTrainingPrefs == null) {
                    oneUserTrainingPrefs = Lists.newArrayListWithCapacity(3);
                }
                oneUserTrainingPrefs.add(newPref);
            } else {
                if (oneUserTestPrefs == null) {
                    oneUserTestPrefs = Lists.newArrayListWithCapacity(3);
                }
                oneUserTestPrefs.add(newPref);
            }
        }
        if (oneUserTrainingPrefs != null) {
            trainingPrefs.put(userID, new GenericUserPreferenceArray(
                    oneUserTrainingPrefs));
            if (oneUserTestPrefs != null) {
                testPrefs.put(userID, new GenericUserPreferenceArray(
                        oneUserTestPrefs));
            }
        }
    }

    protected Evaluation getEvaluation(
            final FastByIDMap testPrefs,
            final Recommender recommender, final float marginForError) {
        reset();
        final Collection> estimateCallables = Lists
                .newArrayList();
        for (final Map.Entry entry : testPrefs
                .entrySet()) {
            estimateCallables.add(new PreferenceEstimateCallable(recommender,
                    entry.getKey(), entry.getValue(), marginForError));
        }
        log.info("Beginning evaluation of {} users", estimateCallables.size());

        final int numProcessors = Runtime.getRuntime().availableProcessors();
        final ExecutorService executor = Executors
                .newFixedThreadPool(numProcessors);
        log.info("Starting timing of {} tasks in {} threads",
                estimateCallables.size(), numProcessors);
        EstimateStatsResult finalResult = null;
        try {
            final List> futures = executor
                    .invokeAll(estimateCallables);
            int count = 0;
            // Go look for exceptions here, really
            for (final Future future : futures) {
                final EstimateStatsResult result = future.get();
                if (Thread.currentThread().isInterrupted()) {
                    throw new TasteException("Interrupted evaluator.");
                }
                if (finalResult == null) {
                    finalResult = result;
                } else {
                    finalResult.merge(result);
                }
                if (count % 1000 == 0) {
                    final Runtime runtime = Runtime.getRuntime();
                    final long totalMemory = runtime.totalMemory();
                    final long memory = totalMemory - runtime.freeMemory();
                    log.info("Approximate memory used: {}MB / {}MB",
                            memory / 1000000L, totalMemory / 1000000L);
                }
                count++;
            }

        } catch (final InterruptedException ie) {
            throw new TasteException(ie);
        } catch (final ExecutionException ee) {
            throw new TasteException(ee.getCause());
        } finally {
            executor.shutdown();
            try {
                executor.awaitTermination(10, TimeUnit.SECONDS);
            } catch (final InterruptedException e) {
                throw new TasteException(e.getCause());
            }
        }

        final Evaluation evaluation = new Evaluation();
        evaluation.setScore(computeFinalEvaluation());
        if (finalResult != null) {
            evaluation.setTotalProcessingTime(finalResult
                    .getTotalProcessingTime());
            evaluation.setAverageProcessingTime(finalResult
                    .getAverageProcessingTime());
            evaluation.setMaxProcessingTime(finalResult.getMaxProcessingTime());
            evaluation.setSuccessful(finalResult.getSuccessful());
            evaluation.setFailure(finalResult.getFailure());
            evaluation.setNoEstimate(finalResult.getNoEstimate());
            evaluation.setTotalPreference(finalResult.getTotalPreference());
        }
        return evaluation;
    }

    protected float capEstimatedPreference(final float estimate) {
        if (estimate > maxPreference) {
            return maxPreference;
        }
        if (estimate < minPreference) {
            return minPreference;
        }
        return estimate;
    }

    protected abstract void reset();

    protected abstract void processOneEstimate(float estimatedPreference,
            Preference realPref);

    protected abstract double computeFinalEvaluation();

    protected static class EstimateStatsResult {

        private int noEstimate = 0;

        private int successful = 0;

        private int failure = 0;

        private int numOfTime = 0;

        private long totalTime = 0;

        private long maxTime = 0;

        public void incrementNoEstimate() {
            noEstimate++;
        }

        public void incrementSuccess() {
            successful++;
        }

        public void incrementFailure() {
            failure++;
        }

        public int getNoEstimate() {
            return noEstimate;
        }

        public int getSuccessful() {
            return successful;
        }

        public int getFailure() {
            return failure;
        }

        public int getTotalPreference() {
            return noEstimate + successful + failure;
        }

        public long getTotalProcessingTime() {
            return totalTime;
        }

        public long getAverageProcessingTime() {
            return totalTime / numOfTime;
        }

        public long getMaxProcessingTime() {
            return maxTime;
        }

        public void addDuration(final long time) {
            numOfTime++;
            totalTime += time;
            if (maxTime < time) {
                maxTime = time;
            }
        }

        public void merge(final EstimateStatsResult result) {
            noEstimate += result.noEstimate;
            successful += result.successful;
            failure += result.failure;
            numOfTime += result.numOfTime;
            totalTime += result.totalTime;
            maxTime += result.maxTime;
        }
    }

    protected class PreferenceEstimateCallable implements
            Callable {

        private final Recommender recommender;

        private final long testUserID;

        private final PreferenceArray prefs;

        private final float marginForError;

        public PreferenceEstimateCallable(final Recommender recommender,
                final long testUserID, final PreferenceArray prefs,
                final float marginForError) {
            this.recommender = recommender;
            this.testUserID = testUserID;
            this.prefs = prefs;
            this.marginForError = marginForError;
        }

        @Override
        public EstimateStatsResult call() {
            final EstimateStatsResult stats = new EstimateStatsResult();
            for (final Preference realPref : prefs) {
                if (interrupted) {
                    break;
                }
                float estimatedPreference = Float.NaN;
                final float actualPreference = realPref.getValue();
                final long start = System.currentTimeMillis();
                final long time;
                try {
                    estimatedPreference = recommender.estimatePreference(
                            testUserID, realPref.getItemID());
                } catch (final NoSuchUserException nsue) {
                    // It's possible that an item exists in the test data but not training data in which case
                    // NSEE will be thrown. Just ignore it and move on.
                    log.info(
                            "User exists in test data but not training data: {}",
                            testUserID);
                } catch (final NoSuchItemException nsie) {
                    log.info(
                            "Item exists in test data but not training data: {}",
                            realPref.getItemID());
                } finally {
                    time = System.currentTimeMillis() - start;
                    stats.addDuration(time);
                }

                String estimateResultType;
                if (Float.isNaN(estimatedPreference)) {
                    estimateResultType = "no_estimate";
                    stats.incrementNoEstimate();
                } else {
                    estimatedPreference = capEstimatedPreference(estimatedPreference);
                    processOneEstimate(estimatedPreference, realPref);
                    if (Math.abs(estimatedPreference - realPref.getValue()) < marginForError) {
                        estimateResultType = "success";
                        stats.incrementSuccess();
                    } else {
                        estimateResultType = "failure";
                        stats.incrementFailure();
                    }
                }
                if (resultWriter != null) {
                    resultWriter.write(id, testUserID, realPref.getItemID(),
                            estimateResultType, actualPreference,
                            estimatedPreference, time);
                }
            }

            return stats;
        }

    }

    @Override
    public void interrupt() {
        interrupted = true;
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy