All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.arosbio.ml.cp.tcp.TCPClassifier Maven / Gradle / Ivy

Go to download

Conformal AI package, including all data IO, transformations, machine learning models and predictor classes. Without inclusion of chemistry-dependent code.

There is a newer version: 2.0.0
Show newest version
/*
 * Copyright (C) Aros Bio AB.
 *
 * CPSign is an Open Source Software that is dual licensed to allow you to choose a license that best suits your requirements:
 *
 * 1) GPLv3 (GNU General Public License Version 3) with Additional Terms, including an attribution clause as well as a limitation to use the software for commercial purposes.
 *
 * 2) CPSign Proprietary License that allows you to use CPSign for commercial activities, such as in a revenue-generating operation or environment, or integrate CPSign in your proprietary software without worrying about disclosing the source code of your proprietary software, which is required if you choose to use the software under GPLv3 license. See arosbio.com/cpsign/commercial-license for details.
 */
package com.arosbio.ml.cp.tcp;

import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.security.InvalidKeyException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.arosbio.commons.CollectionUtils;
import com.arosbio.commons.FuzzyServiceLoader;
import com.arosbio.commons.GlobalConfig.Defaults.PredictorType;
import com.arosbio.commons.TypeUtils;
import com.arosbio.commons.config.Configurable;
import com.arosbio.commons.config.ImplementationConfig;
import com.arosbio.commons.mixins.ResourceAllocator;
import com.arosbio.data.DataRecord;
import com.arosbio.data.Dataset;
import com.arosbio.data.Dataset.SubSet;
import com.arosbio.data.FeatureVector;
import com.arosbio.data.FeatureVector.Feature;
import com.arosbio.data.MissingDataException;
import com.arosbio.data.SparseFeature;
import com.arosbio.data.SparseFeatureImpl;
import com.arosbio.encryption.EncryptionSpecification;
import com.arosbio.io.DataIOUtils;
import com.arosbio.io.DataSink;
import com.arosbio.io.DataSource;
import com.arosbio.ml.PredictorBase;
import com.arosbio.ml.TrainingsetValidator;
import com.arosbio.ml.algorithms.impl.DefaultMLParameterSettings;
import com.arosbio.ml.cp.ConformalClassifier;
import com.arosbio.ml.cp.nonconf.NCM;
import com.arosbio.ml.cp.nonconf.calc.PValueCalculator;
import com.arosbio.ml.cp.nonconf.calc.SmoothedPValue;
import com.arosbio.ml.cp.nonconf.classification.NCMMondrianClassification;
import com.arosbio.ml.io.MetaFileUtils;
import com.arosbio.ml.io.impl.PropertyNameSettings;
import com.arosbio.ml.metrics.SingleValuedMetric;
import com.arosbio.ml.metrics.cp.classification.ProportionSingleLabelPredictions;

public final class TCPClassifier extends PredictorBase implements TCP, ConformalClassifier {

	public static final String PREDICTOR_TYPE = "TCP Classification";

	private static final Logger LOGGER = LoggerFactory.getLogger(TCPClassifier.class);

	// SAVING
	private static final String TCP_DIRECTORY_NAME = "tcp";
	private static final String TCP_META_INFO = "tcp.meta.json";
	private static final String NCM_BASE = "ncm";

	private NCMMondrianClassification ncm;
	private PValueCalculator pValueCalculator = new SmoothedPValue();
	private Dataset originalData;

	private SubSet trainingData;
	private Set labels;

	/* 
	 * =================================================
	 * 			CONSTRUCTORS
	 * =================================================
	 */

	public TCPClassifier() {
		super();
	}

	public TCPClassifier(NCMMondrianClassification ncm) {
		super();
		this.ncm = ncm;
	}

	public TCPClassifier(NCMMondrianClassification ncm, long seed) {
		this(ncm);
		this.seed = seed;
	}

	@Override
	public TCPClassifier clone() {
		TCPClassifier clone = new TCPClassifier(this.ncm.clone());
		clone.seed=seed;
		if (originalData != null)
			clone.originalData = originalData.clone();
		clone.pValueCalculator = pValueCalculator.clone();
		return clone;
	}

	/* 
	 * =================================================
	 * 			GETTERS AND SETTERS
	 * =================================================
	 */

	@Override
	public SingleValuedMetric getDefaultOptimizationMetric() {
		return new ProportionSingleLabelPredictions();
	}

	@Override
	public Dataset getDataset() {
		return originalData;
	}

	public PValueCalculator getPValueCalculator() {
		return pValueCalculator;
	}

	public void setPValueCalculator(PValueCalculator ncsEstimator) {
		this.pValueCalculator = ncsEstimator;
	}

	public NCMMondrianClassification getNCM() {
		return ncm;
	}

	public void setNCM(NCMMondrianClassification ncm) {
		this.ncm = ncm;
	}

	@Override
	public void setSeed(long seed) {
		super.setSeed(seed);
		if (ncm!=null && ncm.getModel() != null)
			ncm.getModel().setSeed(seed);
		if (pValueCalculator != null)
			pValueCalculator.setRNGSeed(seed);
	}

	@Override
	public boolean isTrained() {
		return originalData != null && !originalData.isEmpty() && ncm != null && pValueCalculator != null;
	}

	@Override
	public int getNumObservationsUsed() {
		return originalData.getNumRecords();
	}

	@Override
	public Map getProperties() {
		Map props = new HashMap<>();
		props.putAll(ncm.getProperties()); // this should include everything to recreate the NCM implementation
		props.put(PropertyNameSettings.ML_TYPE_KEY, PredictorType.TCP_CLASSIFICATION.getId());
		props.put(PropertyNameSettings.ML_TYPE_NAME_KEY, PredictorType.TCP_CLASSIFICATION.getName());
		props.put(PropertyNameSettings.IS_CLASSIFICATION_KEY, true);
		props.put(PropertyNameSettings.ML_SEED_VALUE_KEY, seed);
		props.put(PValueCalculator.PVALUE_CALCULATOR_NAME_KEY, pValueCalculator.getName());
		props.put(PValueCalculator.PVALUE_CALCULATOR_ID_KEY, pValueCalculator.getID());
		if (pValueCalculator.getRNGSeed() != null)
			props.put(PValueCalculator.PVALUE_CALCULATOR_SEED_KEY, pValueCalculator.getRNGSeed());
		return props;
	}

	@Override
	public String getPredictorType() {
		return PREDICTOR_TYPE;
	}

	@Override
	public int getNumClasses() {
		if (!isTrained())
			return 0;
		return labels.size();
	}

	@Override
	public Set getLabels(){
		if (! isTrained())
			return new HashSet<>();
		return new HashSet<>(labels);
	}

	@Override
	public boolean releaseResources() {
		if (holdsResources()){
			((ResourceAllocator)ncm.getModel()).releaseResources();
			return true;
		}
		return false;
	}

	@Override
	public boolean holdsResources() {
		return ncm != null && (ncm.getModel() instanceof ResourceAllocator);
	}

	@Override
	public List getConfigParameters() {
		List params = new ArrayList<>();
		if (ncm!=null)
			params.addAll(ncm.getConfigParameters());

		// ncm estimator
		params.add(new ImplementationConfig.Builder<>(CONFIG_PVALUE_CALC_PARAM_NAMES, PValueCalculator.class).defaultValue(new SmoothedPValue()).build());
		return params;
	}

	@Override
	public void setConfigParameters(Map params) throws IllegalArgumentException {
		for (Map.Entry kv : params.entrySet()) {
			try {
			if (CollectionUtils.containsIgnoreCase(CONFIG_PVALUE_CALC_PARAM_NAMES, kv.getKey())) {
				if (kv.getValue() instanceof PValueCalculator) {
					pValueCalculator = (PValueCalculator) kv.getValue();
				} else {
					pValueCalculator = FuzzyServiceLoader.load(PValueCalculator.class, kv.getValue().toString());
				}
			}
			} catch (Exception e) {
				LOGGER.debug("Got invalid config argument: {}", kv);
				throw Configurable.getInvalidArgsExcept(kv.getKey(), kv.getValue()); 
			}
		}
		// pass on to underlying ncm
		ncm.setConfigParameters(params);
	}


	/* 
	 * =================================================
	 * 			TRAIN
	 * =================================================
	 */

	@Override
	public void train(Dataset data) throws MissingDataException, IllegalArgumentException {
		if (data == null || data.isEmpty())
			throw new IllegalArgumentException("No records given");

		originalData = data;

		List trainingset = new ArrayList<>(data.getNumRecords()+1);
		trainingset.addAll(data.getDataset());
		if (!data.getModelingExclusiveDataset().isEmpty()){
            LOGGER.warn("TCP predictor trained with model-exclusive data, TCP does not support this at this time - all data is merged and used");
            trainingset.addAll(data.getModelingExclusiveDataset());
        }
		if (! data.getCalibrationExclusiveDataset().isEmpty()){
            LOGGER.warn("TCP predictor trained with calibration-exclusive data, TCP does not support this at this time - all data is merged and used");
            trainingset.addAll(data.getCalibrationExclusiveDataset());
        }

		// Validate to make sure it's big enough
		TrainingsetValidator.getInstance().validateClassification(trainingset);

		// Shuffle data
		trainingData = new SubSet(trainingset);
		if (trainingData.containsMissingFeatures())
			throw new MissingDataException("Training data contain missing feature values");

		LOGGER.debug("shuffling records with seed: {}",seed);
		trainingData.shuffle(seed);

		// Set the labels
		this.labels = new HashSet<>();
		for (double l: originalData.getLabels()) {
			this.labels.add((int)l);
		}

		LOGGER.debug("Finished 'training' TCP predictor - i.e. setting the training data");
	}


	/* 
	 * =================================================
	 * 			PREDICT
	 * =================================================
	 */

	private void assertIsTrained(){
		if (!isTrained())
			throw new IllegalStateException("Predictor not trained");
	}

	@Override
	public Map predict(final FeatureVector example)
			throws IllegalStateException {
		assertIsTrained();

		Map prediction = new HashMap<>();

		int beforeSize = trainingData.size(); 
		for (int label : labels) {
			prediction.put(label, predictPValueForClass(label, example));
		}

		int afterSize = trainingData.size();

		if (beforeSize != afterSize) {
			LOGGER.debug("The before and after sizes doesn't match! something is wrong in the code!");
			throw new RuntimeException("Coding error in the TCPClassifier class");
		}

		return prediction;
	}

	private double predictPValueForClass(int label, FeatureVector example) {
		// Add the record, with the assumed label
		trainingData.add(new DataRecord((double)label, example));

		// train the NCM
		ncm.trainNCM(trainingData);

		// Predict all alphas (NCS) for the assumed label
		List ncs = new ArrayList<>(trainingData.size()/2);
		for (DataRecord r : trainingData) {
			if (label == (int)r.getLabel()) {
				// get the NCS for the label of interest
				ncs.add(ncm.calculateNCS(r.getFeatures()).get(label));
			}
		}

		// The last example is the example to predict!
		double ncsForTestEx = ncs.remove(ncs.size()-1);

		// Fit the ncs estimator
		pValueCalculator.build(ncs);

		// calculate the p-value for the test example
		double pValue = pValueCalculator.getPvalue(ncsForTestEx);

		// remove the example from the training examples!
		trainingData.remove(trainingData.size()-1);

		return pValue;
	}

	@Override
	public List calculateGradient(FeatureVector example)
			throws IllegalStateException {
		return calculateGradient(example, DefaultMLParameterSettings.DEFAULT_STEPSIZE);
	}

	@Override
	public List calculateGradient(FeatureVector example, double stepsize)
			throws IllegalStateException {

		//First do a normal prediction
		Map result = predict(example);

		//Pick class with largest pValue
		int selectedClass = 0;
		double highestPValue = -1d;
		for(Map.Entry pVales : result.entrySet()){
			if (pVales.getValue()>highestPValue){
				selectedClass=pVales.getKey();
				highestPValue = pVales.getValue();
			}
		}
		return doCalc(example, stepsize, selectedClass,result);
	}

	@Override
	public List calculateGradient(FeatureVector example, int label)
			throws IllegalStateException {
		return calculateGradient(example, DefaultMLParameterSettings.DEFAULT_STEPSIZE, label);
	}

	@Override
	public List calculateGradient(FeatureVector example, double stepsize, int label)
			throws IllegalStateException {
		return doCalc(example, stepsize, label, predict(example));
	}

	private List doCalc(FeatureVector example, double stepsize, int label, Map pvals){

	
		//The gradient to return, same size as the example to predict
		List gradient = new ArrayList<>(example.getNumExplicitFeatures());

		// First do a normal prediction
		double normalPValue = pvals.get(label);

		for (Feature f : example) {
			// Get the old value
			double oldValue = f.getValue();

			// Update it to the new
			example.withFeature(f.getIndex(), f.getValue()+stepsize);

			// Predict it and store in the gradient
			double diff = (predictPValueForClass(label, example)-normalPValue)/stepsize;
			gradient.add(new SparseFeatureImpl(f.getIndex(), diff));

			// Change it back to what it was!
			example.withFeature(f.getIndex(), oldValue);
		}

		return gradient;
	}

	@Override
	public void saveToDataSink(DataSink sink, String path, EncryptionSpecification encryptSpec)
			throws IOException, InvalidKeyException, IllegalStateException {

		// create the directory
		String tcpDir = DataIOUtils.createBaseDirectory(sink, path, TCP_DIRECTORY_NAME+'/');

		LOGGER.debug("Saving TCP Classifier to sink, location= {}", tcpDir);
		originalData.saveToDataSink(sink, tcpDir, encryptSpec);
		LOGGER.debug("Saved TCP data");

		ncm.saveToDataSink(sink, tcpDir + NCM_BASE, encryptSpec);
		LOGGER.debug("Saved TCP NCM");

		try (OutputStream ncmPropertyStream = sink.getOutputStream(tcpDir+ TCP_META_INFO)){
			MetaFileUtils.writePropertiesToStream(ncmPropertyStream, getProperties());
		} catch(Exception e) {
			LOGGER.debug("Failed saving TCP properties to stream", e);
			throw new IOException("Failed saving TCP");
		}
		LOGGER.debug("Saved tcp meta info");
	}

	@Override
	public void loadFromDataSource(DataSource source, String path, EncryptionSpecification encryptSpec)
			throws IOException, InvalidKeyException {

		String tcpDir = DataIOUtils.locateBasePath(source, path, TCP_DIRECTORY_NAME+'/');

		LOGGER.debug("Trying to load TCP classifier from src, location={}",tcpDir);
		Dataset p = new Dataset();
		p.loadFromDataSource(source, tcpDir, encryptSpec);
		LOGGER.debug("Loaded {} records",p.getNumRecords());

		LOGGER.debug("Trying to load TCP Meta info");
		Map props = null;
		try (InputStream tcpPropStream = source.getInputStream(tcpDir + TCP_META_INFO)){
			props = MetaFileUtils.readPropertiesFromStream(tcpPropStream);
			seed = TypeUtils.asLong(props.get(PropertyNameSettings.ML_SEED_VALUE_KEY));
		} catch(Exception e) {
			LOGGER.debug("Failed loading properties", e);
			throw new IOException("Could not load properties from source");
		}

		// p-value calculator
		if (props.containsKey(PValueCalculator.PVALUE_CALCULATOR_ID_KEY)) {
			int id = TypeUtils.asInt(props.get(PValueCalculator.PVALUE_CALCULATOR_ID_KEY));
			LOGGER.debug("Retrieving pvalue-calculator based on ID: {}", id);
			pValueCalculator = FuzzyServiceLoader.load(PValueCalculator.class, id);
		} else if (props.containsKey(PValueCalculator.PVALUE_CALCULATOR_NAME_KEY)) {
			String name = props.get(PValueCalculator.PVALUE_CALCULATOR_NAME_KEY).toString();
			LOGGER.debug("Retreiving pvalue-calculator based on name: {}", name);
			pValueCalculator = FuzzyServiceLoader.load(PValueCalculator.class, name);
		} else {
			LOGGER.debug("No pvalue-calculator info saved in model-file, using the default one");
		}
		if (props.containsKey(PValueCalculator.PVALUE_CALCULATOR_SEED_KEY)) {
			long seed = TypeUtils.asLong(props.get(PValueCalculator.PVALUE_CALCULATOR_SEED_KEY));
			pValueCalculator.setRNGSeed(seed);
			LOGGER.debug("Set the p-value calculator seed to: {}", seed);
		}

		NCM ncmLoaded = FuzzyServiceLoader.load(NCM.class, props.get(PropertyNameSettings.NCM_ID).toString());
		if (!(ncmLoaded instanceof NCMMondrianClassification)) {
			LOGGER.debug("TCP meta pointed to a faulty NCM implementation of non-correct type: {}", ncmLoaded.getName()); 
			throw new IOException("Failed initializing the NCM for TCP");
		}
		ncm = (NCMMondrianClassification) ncmLoaded;

		ncm.loadFromDataSource(source, tcpDir + NCM_BASE, encryptSpec);

		// Set the properties saved in the model - hopefully this will set the model-specific things correctly!
		ncm.getModel().setConfigParameters(props);

		// Train it
		train(p);

		// set the seed
		setSeed(TypeUtils.asLong(props.get(PropertyNameSettings.ML_SEED_VALUE_KEY)));
	}

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy