All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.arosbio.ml.cp.nonconf.regression.SignedNormalizedNCM Maven / Gradle / Ivy

Go to download

Conformal AI package, including all data IO, transformations, machine learning models and predictor classes. Without inclusion of chemistry-dependent code.

There is a newer version: 2.0.0
Show newest version
/*
 * Copyright (C) Aros Bio AB.
 *
 * CPSign is an Open Source Software that is dual licensed to allow you to choose a license that best suits your requirements:
 *
 * 1) GPLv3 (GNU General Public License Version 3) with Additional Terms, including an attribution clause as well as a limitation to use the software for commercial purposes.
 *
 * 2) CPSign Proprietary License that allows you to use CPSign for commercial activities, such as in a revenue-generating operation or environment, or integrate CPSign in your proprietary software without worrying about disclosing the source code of your proprietary software, which is required if you choose to use the software under GPLv3 license. See arosbio.com/cpsign/commercial-license for details.
 */
package com.arosbio.ml.cp.nonconf.regression;

import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.security.InvalidKeyException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.arosbio.commons.CollectionUtils;
import com.arosbio.commons.FuzzyServiceLoader;
import com.arosbio.commons.TypeUtils;
import com.arosbio.commons.Version;
import com.arosbio.commons.config.NumericConfig;
import com.arosbio.data.DataRecord;
import com.arosbio.data.FeatureVector;
import com.arosbio.encryption.EncryptionSpecification;
import com.arosbio.io.DataSink;
import com.arosbio.io.DataSource;
import com.arosbio.ml.algorithms.MLAlgorithm;
import com.arosbio.ml.algorithms.Regressor;
import com.arosbio.ml.algorithms.impl.AlgorithmUtils;
import com.arosbio.ml.cp.nonconf.utils.NCMUtils;
import com.arosbio.ml.io.MetaFileUtils;
import com.arosbio.ml.io.impl.PropertyNameSettings;
import com.google.common.collect.Range;

/**
 * 

Regression

*

* Slightly different than the {@link NormalizedNCM}, differs only by * letting the error model train on the signed residual instead of taking the absolute value * of the residual as in the {@link NormalizedNCM}. * *

* Computes nonconformity as:
* nonconf = |y - y_hat|/μ_hat , μ_hat = predicted error by error model *

* Error model is trained by:
* μ = y - y_hat * * @author staffan */ public class SignedNormalizedNCM implements NCMRegression { private static final Version VERSION = new Version(1, 1, 0); private static final Logger LOGGER = LoggerFactory.getLogger(SignedNormalizedNCM.class); private final static String BETA_PROPERTY_KEY = "ncm_beta"; private final static String INVALID_BETA_ERR_MSG = "Beta value must be in range [0..+∞)"; public static final String IDENTIFIER="SignedNormalized"; public static final int ID = 4; public final static double DEFAULT_BETA_VALUE = 0.01; public final static String DESCRIPTION = "Nonconformity measure for regression. Residuals found for the calibration " + "set is used for training an error-model that will normalize the nonconformity scores for new objects. " + "Thus making 'easier' objects get smaller prediction intervals and more difficult to get larger ones. " + "The error-model is trained on the signed value of the residuals in contrast with the " + NormalizedNCM.IDENTIFIER + " NCM that " + "trains on the absolute value of the residuals. " + "The " + CONFIG_BETA_PARAM_NAMES.get(0) + " parameter 'smoothes' the importance of the normalization and a beta >0 " + "remove the possibility of generating infinitly large scores due to division by 0."; private double beta=DEFAULT_BETA_VALUE; private Regressor model; private Regressor errorModel; //////////////////////////////////// ///////// CONSTRUCTORS //////////////////////////////////// public SignedNormalizedNCM() {} public SignedNormalizedNCM(Regressor model) { this(model, model.clone()); } public SignedNormalizedNCM(Regressor model, Regressor errorModel) { Objects.requireNonNull(model, "Regressor model cannot be null"); this.model = model; if (errorModel == null){ LOGGER.debug("No explicit error model given, using the same settings as the target model"); this.errorModel = model.clone(); } else { this.errorModel = errorModel; } } public SignedNormalizedNCM(Regressor model, Regressor errorModel, double beta) { this(model,errorModel); this.setBeta(beta); } public SignedNormalizedNCM clone() { SignedNormalizedNCM clone = new SignedNormalizedNCM(); if (model != null) clone.model = model.clone(); if (errorModel != null) clone.errorModel = errorModel.clone(); clone.beta = beta; return clone; } //////////////////////////////////// ///////// GETTERS AND SETTERS //////////////////////////////////// public double getBeta() { return beta; } public void setBeta(double beta) { if (beta < 0) throw new IllegalArgumentException(INVALID_BETA_ERR_MSG); this.beta = beta; } @Override public boolean requiresErrorModel() { return true; } @Override public MLAlgorithm getErrorModel() { return errorModel; } @Override public void setErrorModel(MLAlgorithm model) throws IllegalArgumentException { if (model instanceof Regressor) this.errorModel = (Regressor) model; else throw new IllegalArgumentException("NCM error model must be of type Regressor"); } @Override public Regressor getModel() { return model; } @Override public void setModel(Regressor model) { this.model = model; } @Override public boolean isFitted() { return model!=null && model.isFitted() && errorModel!=null && errorModel.isFitted(); } @Override public Map getProperties() { Map props = new HashMap<>(model.getProperties()); props.putAll(NCMUtils.addErrorModelPrefix(errorModel.getProperties())); props.put(BETA_PROPERTY_KEY, beta); props.put(PropertyNameSettings.NCM_NAME, IDENTIFIER); props.put(PropertyNameSettings.NCM_ID, ID); props.put(PropertyNameSettings.NCM_VERSION, VERSION.toString()); props.put(PropertyNameSettings.NCM_SCORING_MODEL_ID, model.getID()); props.put(PropertyNameSettings.NCM_SCORING_MODEL_NAME, model.getName()); props.put(PropertyNameSettings.NCM_ERROR_MODEL_ID, errorModel.getID()); props.put(PropertyNameSettings.NCM_ERROR_MODEL_NAME, errorModel.getName()); return props; } @Override public String getName() { return IDENTIFIER; } @Override public int getID() { return ID; } public String toString() { return IDENTIFIER; } @Override public Version getVersion() { return VERSION; } @Override public String getDescription() { return DESCRIPTION; } @Override public List getConfigParameters() { List params = new ArrayList<>(); // Beta parameter params.add(new NumericConfig.Builder(CONFIG_BETA_PARAM_NAMES, DEFAULT_BETA_VALUE).range(Range.closed(0d, 10d)) .description("Smoothing term used for regulating the impact of the error model").build()); if (model==null || errorModel==null) return params; //Scoring params.addAll(model.getConfigParameters()); // Error params.addAll(NCMUtils.addErrorModelPrefix(errorModel.getConfigParameters())); return params; } @Override public void setConfigParameters(Map params) throws IllegalArgumentException { // Scoring params if (model!=null) model.setConfigParameters(params); // Error params if (errorModel!=null) errorModel.setConfigParameters(NCMUtils.stripPrefixAndGetErrorModelParams(params)); // BETA for (Map.Entry p : params.entrySet()) { if (CollectionUtils.containsIgnoreCase(CONFIG_BETA_PARAM_NAMES, p.getKey())) { try { setBeta(TypeUtils.asDouble(p.getValue())); } catch (Exception e) { throw new IllegalArgumentException("Invalid value for parameter "+ p.getKey() + ": value must be in range [0..+∞) but was '"+ p.getValue()+'\''); } } } } //////////////////////////////////// ///////// TRAIN / CALCULATE //////////////////////////////////// @Override public void trainNCM(List data) { model.train(data); LOGGER.debug("Trained scoring model"); List errorRecs = new ArrayList<>(); for (DataRecord rec : data) { double y_hat = model.predictValue(rec.getFeatures()); errorRecs.add(new DataRecord(rec.getLabel()-y_hat, rec.getFeatures())); } errorModel.train(errorRecs); LOGGER.debug("Trained error model"); } @Override public double calcNCS(DataRecord example) throws IllegalStateException { double y_hat = model.predictValue(example.getFeatures()); return Math.abs(example.getLabel()-y_hat)/ calcIntervalScaling(example.getFeatures()); // avoid division by 0 } @Override public double predictMidpoint(FeatureVector example) throws IllegalStateException { return model.predictValue(example); } @Override public double calcIntervalScaling(FeatureVector example) throws IllegalStateException { double e_hat = errorModel.predictValue(example); return Math.max(Math.abs(e_hat)+beta, Double.MIN_VALUE); } //////////////////////////////////// ///////// SAVE / LOAD //////////////////////////////////// @Override public void saveToDataSink(DataSink sink, String path, EncryptionSpecification encryptSpec) throws IOException, InvalidKeyException, IllegalStateException { String scoreSavePath = null, scoreParamsSavePath = null; if (model != null) { if (model.isFitted()) { // For TCP - not saving an actual model - but only the parameters scoreSavePath = path + NCMUtils.SCORING_MODEL_FILE; AlgorithmUtils.saveModelToStream(sink, scoreSavePath, model, encryptSpec); } // if (model.getParameters() != null) { // scoreParamsSavePath = path + NCMUtils.SCORING_MODEL_FILE + NCMUtils.MODEL_PARAMS_FILE_ENDING; // AlgorithmUtils.saveModelParamsToStream(sink, scoreParamsSavePath, model.getParameters()); // } } String errorSavePath = null, errorParamsSavePath = null; if (errorModel != null) { if (errorModel.isFitted()) { errorSavePath = path + NCMUtils.ERROR_MODEL_FILE; // For TCP - not saving an actual model - but only the parameters AlgorithmUtils.saveModelToStream(sink, errorSavePath, errorModel, encryptSpec); } // if (errorModel.getParameters() != null) { // errorParamsSavePath = path + NCMUtils.ERROR_MODEL_FILE + NCMUtils.MODEL_PARAMS_FILE_ENDING; // AlgorithmUtils.saveModelParamsToStream(sink, errorParamsSavePath, errorModel.getParameters()); // } } try (OutputStream metaStream = sink.getOutputStream(path+NCMUtils.NCM_PARAMS_META_FILE);){ Map props = getProperties(); props.put(PropertyNameSettings.NCM_SCORING_MODEL_LOCATION, scoreSavePath); props.put(PropertyNameSettings.NCM_SCORING_PARAMETERS_LOCATION, scoreParamsSavePath); props.put(PropertyNameSettings.NCM_ERROR_MODEL_LOCATION, errorSavePath); props.put(PropertyNameSettings.NCM_ERROR_PARAMETERS_LOCATION, errorParamsSavePath); props.put(PropertyNameSettings.IS_ENCRYPTED, (encryptSpec != null)); MetaFileUtils.writePropertiesToStream(metaStream, props); } LOGGER.debug("Saved NCM meta parameters, loc={}{}",path,NCMUtils.NCM_PARAMS_META_FILE); } @Override public void loadFromDataSource(DataSource src, String path, EncryptionSpecification encryptSpec) throws IOException, IllegalArgumentException, InvalidKeyException { if (! src.hasEntry(path+NCMUtils.NCM_PARAMS_META_FILE)) throw new IOException("No NCM saved at location=" + path); // Check that NCM matches NCMUtils.verifyNCMmatches(this, src, path+NCMUtils.NCM_PARAMS_META_FILE); Map params = null; try (InputStream istream = src.getInputStream(path + NCMUtils.NCM_PARAMS_META_FILE)){ params = MetaFileUtils.readPropertiesFromStream(istream); } if (!(boolean)params.get(PropertyNameSettings.IS_ENCRYPTED)) { encryptSpec = null; // set it to null } // Load beta value if (!params.containsKey(BETA_PROPERTY_KEY)) { throw new IOException("Could not retrieve NCM properties correctly"); } try { beta = TypeUtils.asDouble(params.get(BETA_PROPERTY_KEY)); LOGGER.debug("Loaded beta={} from meta-file",beta); } catch(NumberFormatException e) { throw new IOException("Could not load beta value properly from meta-property file"); } // SCORING MODEL MLAlgorithm scoreModel = FuzzyServiceLoader.load(MLAlgorithm.class, params.get(PropertyNameSettings.NCM_SCORING_MODEL_ID).toString()); if (scoreModel instanceof Regressor) { model = (Regressor) scoreModel; } else { LOGGER.debug("ML algorithm was not of correct type for NCM of type {}", IDENTIFIER); throw new IllegalArgumentException("ML scoring model was not of the correct algorithm type for NCM"); } if (params.get(PropertyNameSettings.NCM_SCORING_MODEL_LOCATION) != null) { AlgorithmUtils.loadAlgorithm(src, params.get(PropertyNameSettings.NCM_SCORING_MODEL_LOCATION).toString(), model, encryptSpec); LOGGER.debug("loaded scoring model"); } // ERROR MODEL MLAlgorithm errorModelTmp = FuzzyServiceLoader.load(MLAlgorithm.class, params.get(PropertyNameSettings.NCM_ERROR_MODEL_ID).toString()); if (errorModelTmp instanceof Regressor) { errorModel = (Regressor) errorModelTmp; } else { LOGGER.debug("ML model was not of correct type for NCM of type {}", IDENTIFIER); throw new IllegalArgumentException("ML error model was not of the correct algorithm type for NCM"); } if (params.get(PropertyNameSettings.NCM_ERROR_MODEL_LOCATION) != null) { AlgorithmUtils.loadAlgorithm(src, params.get(PropertyNameSettings.NCM_ERROR_MODEL_LOCATION).toString(), errorModel, encryptSpec); LOGGER.debug("loaded error model"); } } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy