com.feedzai.fos.impl.weka.WekaScorer Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of fos-impl-weka Show documentation
Show all versions of fos-impl-weka Show documentation
Feedzai Open Scoring Server - Weka Implementation
/*
* $#
* FOS Weka
*
* Copyright (C) 2013 Feedzai SA
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as
* published by the Free Software Foundation, either version 3 of the
* License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public
* License along with this program. If not, see
* .
* #$
*/
package com.feedzai.fos.impl.weka;
import com.feedzai.fos.api.FOSException;
import com.feedzai.fos.api.Scorer;
import com.feedzai.fos.common.validation.NotNull;
import com.feedzai.fos.impl.weka.config.WekaManagerConfig;
import com.feedzai.fos.impl.weka.config.WekaModelConfig;
import com.feedzai.fos.impl.weka.utils.*;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.*;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import static com.google.common.base.Preconditions.checkNotNull;
/**
* Implementation of the classification-api that supports multiple simultaneous classifiers (thread safe!).
*
* @author Marco Jorge ([email protected])
*/
public class WekaScorer implements Scorer {
private static final Logger logger = LoggerFactory.getLogger(WekaScorer.class);
private Map wekaThreadSafeScorers = new HashMap<>();
private ExecutorService executorService;
private ReentrantReadWriteLock reloadModelsLock = new ReentrantReadWriteLock(true /* fair */);
private WekaManagerConfig wekaManagerConfig;
private WekaThreadSafeScorer getScorer(UUID modelId) throws FOSException {
WekaThreadSafeScorer wekaThreadSafeScorer = wekaThreadSafeScorers.get(modelId);
if (wekaThreadSafeScorer == null) {
logger.error("No model with ID '{}'", modelId);
throw new FOSException("No model with ID " + modelId);
}
return wekaThreadSafeScorer;
}
private T getFuture(Future future, UUID modelId) throws FOSException {
try {
return future.get();
} catch (InterruptedException e) {
logger.error("Could not score on model '{}'", modelId, e);
throw new FOSException(e);
} catch (ExecutionException e) {
logger.error("Could not score on model '{}'", modelId, e);
throw new FOSException(e);
}
}
/**
* Creates a new scorer for the models identified int he configuration.
* If loading of a model was not possible, this logs a message but continues to load other models and does not throw any exception.
*
* @param modelConfigs the list of models to instantiate
* @param wekaManagerConfig the global configuration
*/
public WekaScorer(Map modelConfigs, WekaManagerConfig wekaManagerConfig) {
checkNotNull(modelConfigs, "Model configuration map cannot be null");
checkNotNull(wekaManagerConfig, "Manager config cannot be null");
this.wekaManagerConfig = wekaManagerConfig;
for (Map.Entry wekaModelConfigEntry : modelConfigs.entrySet()) {
try {
if (wekaModelConfigEntry.getValue().isClassifierThreadSafe()) {
wekaThreadSafeScorers.put(wekaModelConfigEntry.getValue().getId(), new WekaThreadSafeScorerPassthrough(wekaModelConfigEntry.getValue(), wekaManagerConfig));
} else {
wekaThreadSafeScorers.put(wekaModelConfigEntry.getValue().getId(), new WekaThreadSafeScorerPool(wekaModelConfigEntry.getValue(), wekaManagerConfig));
}
} catch (Exception e) {
logger.error("Could not load from '{}' (continuing to load others)", wekaModelConfigEntry.getKey(), e);
}
}
this.executorService = Executors.newFixedThreadPool(wekaManagerConfig.getThreadPoolSize());
}
@Override
public void close() {
try {
reloadModelsLock.writeLock().lock();
executorService.shutdown();
for (WekaThreadSafeScorer wekaThreadSafeScorer : this.wekaThreadSafeScorers.values()) {
if (wekaThreadSafeScorer != null) {
wekaThreadSafeScorer.close();
}
}
} finally {
reloadModelsLock.writeLock().unlock();
}
}
/**
* Score the scorable
for each model ID identified by modelIds
.
* If multiple models are given as parameters, they will be scored in parallel.
*
* @param modelIds the list of models to score
* @param scorable the item to score
* @return a List of scores with the same order and the received modelIds
* @throws FOSException when classification was not possible
*/
@Override
@NotNull
public List score(List modelIds, Object[] scorable) throws FOSException {
checkNotNull(modelIds, "Models to score cannot be null");
checkNotNull(scorable, "Instance cannot be null");
List scores = new ArrayList<>(modelIds.size());
try {
reloadModelsLock.readLock().lock();
if (modelIds.size() == 1) {
// if only one model, then don't parallelize scoring
WekaThreadSafeScorer wekaThreadSafeScorer = getScorer(modelIds.get(0));
scores.add(wekaThreadSafeScorer.score(scorable));
} else {
Map> futureScores = new HashMap<>(modelIds.size());
// scatter
for (UUID modelId : modelIds) {
WekaThreadSafeScorer wekaThreadSafeScorer = getScorer(modelId);
futureScores.put(modelId, executorService.submit(new AsyncScoringTask(wekaThreadSafeScorer, scorable)));
}
// gather
for (UUID modelId : modelIds) {
scores.add(getFuture(futureScores.get(modelId), modelId));
}
}
} finally {
reloadModelsLock.readLock().unlock();
}
return scores;
}
/**
* Score each scorable
with the modelId that maps it in the map.
* If multiple models/scorables are given as parameters, they will be scored in parallel.
*
* @param modelIdsToScorables a map from modelId to scorable
* @return a map from modelId to score
* @throws FOSException when classification was not possible
*/
@Override
@NotNull
public Map score(Map modelIdsToScorables) throws FOSException {
checkNotNull(modelIdsToScorables, "Map of instances cannot be null");
Map scores = new HashMap<>(modelIdsToScorables.size());
try {
reloadModelsLock.readLock().lock();
if (modelIdsToScorables.size() == 1) {
// if only one model, then don't parallelize scoring
for (Map.Entry entry : modelIdsToScorables.entrySet()) { // easy way to access...
WekaThreadSafeScorer wekaThreadSafeScorer = getScorer(entry.getKey());
scores.put(entry.getKey(), wekaThreadSafeScorer.score(entry.getValue()));
}
} else {
Map> futureScores = new HashMap<>(modelIdsToScorables.size());
// scatter
for (Map.Entry entry : modelIdsToScorables.entrySet()) {
WekaThreadSafeScorer wekaThreadSafeScorer = getScorer(entry.getKey());
futureScores.put(entry.getKey(), executorService.submit(new AsyncScoringTask(wekaThreadSafeScorer, entry.getValue())));
}
// gather
for (Map.Entry> entry : futureScores.entrySet()) {
scores.put(entry.getKey(), getFuture(entry.getValue(), entry.getKey()));
}
}
} finally {
reloadModelsLock.readLock().unlock();
}
return scores;
}
/**
* Score each scorable
with the given modelId
.
* If multiple scorables
are given as parameters, they will be scored in parallel.
*
* @param modelId the id of the model
* @param scorables an array of instances to score
* @return a list of scores with the same order as the scorable array
* @throws FOSException when classification was not possible
*/
@Override
@NotNull
public List score(UUID modelId, List
© 2015 - 2025 Weber Informatics LLC | Privacy Policy