Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance. Project price only 1 $
You can buy this project and download/modify it how often you want.
Conformal AI package, including all data IO, transformations, machine learning models and predictor classes. Without inclusion of chemistry-dependent code.
/*
* Copyright (C) Aros Bio AB.
*
* CPSign is an Open Source Software that is dual licensed to allow you to choose a license that best suits your requirements:
*
* 1) GPLv3 (GNU General Public License Version 3) with Additional Terms, including an attribution clause as well as a limitation to use the software for commercial purposes.
*
* 2) CPSign Proprietary License that allows you to use CPSign for commercial activities, such as in a revenue-generating operation or environment, or integrate CPSign in your proprietary software without worrying about disclosing the source code of your proprietary software, which is required if you choose to use the software under GPLv3 license. See arosbio.com/cpsign/commercial-license for details.
*/
package com.arosbio.ml.gridsearch;
import java.io.IOException;
import java.io.Writer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.arosbio.commons.CollectionUtils;
import com.arosbio.commons.LazyListsPermutationIterator;
import com.arosbio.commons.MathUtils;
import com.arosbio.commons.Stopwatch;
import com.arosbio.commons.StringUtils;
import com.arosbio.commons.config.Configurable;
import com.arosbio.commons.config.Configurable.ConfigParameter;
import com.arosbio.commons.logging.LoggerUtils;
import com.arosbio.commons.mixins.ResourceAllocator;
import com.arosbio.data.DataUtils;
import com.arosbio.data.DataUtils.DataType;
import com.arosbio.data.Dataset;
import com.arosbio.data.MissingDataException;
import com.arosbio.io.CollectionsWriter;
import com.arosbio.io.DebugWriter;
import com.arosbio.io.IOUtils;
import com.arosbio.ml.algorithms.Classifier;
import com.arosbio.ml.algorithms.MLAlgorithm;
import com.arosbio.ml.algorithms.Regressor;
import com.arosbio.ml.cp.ConformalPredictor;
import com.arosbio.ml.gridsearch.utils.GSResComparator;
import com.arosbio.ml.interfaces.Predictor;
import com.arosbio.ml.metrics.Metric;
import com.arosbio.ml.metrics.MetricAggregation;
import com.arosbio.ml.metrics.MetricFactory;
import com.arosbio.ml.metrics.SingleValuedMetric;
import com.arosbio.ml.metrics.cp.CPAccuracy;
import com.arosbio.ml.metrics.cp.ConfidenceDependentMetric;
import com.arosbio.ml.testing.KFoldCV;
import com.arosbio.ml.testing.TestRunner;
import com.arosbio.ml.testing.TestingStrategy;
import com.arosbio.ml.testing.utils.EvaluationUtils;
import com.arosbio.ml.vap.VennABERSPredictor;
/**
* GridSearch takes a Dataset and the chosen
* {@link com.arosbio.ml.interfaces.Predictor Predictor}
* and performs an exhaustive search for the
* {@link com.arosbio.commons.config.Configurable.ConfigParameter
* ConfigParameter}s
* (available parameters depend on the predictor/NCM/scoring algorithm that is
* used). For instance some
* NCMs in ICPRegressor uses an error-model, and the parameters can be set
* independently from
* the scoring model. Thus the number of possible combinations can be very
* large!
*
* The {@link com.arosbio.ml.interfaces.Predictor Predictor} has a default
* {@link com.arosbio.ml.metrics.SingleValuedMetric SingleValuedMetric}
* that is used if not specified.
* Note that only a single confidence can be used for the evaluation, so it
* might be good to consider metrics that
* do not depend on the confidence.
*
* Another important parameter of {@link GridSearch} is the
* {@code tolerance} (Conformal Prediction only). The tolerance controls
* how much the accuracy is tolerated to differ compared
* to the specified {@code confidence}. Conformal Predictors
* are meant to give the correct results in e.g. 70 % of the predictions if
* confidence is set to 0.7, but that is only guaranteed on average for a
* large test set, meaning that some times the output might be lower or higher
* than set confidence level. The tolerance will allow for a discrepancy
* in the accuracy of the Conformal Predictor (estimated by the
* {@link com.arosbio.ml.testing.TestingStrategy TestingStrategy}).
* The default {@code tolerance} is set to 0.05, e.g.
* if confidence is set to 0.7, the Grid Search will accept results with
* accuracy of 0.7 - 0.05 = 0.65 at worst. For Venn-ABERS predictors the
* accuracy is not checked so the {@code tolerance} parameter is not used.
* Note that {@code tolerance} can only be in the range [0..1].
*
*
* @author Aros Bio AB
* @author Staffan Arvidsson McShane
*
*/
public class GridSearch {
private static final Logger LOGGER = (Logger) LoggerFactory.getLogger(GridSearch.class);
public static final double MIN_ALLOWED_TOLERANCE = 0.0, MAX_ALLOWED_TOLERANCE = 1.0;
public static enum EvalStatus {
IN_PROGRESS("in progress"), VALID("valid"), NOT_VALID("not valid"), FAILED("failed");
private final String textRep;
private EvalStatus(String textRep) {
this.textRep = textRep;
}
public String toString() {
return textRep;
}
}
/**
* This is a callback interface that you can optionally register an instance of,
* in order to get information about the currently running grid search. This can be used
* in order to e.g. stop execution early to revise any parameters.
*/
public static interface ProgressCallback {
/**
* Get information about the currently running grid search (will be called after every param-combo finished)
* @param info {@link ProgressInfo} with the current state
*/
public void updatedInfo(ProgressInfo info);
}
/**
* This is a callback interface, similar to {@link ProgressCallback} but has the option to return an action
* given the current state - i.e. to stop execution (or continue) so that the JVM doesn't need to be stoped,
* and e.g. the parameter grid or other things can be altered programatically.
*/
public static interface ProgressMonitor {
public static enum Action {
CONTINUE, EXIT;
}
/**
* Allows the {@link ProgressMonitor} to act on the current state of execution,
* by either specifying that the grid search should continue or to exit it and
* return the currently best parameters and any information
* @param info the current information
* @return an {@link Action} for the grid search
*/
public Action actOnInfo(ProgressInfo info);
}
/**
* This is a class holding information about the current progress when running {@link GridSearch},
* this can be recieved by registering a {@link ProgressMonitor} when instantiating the {@link GridSearch}
* class.
*/
public static class ProgressInfo {
private final int numTotalGridPoints;
private final int numProcessedGridPoints;
private final long runtime;
private final double currentBestScore;
private ProgressInfo(int numTotal, int numProcessed, long runtime, double bestScore){
this.numTotalGridPoints = numTotal;
this.numProcessedGridPoints = numProcessed;
this.runtime = runtime;
this.currentBestScore = bestScore;
}
public int getTotalNumGridPoints(){
return numTotalGridPoints;
}
/**
* Get the number of processed grid points (i.e. combinations of hyperparameters)
* @return the number of grid points that has been tested
*/
public int getNumProcessedGridPoints(){
return numProcessedGridPoints;
}
/**
* Get the elapsed time in milliseconds for the currently running grid search
* @return elapsed time in milliseconds
*/
public long getElapsedTimeMS(){
return runtime;
}
public double currentBestScore(){
return currentBestScore;
}
}
public static class GSResult {
private final Map parameters;
private final double result;
private final SingleValuedMetric optimizationType;
private final long runtimeMS;
private final List secondaryMetrics;
private final EvalStatus status;
private final String errorMessage;
private GSResult(Builder b){
this.parameters = Objects.requireNonNull(b.parameters);
this.result = b.result;
this.optimizationType = Objects.requireNonNull(b.optimizationType);
this.runtimeMS = b.runtimeMS;
this.secondaryMetrics = b.secondaryMetrics;
this.status = b.status;
this.errorMessage = b.errorMessage;
}
static class Builder {
private Map parameters;
private double result;
private SingleValuedMetric optimizationType;
private long runtimeMS;
private List secondaryMetrics;
private EvalStatus status;
private String errorMessage;
public static Builder success(Map params,
double optimizationResult,
SingleValuedMetric type,
long runtime) {
Builder b = new Builder();
b.parameters = params;
b.result = optimizationResult;
b.optimizationType = type;
b.runtimeMS = runtime;
b.status = EvalStatus.VALID;
return b;
}
public static Builder failed(Map params,
SingleValuedMetric optMetric,
EvalStatus status,
String error) {
Builder b = new Builder();
b.parameters = params;
b.optimizationType = optMetric;
b.status = status;
b.errorMessage = error;
return b;
}
public Builder secondary(List metrics){
this.secondaryMetrics = metrics;
return this;
}
public GSResult build(){
return new GSResult(this);
}
}
public EvalStatus getStatus() {
return status;
}
public SingleValuedMetric getOptimizationMetric() {
return optimizationType;
}
public List getSecondaryMetrics() {
return secondaryMetrics;
}
public Map getParams() {
return parameters;
}
public double getResult() {
return result;
}
/**
* Runtime in milliseconds
*
* @return the runtime
*/
public long getRuntime() {
return runtimeMS;
}
public String toString() {
return String.format("GSResult using metric %s: %s, runtime: %sms, params: %s",
optimizationType.getName(),
(status == EvalStatus.VALID ? result : status.textRep),
runtimeMS,
parameters);
}
/**
* Returns error message (if any) that was encountered during the run. Empty
* string if no exceptions were thrown
*
* @return Error message or empty String
*/
public String getErrorMessage() {
return errorMessage != null ? errorMessage : "";
}
}
private final Writer customResultsWriter;
// testing settings
private final TestingStrategy testStrategy;
private final boolean calcMeanAndSD;
private final SingleValuedMetric explicitMetric;
private final List secondaryMetrics;
private final double confidence;
private final double tolerance;
private final int maxNumGSresults;
private final ProgressCallback callback;
private final ProgressMonitor monitor;
private GridSearch(Builder builder) {
customResultsWriter = builder.customWriter;
if (builder.testStrategy == null)
throw new IllegalArgumentException("Must specify a testing strategy");
testStrategy = builder.testStrategy;
calcMeanAndSD = builder.computeMeanAndSD;
explicitMetric = builder.optMetric;
secondaryMetrics = builder.secondaryMetrics;
confidence = builder.confidence;
tolerance = builder.tolerance;
maxNumGSresults = builder.maxNumGSresults;
monitor = builder.monitor;
callback = builder.callback;
}
/**
* A mutable builder object. Calls will return the reference to the same
* instance, with a fluid API facilitating chaining of method calls. Uses 10-fold CV
* as the default {@link TestingStrategy}.
*/
public static class Builder {
private TestingStrategy testStrategy = new KFoldCV();
private boolean computeMeanAndSD = true;
private SingleValuedMetric optMetric;
private List secondaryMetrics;
private Writer customWriter;
private double confidence = ConfidenceDependentMetric.DEFAULT_CONFIDENCE;
private double tolerance = 0.05;
private int maxNumGSresults = 10;
private ProgressCallback callback;
private ProgressMonitor monitor;
public Builder testStrategy(TestingStrategy strategy) {
this.testStrategy = strategy;
return this;
}
public TestingStrategy testStrategy() {
return testStrategy;
}
/**
* Set the metric that should be used for determining the best model
*
* @param metric the metric
* @return the same Builder object
*/
public Builder optimizationMetric(SingleValuedMetric metric) {
this.optMetric = metric;
return this;
}
/**
* Set the metric that should be used for determining the best model
*
* @param metric the metric
* @return the same Builder object
*/
public Builder optMetric(SingleValuedMetric metric) {
this.optMetric = metric;
return this;
}
/**
* Set the metric that should be used for determining the best model
*
* @param metric the metric
* @return the same Builder object
*/
public Builder evaluationMetric(SingleValuedMetric metric) {
this.optMetric = metric;
return this;
}
/**
* Allows to evaluate the parameters using additional metrics, not only the
* evaluation metric that is used for picking the best evaluation strategy
*
* @param metrics a list of metrics
* @return the reference of the calling instance (fluid API)
*/
public Builder secondaryMetrics(List metrics) {
this.secondaryMetrics = new ArrayList<>(metrics);
return this;
}
public List secondaryMetrics() {
return this.secondaryMetrics;
}
/**
* Allows to evaluate the parameters using additional metrics, not only the
* evaluation metric that is used for picking the best evaluation strategy
*
* @param metrics a list of metrics
* @return the same Builder object
*/
public Builder secondaryMetrics(SingleValuedMetric... metrics) {
this.secondaryMetrics = new ArrayList<>(Arrays.asList(metrics));
return this;
}
/**
* Set the desired confidence of the internal cross-validation (not always
* applicable)
*
* @param confidence the confidence, should be in range [0..1]
* @return the same Builder object
* @throws IllegalArgumentException If the confidence given is not allowed
*/
public Builder confidence(double confidence) {
if (confidence < 0 || confidence > 1)
throw new IllegalArgumentException("Confidence must be within the range [0..1]");
this.confidence = confidence;
return this;
}
/**
* Setter for the tolerance for the validity of the model
*
* @param tol Allowed tolerance for validity of the model range [0..1]
* @return the reference of the calling instance (fluid API)
*/
public Builder tolerance(double tol) {
if (tolerance < MIN_ALLOWED_TOLERANCE || tolerance > MAX_ALLOWED_TOLERANCE)
throw new IllegalArgumentException(String.format("Parameter tolerance must be in range [%s..%s]",
MIN_ALLOWED_TOLERANCE, MAX_ALLOWED_TOLERANCE));
this.tolerance = tol;
return this;
}
public Builder maxNumResults(int max) {
this.maxNumGSresults = max;
return this;
}
public Builder loggingWriter(Writer output) {
this.customWriter = output;
return this;
}
public Builder computeMeanSD(boolean compute) {
this.computeMeanAndSD = compute;
return this;
}
public Builder register(ProgressCallback callback){
this.callback = callback;
return this;
}
public Builder register(ProgressMonitor monitor){
this.monitor = monitor;
return this;
}
public GridSearch build() {
return new GridSearch(this);
}
}
public SingleValuedMetric getEvaluationMetric() {
return explicitMetric;
}
public SingleValuedMetric getOptimizationMetric() {
return explicitMetric;
}
public List getSecondaryMetrics() {
return secondaryMetrics;
}
public TestingStrategy getTestingStrategy() {
return this.testStrategy;
}
/**
* Get the confidence used for internal cross validation
*
* @return the confidence
*/
public double getConfidence() {
return confidence;
}
/**
* Getter for the tolerance for the validity of the model
*
* @return the tolerance
*/
public double getTolerance() {
return tolerance;
}
private static void verifyGridParameters(Map> grid, Configurable predictor) {
List paramList = predictor.getConfigParameters();
List allowedParamNames = new ArrayList<>();
for (ConfigParameter p : paramList) {
allowedParamNames.addAll(p.getNames());
}
LOGGER.debug("All possible parameters: {}, given parameters: {}", allowedParamNames, grid.keySet());
Set nonOkParams = new HashSet<>();
for (String givenParamName : grid.keySet()) {
if (!CollectionUtils.containsIgnoreCase(allowedParamNames, givenParamName)) {
nonOkParams.add(givenParamName);
}
}
if (!nonOkParams.isEmpty()) {
LOGGER.debug("Found extra parameters that is not valid: {}", nonOkParams);
throw new IllegalArgumentException("Following parameters are not used/recognized: " + nonOkParams);
}
}
private final static String WARNING_MESSAGE = "WARNING: Optimal parameters found at border of the grid, true optimal parameters might be outside the search grid. Parameters affected: ";
private final static String WARNING_EXECUTION_STOPPED = "WARNING: execution was manually terminated, all given parameters may not have been tested";
private final static String WARNING_NO_VALID_RESULTS = "WARNING: no parameter combinations produced valid models";
private String getWarning(Map optimalParams, Map> grid, boolean stoppedEarly) {
if (stoppedEarly)
return WARNING_EXECUTION_STOPPED;
StringBuilder warningBuilder = new StringBuilder();
for (String p : optimalParams.keySet()) {
// Cannot be on the boarder unless 3 points given
if (grid.get(p).size() < 3)
continue;
// If not numeric value - skip (how to do it for enum/sampling etc?
if (!(optimalParams.get(p) instanceof Number))
continue;
Pair, List