All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.feedzai.fos.impl.weka.WekaScorer Maven / Gradle / Ivy

There is a newer version: 1.0.10
Show newest version
/*
 * $#
 * FOS Weka
 *  
 * Copyright (C) 2013 Feedzai SA
 *  
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as
 * published by the Free Software Foundation, either version 3 of the 
 * License, or (at your option) any later version.
 * 
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 * 
 * You should have received a copy of the GNU General Public 
 * License along with this program.  If not, see
 * .
 * #$
 */
package com.feedzai.fos.impl.weka;

import com.feedzai.fos.api.FOSException;
import com.feedzai.fos.api.Scorer;
import com.feedzai.fos.common.validation.NotNull;
import com.feedzai.fos.impl.weka.config.WekaManagerConfig;
import com.feedzai.fos.impl.weka.config.WekaModelConfig;
import com.feedzai.fos.impl.weka.utils.*;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.*;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.locks.ReentrantReadWriteLock;

import static com.google.common.base.Preconditions.checkNotNull;

/**
 * Implementation of the classification-api that supports multiple simultaneous classifiers (thread safe!).
 *
 * @author Marco Jorge ([email protected])
 */
public class WekaScorer implements Scorer {
    private static final Logger logger = LoggerFactory.getLogger(WekaScorer.class);

    private Map wekaThreadSafeScorers = new HashMap<>();
    private ExecutorService executorService;
    private ReentrantReadWriteLock reloadModelsLock = new ReentrantReadWriteLock(true /* fair */);
    private WekaManagerConfig wekaManagerConfig;

    private WekaThreadSafeScorer getScorer(UUID modelId) throws FOSException {
        WekaThreadSafeScorer wekaThreadSafeScorer = wekaThreadSafeScorers.get(modelId);
        if (wekaThreadSafeScorer == null) {
            logger.error("No model with ID '{}'", modelId);
            throw new FOSException("No model with ID " + modelId);
        }
        return wekaThreadSafeScorer;
    }

    private  T getFuture(Future future, UUID modelId) throws FOSException {
        try {
            return future.get();
        } catch (InterruptedException e) {
            logger.error("Could not score on model '{}'", modelId, e);
            throw new FOSException(e);
        } catch (ExecutionException e) {
            logger.error("Could not score on model '{}'", modelId, e);
            throw new FOSException(e);
        }
    }

    /**
     * Creates a new scorer for the models identified int he configuration.
     * 

If loading of a model was not possible, this logs a message but continues to load other models and does not throw any exception. * * @param modelConfigs the list of models to instantiate * @param wekaManagerConfig the global configuration */ public WekaScorer(Map modelConfigs, WekaManagerConfig wekaManagerConfig) { checkNotNull(modelConfigs, "Model configuration map cannot be null"); checkNotNull(wekaManagerConfig, "Manager config cannot be null"); this.wekaManagerConfig = wekaManagerConfig; for (Map.Entry wekaModelConfigEntry : modelConfigs.entrySet()) { try { if (wekaModelConfigEntry.getValue().isClassifierThreadSafe()) { wekaThreadSafeScorers.put(wekaModelConfigEntry.getValue().getId(), new WekaThreadSafeScorerPassthrough(wekaModelConfigEntry.getValue(), wekaManagerConfig)); } else { wekaThreadSafeScorers.put(wekaModelConfigEntry.getValue().getId(), new WekaThreadSafeScorerPool(wekaModelConfigEntry.getValue(), wekaManagerConfig)); } } catch (Exception e) { logger.error("Could not load from '{}' (continuing to load others)", wekaModelConfigEntry.getKey(), e); } } this.executorService = Executors.newFixedThreadPool(wekaManagerConfig.getThreadPoolSize()); } @Override public void close() { try { reloadModelsLock.writeLock().lock(); executorService.shutdown(); for (WekaThreadSafeScorer wekaThreadSafeScorer : this.wekaThreadSafeScorers.values()) { if (wekaThreadSafeScorer != null) { wekaThreadSafeScorer.close(); } } } finally { reloadModelsLock.writeLock().unlock(); } } /** * Score the scorable for each model ID identified by modelIds. *

If multiple models are given as parameters, they will be scored in parallel. * * @param modelIds the list of models to score * @param scorable the item to score * @return a List of scores with the same order and the received modelIds * @throws FOSException when classification was not possible */ @Override @NotNull public List score(List modelIds, Object[] scorable) throws FOSException { checkNotNull(modelIds, "Models to score cannot be null"); checkNotNull(scorable, "Instance cannot be null"); List scores = new ArrayList<>(modelIds.size()); try { reloadModelsLock.readLock().lock(); if (modelIds.size() == 1) { // if only one model, then don't parallelize scoring WekaThreadSafeScorer wekaThreadSafeScorer = getScorer(modelIds.get(0)); scores.add(wekaThreadSafeScorer.score(scorable)); } else { Map> futureScores = new HashMap<>(modelIds.size()); // scatter for (UUID modelId : modelIds) { WekaThreadSafeScorer wekaThreadSafeScorer = getScorer(modelId); futureScores.put(modelId, executorService.submit(new AsyncScoringTask(wekaThreadSafeScorer, scorable))); } // gather for (UUID modelId : modelIds) { scores.add(getFuture(futureScores.get(modelId), modelId)); } } } finally { reloadModelsLock.readLock().unlock(); } return scores; } /** * Score each scorable with the modelId that maps it in the map. *

If multiple models/scorables are given as parameters, they will be scored in parallel. * * @param modelIdsToScorables a map from modelId to scorable * @return a map from modelId to score * @throws FOSException when classification was not possible */ @Override @NotNull public Map score(Map modelIdsToScorables) throws FOSException { checkNotNull(modelIdsToScorables, "Map of instances cannot be null"); Map scores = new HashMap<>(modelIdsToScorables.size()); try { reloadModelsLock.readLock().lock(); if (modelIdsToScorables.size() == 1) { // if only one model, then don't parallelize scoring for (Map.Entry entry : modelIdsToScorables.entrySet()) { // easy way to access... WekaThreadSafeScorer wekaThreadSafeScorer = getScorer(entry.getKey()); scores.put(entry.getKey(), wekaThreadSafeScorer.score(entry.getValue())); } } else { Map> futureScores = new HashMap<>(modelIdsToScorables.size()); // scatter for (Map.Entry entry : modelIdsToScorables.entrySet()) { WekaThreadSafeScorer wekaThreadSafeScorer = getScorer(entry.getKey()); futureScores.put(entry.getKey(), executorService.submit(new AsyncScoringTask(wekaThreadSafeScorer, entry.getValue()))); } // gather for (Map.Entry> entry : futureScores.entrySet()) { scores.put(entry.getKey(), getFuture(entry.getValue(), entry.getKey())); } } } finally { reloadModelsLock.readLock().unlock(); } return scores; } /** * Score each scorable with the given modelId. *

If multiple scorables are given as parameters, they will be scored in parallel. * * @param modelId the id of the model * @param scorables an array of instances to score * @return a list of scores with the same order as the scorable array * @throws FOSException when classification was not possible */ @Override @NotNull public List score(UUID modelId, List scorables) throws FOSException { checkNotNull(scorables, "List of scorables cannot be null"); List scores = new ArrayList<>(scorables.size()); try { reloadModelsLock.readLock().lock(); if (scorables.size() == 1) { // if only one model, then don't parallelize scoring WekaThreadSafeScorer wekaThreadSafeScorer = getScorer(modelId); scores.add(wekaThreadSafeScorer.score(scorables.get(0))); } else { Map> futureScores = new HashMap<>(scorables.size()); // scatter for (Object[] scorable : scorables) { WekaThreadSafeScorer wekaThreadSafeScorer = wekaThreadSafeScorers.get(modelId); futureScores.put(scorable, executorService.submit(new AsyncScoringTask(wekaThreadSafeScorer, scorable))); } // gather for (Object[] scorable : scorables) { scores.add(getFuture(futureScores.get(scorable), modelId)); } } } finally { reloadModelsLock.readLock().unlock(); } return scores; } /** * Adds the given model to the managed models. *

* If the provided model id already exists, then the older model is removed and the new one is instantiated. * * @param wekaModelConfig the configuration of the new model * @throws FOSException when the new model could not be instantiated */ public void addOrUpdate(WekaModelConfig wekaModelConfig) throws FOSException { checkNotNull(wekaModelConfig, "Model config cannot be null"); WekaThreadSafeScorer newWekaThreadSafeScorer = new WekaThreadSafeScorerPool(wekaModelConfig, wekaManagerConfig); WekaThreadSafeScorer oldWekaThreadSafeScorer = quickSwitch(wekaModelConfig.getId(), newWekaThreadSafeScorer); WekaUtils.closeSilently(oldWekaThreadSafeScorer); } /** * Removes the given model from the managed models. *

If the model does not exist no exception will be thrown. * * @param modelId the id of the model to remove. */ public void removeModel(UUID modelId) { WekaThreadSafeScorer newWekaThreadSafeScorer = null; WekaThreadSafeScorer oldWekaThreadSafeScorer = quickSwitch(modelId, newWekaThreadSafeScorer); WekaUtils.closeSilently(oldWekaThreadSafeScorer); } private WekaThreadSafeScorer quickSwitch(UUID modelId, WekaThreadSafeScorer newWekaThreadSafeScorer) { try { // quick switch - do not do anything inside for performance reasons!! reloadModelsLock.writeLock().lock(); return wekaThreadSafeScorers.put(modelId, newWekaThreadSafeScorer); } finally { reloadModelsLock.writeLock().unlock(); } } }





© 2015 - 2025 Weber Informatics LLC | Privacy Policy