All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.arosbio.ml.vap.avap.AVAPClassifier Maven / Gradle / Ivy

Go to download

Conformal AI package, including all data IO, transformations, machine learning models and predictor classes. Without inclusion of chemistry-dependent code.

There is a newer version: 2.0.0
Show newest version
/*
 * Copyright (C) Aros Bio AB.
 *
 * CPSign is an Open Source Software that is dual licensed to allow you to choose a license that best suits your requirements:
 *
 * 1) GPLv3 (GNU General Public License Version 3) with Additional Terms, including an attribution clause as well as a limitation to use the software for commercial purposes.
 *
 * 2) CPSign Proprietary License that allows you to use CPSign for commercial activities, such as in a revenue-generating operation or environment, or integrate CPSign in your proprietary software without worrying about disclosing the source code of your proprietary software, which is required if you choose to use the software under GPLv3 license. See arosbio.com/cpsign/commercial-license for details.
 */
package com.arosbio.ml.vap.avap;

import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.security.InvalidKeyException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.NoSuchElementException;
import java.util.Set;

import org.apache.commons.lang3.tuple.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.arosbio.commons.GlobalConfig.Defaults.PredictorType;
import com.arosbio.commons.MathUtils;
import com.arosbio.commons.TypeUtils;
import com.arosbio.commons.config.ImplementationConfig;
import com.arosbio.data.DataUtils;
import com.arosbio.data.Dataset;
import com.arosbio.data.FeatureVector;
import com.arosbio.data.SparseFeature;
import com.arosbio.encryption.EncryptionSpecification;
import com.arosbio.io.DataIOUtils;
import com.arosbio.io.DataSink;
import com.arosbio.io.DataSource;
import com.arosbio.ml.ClassificationUtils;
import com.arosbio.ml.PredictorBase;
import com.arosbio.ml.algorithms.ScoringClassifier;
import com.arosbio.ml.algorithms.impl.DefaultMLParameterSettings;
import com.arosbio.ml.algorithms.svm.SVC;
import com.arosbio.ml.interfaces.ClassificationPredictor;
import com.arosbio.ml.io.MetaFileUtils;
import com.arosbio.ml.io.impl.PropertyNameSettings;
import com.arosbio.ml.metrics.SingleValuedMetric;
import com.arosbio.ml.metrics.classification.LogLoss;
import com.arosbio.ml.sampling.SamplingStrategy;
import com.arosbio.ml.sampling.SamplingStrategyUtils;
import com.arosbio.ml.sampling.TrainSplit;
import com.arosbio.ml.sampling.TrainSplitGenerator;
import com.arosbio.ml.vap.ivap.IVAPClassifier;

public final class AVAPClassifier extends PredictorBase implements AVAP, ClassificationPredictor {

	private static final Logger LOGGER = LoggerFactory.getLogger(AVAPClassifier.class);
	private static final String CVAP_DIRECTORY_NAME = "cvap";
	private static final String CVAP_META_FILE_NAME = "meta.json";
	private static final String IVAP_BASE_FILE_NAME = "model";
	private final static String FEATURE_SCALING_FILE_NAME = "model.scale";
	public final static String PREDICTOR_TYPE = "CVAP Classification";

	private Map predictors = new HashMap<>();
	private ScoringClassifier scoringAlgorithm;
	private SamplingStrategy strategy;

	/* 
	 * =================================================
	 * 			CONSTRUCTORS
	 * =================================================
	 */

	public AVAPClassifier(){
		super();
	}

	public AVAPClassifier(ScoringClassifier mlImpl, SamplingStrategy strategy){
		super();
		this.scoringAlgorithm = mlImpl;
		this.strategy = strategy;
	}

	public AVAPClassifier(ScoringClassifier mlImpl, SamplingStrategy strategy, long seed){
		this(mlImpl, strategy);
		this.seed = seed;
	}

	@Override
	public AVAPClassifier clone(){
		AVAPClassifier clone = new AVAPClassifier();
		clone.strategy = strategy.clone();
		clone.scoringAlgorithm = this.scoringAlgorithm.clone();
		clone.seed=seed;

		// Copy all IVAPs 
		if (predictors != null)
			for (Integer i : predictors.keySet())
				clone.predictors.put(i, predictors.get(i).clone());

		return clone;		
	}

	/* 
	 * =================================================
	 * 			GETTERS AND SETTERS
	 * =================================================
	 */

	@Override
	public SingleValuedMetric getDefaultOptimizationMetric() {
		return new LogLoss();
	}

	@Override
	public String getPredictorType() {
		return PREDICTOR_TYPE;
	}

	/**
	 * Returns the number of trained models, which might be nrModels ≠ SamplingStrategy.nrModels
	 * @return the number of trained models 
	 */
	public int getNumTrainedPredictors(){
		if (predictors!=null)
			return predictors.size();
		return 0;
	}

	public Set getLabels(){
		if (predictors==null || predictors.isEmpty())
			return new HashSet<>();
		return predictors.values().iterator().next().getLabels();
	}


	public int getNumClasses() {
		return isTrained() ? 2 : -1;
	}

	@Override
	public SamplingStrategy getStrategy(){
		return strategy;
	}
	
	public ScoringClassifier getScoringAlgorithm() {
		return scoringAlgorithm;
	}

	@Override
	public boolean isPartiallyTrained() {
		return predictors!=null && !predictors.isEmpty();
	}

	@Override
	public boolean isTrained() {
		return predictors!=null && predictors.size() == strategy.getNumSamples();
	}

	public boolean holdsResources(){
		return ! predictors.isEmpty();
	}

	public boolean releaseResources(){
		if (predictors == null || predictors.isEmpty())
			return false;

		// Release all ICPs memory
		boolean state = true;
		for (IVAPClassifier icp : predictors.values()){
			state = MathUtils.keepFalse(state, icp.releaseResources());
		}
		// Drop references
		predictors.clear();
		return state;
	}

	/**
	 * Get which models have been trained (used for parallel training)
	 * @return The set of models that have been trained. Numbers can be in the range [1, total-number-of-models]
	 */
	public Set getModelsTrained(){
		return new HashSet<>(predictors.keySet());
	}

	public Map getModels(){
		return predictors;
	}


	@Override
	public Map getProperties() {
		Map params = new HashMap<>();
		params.putAll(strategy.getProperties());
		params.put(PropertyNameSettings.ML_SEED_VALUE_KEY, seed);
		params.put(PropertyNameSettings.PREDICTOR_ML_ALG_INFO_KEY, scoringAlgorithm.getProperties());
		params.put(PropertyNameSettings.ML_TYPE_KEY, PredictorType.VAP_CLASSIFICATION.getId());
		params.put(PropertyNameSettings.ML_TYPE_NAME_KEY, PredictorType.VAP_CLASSIFICATION.getName());
		params.put(PropertyNameSettings.IS_CLASSIFICATION_KEY, true);
		return params;
	}

	@Override
	public List getConfigParameters() {
		List params = new ArrayList<>();
		params.add(new ImplementationConfig.Builder<>(Arrays.asList(CONFIG_SAMPLING_STRATEGY_PARAM_NAME), SamplingStrategy.class).build());
		if (strategy != null)
			params.addAll(strategy.getConfigParameters());
		if (scoringAlgorithm != null)
			params.addAll(scoringAlgorithm.getConfigParameters());
		return params;
	}

	@Override
	public void setConfigParameters(Map params) throws IllegalArgumentException {
		// SAMPLING STRATEGY
		if (params.containsKey(CONFIG_SAMPLING_STRATEGY_PARAM_NAME)) {
			if (params.get(CONFIG_SAMPLING_STRATEGY_PARAM_NAME) instanceof SamplingStrategy) {
				this.strategy = (SamplingStrategy) params.get(CONFIG_SAMPLING_STRATEGY_PARAM_NAME);
			} else {
				throw new IllegalArgumentException("Parameter " + CONFIG_SAMPLING_STRATEGY_PARAM_NAME + " cannot take value: " + params.get(CONFIG_SAMPLING_STRATEGY_PARAM_NAME));
			}
		}

		// pass on to underlying classifier
		scoringAlgorithm.setConfigParameters(params);
	}

	@Override
	public int getNumObservationsUsed() {
		if(!isTrained())
			return 0;
		return predictors.values().iterator().next().getNumObservationsUsed();
	}

	/* 
	 * =================================================
	 * 			TRAIN
	 * =================================================
	 */
	/**
	 * Train the complete CVAP
	 * @throws IllegalArgumentException Too small dataset or invalid arguments 
	 */
	@Override
	public void train(Dataset problem) throws IllegalArgumentException {

		Iterator splits = strategy.getIterator(problem, seed);

		predictors=new HashMap<>();

		//Train the models
		int i=0, nrModels=strategy.getNumSamples();
		LOGGER.debug("Training CVAP Predictor with {} models", nrModels);

		while (splits.hasNext()){
			IVAPClassifier ivap = new IVAPClassifier(scoringAlgorithm.clone());
			TrainSplit nextDataset = splits.next();
			ivap.train(nextDataset);
			nextDataset.clear(); //explicitly clear all memory
			predictors.put(i, ivap);
			LOGGER.debug(" - Trained model {}/{}",(i+1), nrModels);
			i++;
		}

	}

	public void train(Dataset problem, int index) throws IllegalArgumentException {

		if (predictors == null)
			predictors = new HashMap<>();
		SamplingStrategyUtils.validateTrainSplitIndex(strategy, index);
		TrainSplitGenerator generator = strategy.getIterator(problem, seed);

		IVAPClassifier ivap = new IVAPClassifier(scoringAlgorithm.clone());
		TrainSplit split=null;
		try {
			split = generator.get(index);
		} catch(NoSuchElementException e) {
			LOGGER.debug("Tried to get a non-existing index split",e);
			throw new IllegalArgumentException("Cannot train index " + index + ", only allowed indexes are [0,"+(strategy.getNumSamples()-1)+"]");
		}
		ivap.train(split);
		split.clear(); //explicitly clear all memory
		predictors.put(index, ivap);

	}

	/* 
	 * =================================================
	 * 			PREDICT
	 * =================================================
	 */

	private void assertIsTrained(){
		if (! isTrained())
			throw new IllegalStateException("Predictor not trained");
	}

	public CVAPPrediction predict(FeatureVector example) 
			throws IllegalStateException {
		assertIsTrained();
		
		Integer label0=null, label1=null;
		final List p0s = new ArrayList<>(predictors.size()),
			p1s = new ArrayList<>(predictors.size()),
			intervalWidths = new ArrayList<>(predictors.size()),
			oneMinusP0 = new ArrayList<>(predictors.size());

		boolean firstIteration=true; // Check if initial stuff needs to be set

		for (IVAPClassifier ivap : predictors.values()){
			Map> interval = ivap.predict(example);

			if (firstIteration){
				// Chose label to compute for
				Set labels = interval.keySet();
				Iterator labelsIt = labels.iterator();
				label0 = labelsIt.next();
				label1 = labelsIt.next();
			}
			// Collect all info
			Pair p0p1 = interval.get(label0);
			final double p0 = p0p1.getLeft();
			final double p1 = p0p1.getRight();
			p0s.add(p0);
			p1s.add(p1);
			intervalWidths.add(p1-p0);
			oneMinusP0.add(1-p0);

			// Make sure init of params is not done again
			firstIteration=false;
		}

		// After all IVAPs
		double gmP1 = MathUtils.geometricMean(p1s);
		double gm1mP0 = MathUtils.geometricMean(oneMinusP0);
		double probability = gmP1/(gm1mP0+gmP1);
		if (Double.isNaN(probability) || Double.isNaN(gm1mP0) || Double.isNaN(gmP1)) {
			LOGGER.debug("CVAP probability calculation: gmP1={}, gm1mP0={}, prob={}",
					gmP1,gm1mP0,probability);
			if (gmP1 == 0 && gm1mP0 == 0)
				LOGGER.debug("gmP1 and gm1mP0 are both == 0!!, the full lists p1s={} and oneMinusP0={}",p1s,oneMinusP0);
		}

		Map probabilities = new HashMap<>();
		probabilities.put(label0, probability);
		probabilities.put(label1, 1-probability);

		final double meanIntervalWidth = MathUtils.mean(intervalWidths);
		final double medianIntervalWidth = MathUtils.median(intervalWidths);


		return new CVAPPrediction(p0s, p1s, label0, label1, probabilities, meanIntervalWidth,medianIntervalWidth);
	}

	@Override
	public List calculateGradient(FeatureVector example)
			throws IllegalStateException {
		return calculateGradient(example, DefaultMLParameterSettings.DEFAULT_STEPSIZE);
	}

	public List calculateGradient(FeatureVector example, int label)
			throws IllegalStateException {
		return calculateGradient(example, DefaultMLParameterSettings.DEFAULT_STEPSIZE, label);
	}

	@Override
	public List calculateGradient(FeatureVector example, double stepsize)
			throws IllegalStateException {

		// Find the most likely class - use that for computing gradient
		CVAPPrediction res = predict(example);
		int label = ClassificationUtils.getPredictedClass(res.getProbabilities());
		return calculateGradient(example, stepsize, label, res);
	}

	public List calculateGradient(FeatureVector example, double stepsize, int label)
			throws IllegalStateException {

		return calculateGradient(example, stepsize, label, predict(example));
	}

	private List calculateGradient(FeatureVector example, double stepsize, int label, CVAPPrediction prediction)
			throws IllegalStateException {
		if (! isTrained())
			throw new IllegalStateException("Predictor not trained yet");
		List> gradients = new ArrayList<>();

		for (IVAPClassifier model : predictors.values()){
			gradients.add(model.calculateGradient(example, stepsize, label));
		}

		return DataUtils.averageIdenticalIndices(gradients);
	}

	/* 
	 * =================================================
	 * 			SAVE / LOAD
	 * =================================================
	 */

	@Override
	public void saveToDataSink(DataSink sink, String basePath, EncryptionSpecification spec) 
			throws IOException, InvalidKeyException, IllegalStateException {
		// create the directory
		String cvapDir = DataIOUtils.createBaseDirectory(sink, basePath, CVAP_DIRECTORY_NAME+"/");
		LOGGER.debug("Saving AVAPClassifier to jar, loc={}", cvapDir);

		// write meta.json
		Map params = getProperties();
		try (OutputStream metaStream = sink.getOutputStream(cvapDir+CVAP_META_FILE_NAME)){
			MetaFileUtils.writePropertiesToStream(metaStream,getProperties());
		} catch (Exception e) {
			LOGGER.debug("Failed saving AVAP properties to stream", e);
			throw new IOException("Failed saving AVAPClassifier");
		}
		LOGGER.debug("Written CVAP Properties to jar: {}", params);

		// for each IVAP - write it
		int i=0;
		for (Entry ivap: predictors.entrySet()){
			LOGGER.debug("Attempting to write IVAP with id={} to dataSink",ivap.getKey());
			ivap.getValue().saveToDataSink(sink, cvapDir+IVAP_BASE_FILE_NAME+'.'+ivap.getKey(), spec);
			i++;
		}
		LOGGER.debug("Written {} IVAP models to jar",i);

	}

	public void loadFromDataSource(DataSource source, EncryptionSpecification spec) throws IOException, IllegalArgumentException, InvalidKeyException {
		loadFromDataSource(source, null, spec);
	}

	@Override
	public void loadFromDataSource(DataSource source, String basePath, EncryptionSpecification spec) 
			throws IOException, IllegalArgumentException, InvalidKeyException {

		String cvapDir = DataIOUtils.locateBasePath(source, basePath, CVAP_DIRECTORY_NAME+ "/");
		LOGGER.debug("loading AVAP from source, location={}",cvapDir);

		// Load meta.params
		try(
				InputStream metaDataStream = source.getInputStream(cvapDir+CVAP_META_FILE_NAME);
				){
			Map properties = MetaFileUtils.readPropertiesFromStream(metaDataStream);
			LOGGER.debug("cvap properties from meta-file: {}",properties);

			// Sampling strategy
			strategy = SamplingStrategyUtils.fromProperties(properties);
			seed = TypeUtils.asLong(properties.get(PropertyNameSettings.ML_SEED_VALUE_KEY));

		} catch (IOException e){
			LOGGER.debug("Could not read the cvap meta-file",e);
			throw new IOException(e);
		}
		LOGGER.debug("Loaded CVAP meta-file");

		// Load the IVAPS
		int nrModels = strategy.getNumSamples();
		predictors = new HashMap<>(nrModels);
		for (int i=0; i




© 2015 - 2024 Weber Informatics LLC | Privacy Policy