All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.arosbio.commons.MathUtils Maven / Gradle / Ivy

Go to download

Conformal AI package, including all data IO, transformations, machine learning models and predictor classes. Without inclusion of chemistry-dependent code.

There is a newer version: 2.0.0
Show newest version
/*
 * Copyright (C) Aros Bio AB.
 *
 * CPSign is an Open Source Software that is dual licensed to allow you to choose a license that best suits your requirements:
 *
 * 1) GPLv3 (GNU General Public License Version 3) with Additional Terms, including an attribution clause as well as a limitation to use the software for commercial purposes.
 *
 * 2) CPSign Proprietary License that allows you to use CPSign for commercial activities, such as in a revenue-generating operation or environment, or integrate CPSign in your proprietary software without worrying about disclosing the source code of your proprietary software, which is required if you choose to use the software under GPLv3 license. See arosbio.com/cpsign/commercial-license for details.
 */
package com.arosbio.commons;

import java.math.BigDecimal;
import java.math.MathContext;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Objects;
import java.util.Optional;

import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.commons.math3.stat.descriptive.rank.Percentile;

import com.google.common.collect.Range;

public class MathUtils {

	public static double DEFAULT_EPS = 0.00001;
	
	public static boolean equals(double v1, double v2) {
		return equals(v1, v2, DEFAULT_EPS);
	}

	public static boolean equals(double v1, double v2, double eps) {
		// If both values are finite valued
		if (Double.isFinite(v1) && Double.isFinite(v2)){
			return Math.abs(v1-v2) < eps;
		} else if (Double.isFinite(v1) != Double.isFinite(v2)){
			// if one is finite and the other is not - clearly not the same
			return false;
		}
		// Here if both are infinity or NaN valued
		if (Double.isNaN(v1) && Double.isNaN(v2))
			return true;
		if (Double.isInfinite(v1) && Double.isInfinite(v2)) {
			return (v1>0 && v2>0) ||(v1<0 && v2<0); 
		}
		// Means one is NaN and the other is +/- inf
		return false;
	}

	@SuppressWarnings("null")
	public static boolean equalsDouble(Double d1, Double d2){
		if (d1 == null && d2 == null)
			return true;
		if (d1 == null && d2 != null)
			return false;
		if (d1 != null && d2 == null)
			return false;
		return equals((double)d1, (double) d2, DEFAULT_EPS);
	}

	/**
	 * Calculate the factorial of n.
	 *
	 * @param n the number to calculate the factorial of.
	 * @return n! - the factorial of n.
	 */
	public static int fact(int n) {

		// Base Case: 
		//    If n <= 1 then n! = 1.
		if (n <= 1) {
			return 1;
		}
		// Recursive Case:  
		//    If n > 1 then n! = n * (n-1)!
		else {
			return n * fact(n-1);
		}
	}

	public static boolean containsNanOrInf(double... values){
		// if empty
		if (values.length ==0)
			return false;
		for (int i=0;i roundTo3significantFigures(Pair val){
		return ImmutablePair.of(roundTo3significantFigures(val.getLeft()), roundTo3significantFigures(val.getRight()));
	}

	public static Range roundTo3significantFigures(Range val){
		return Range.range(roundTo3significantFigures(val.lowerEndpoint()), val.lowerBoundType(),
				roundTo3significantFigures(val.upperEndpoint()), val.upperBoundType());
	}

	public static double roundTo3significantFigures(double val){
		return roundToNSignificantFigures(val, 3);
	}

	public static double roundToNSignificantFigures(double val, int n){
		try {
			BigDecimal bd = new BigDecimal(val);
			bd = bd.round(new MathContext(n));
			return bd.doubleValue();
		} catch (NumberFormatException e) {
			return val;
		}
	}

	public static double median(List values){
		return median(CollectionUtils.toArray(values));
	}

	public static double median(double... values){
		if (values==null || values.length==0)
			return Double.NaN;
		if (values.length==1)
			return values[0];

		double median = new Percentile().evaluate(values, 50);
		
		if (Double.isNaN(median)) {
			// Check if all values are positive/negative infinity
			int numPosInf = 0, numNegInf = 0;

			for (Double d : values) {
				if (!Double.isInfinite(d)) {
					break;
				}
				if (d >0)
					numPosInf++;
				else if (d < 0)
					numNegInf++;
			}
			if (numPosInf > 0 && numNegInf == 0)
				return Double.POSITIVE_INFINITY;
			else if (numNegInf > 0 && numPosInf == 0)
				return Double.NEGATIVE_INFINITY;
		}
		
		return median;
	}

	public static > int findMaxIndex(List values){
		int maxInd = 0;
		for (int i=1; i 0 ?  maxInd : i;
		}
		return maxInd;
	}

	public static  List filterNull(List list) {
		List filtered = new ArrayList<>();
		for(T elem: list){
			if(elem != null)
				filtered.add(elem);
		}
		return filtered;
	}

	public static int sumInts(Collection values) {
		return values.stream().mapToInt(i->i).sum();
	}

	public static double sumDoubles(Collection values) {
		return values.stream().mapToDouble(i->i).sum();
	}

	/**
	 * Calculate the average of a collection of values, using an iterative
	 * approach
	 * @param  the type 
	 * @param values values to average
	 * @return the average
	 * @see Iterative mean reference
	 */
	public static  double mean(Collection values) {
		Objects.requireNonNull(values, "cannot calculate mean on null");
		if (values.isEmpty())
			return Double.NaN;

		double avg = 0d;

		int t = 1;
		for (T v : values){
			double dv = v.doubleValue();
			if (Double.isInfinite(dv) || Double.isNaN(dv)) {
				return calcMeanWithInfsOrNaN(values.stream().mapToDouble(d -> d.doubleValue()).toArray());
			}
			avg += (v.doubleValue() - avg) / t;
			t++;
		}
		
		return avg;
	}

	// Utility method in case there are any Infs or NaN values in the input
	private static double calcMeanWithInfsOrNaN(double[] values){
		boolean containsPosInf = false;
		boolean containsNegInf = false;
		for (double v : values){
			if (Double.isNaN(v)){
				return Double.NaN;
			} else if (Double.isInfinite(v)){
				if (v>0)
					containsPosInf = true;
				else
					containsNegInf = true;
			}
		}
		if (containsNegInf && containsPosInf){
			return Double.NaN;
		} else if (containsNegInf){
			return Double.NEGATIVE_INFINITY;
		} 
		return Double.POSITIVE_INFINITY;

	}

	public static double mean(double... values) {
		Objects.requireNonNull(values, "cannot calculate mean on null");
		if (values.length == 0)
			return Double.NaN;
		
		double avg = 0d;
		int t = 1;
		for (double v : values){
			if (Double.isNaN(v) || Double.isInfinite(v)){
				return calcMeanWithInfsOrNaN(values);
			}
			avg += (v - avg) / t;
			t++;
		}

		return avg;
	}


	public static > T min(Collection values) {
		if (values.isEmpty())
			return null;
		Iterator iterator = values.iterator();
		T currMin = iterator.next();
		while(iterator.hasNext()) {
			T val = iterator.next();
			if (val.compareTo(currMin) < 0)
				currMin = val;
		}
		return currMin;
	}

	public static > T max(Collection values) {
		if (values.isEmpty())
			return null;
		Iterator iterator = values.iterator();
		T currMax = iterator.next();
		while(iterator.hasNext()) {
			T val = iterator.next();
			if (val.compareTo(currMax) > 0)
				currMax = val;
		}
		return currMax;
	}


	public static int max(int[] values) throws IllegalArgumentException {
		if (values==null || values.length==0)
			throw new IllegalArgumentException("No values given");
		int tmpMax = values[0];
		for (int i=1; itmpMax)
				tmpMax = values[i];
		}
		return tmpMax;
	}

	public static double truncate(double val, double min, double max) {
		return Math.max(min, Math.min(max, val));
	}


	public static double geometricMean(Collection values){
		Optional avgLog = avgLogs(values);
		if (avgLog.isEmpty())
			return 0;
		return Math.exp(avgLog.get());
	}

	/**
	 * Helper method for {@link #geometricMean(Collection)}, to make a single
	 * pass over all values and compute the average of the logarithm of the values.
	 * Calculates the average using an iterative approach to avoid overflow/underflow issues.
	 * Note that computing when one value is ==0 the log is not defined, then {@code Optional.empty()} is returned
	 * @param values values to average
	 * @return the average of the log of each value
	 */
	protected static Optional avgLogs(Collection values){
		Objects.requireNonNull(values, "cannot calculate mean on null");
		if (values.isEmpty())
			return Optional.empty();

		double avg = 0d;

		int t = 1;
		for (double v : values){
			if (v == 0)
				return Optional.empty();
			avg += (Math.log(v) - avg) / t;
			t++;
		}
		return Optional.of(avg);
	}

	public static int multiplyAllTogetherInt(Collection values){
		int res = 1;
		for (int v : values){
			res *= v;
		}
		return res;
	}

	public static Map roundValues(Map inputMap){
		Map resultMap = new HashMap<>();

		for(Entry entry: inputMap.entrySet()){
			resultMap.put(entry.getKey(), roundTo3significantFigures(entry.getValue()));
		}

		return resultMap;
	}

	public static  Map roundAllValues(Map inputMap){
		Map resultMap = new LinkedHashMap<>();

		for (Map.Entry entry: inputMap.entrySet()){
			V val = entry.getValue();
			if (val instanceof Double || val instanceof Float)
				resultMap.put(entry.getKey(), roundTo3significantFigures((Double)val));
			else if (val instanceof Map)
				resultMap.put(entry.getKey(), roundAllValues((Map)val));
			else if (val instanceof List) {
				resultMap.put(entry.getKey(), roundAllValues((List)val));
			} else
				resultMap.put(entry.getKey(), val);
		}

		return resultMap;
	}
	
	public static Map roundAll(Map input){
		Map result = new LinkedHashMap<>();;
		for (Map.Entry kv: input.entrySet()){
			result.put(kv.getKey(), roundTo3significantFigures(kv.getValue()));
		}
		
		return result;
	}

	public static  List roundAllValues(List list){
		List resList = new ArrayList<>();
		for (T val: list) {
			if (val instanceof Double)
				resList.add(roundTo3significantFigures((Double)val));
			else if (val instanceof Float) 
				resList.add(roundTo3significantFigures((Double)val));
			else if (val instanceof Map)
				resList.add(roundAllValues((Map)val));
			else if (val instanceof List)
				resList.add(roundAllValues((List)val));
			else
				resList.add(val);
		}

		return resList;
	}


	public static  Map normalizeMap(Map colorMap, double rangeStart, double rangeEnd){
		if(rangeStart >= rangeEnd)
			throw new IllegalArgumentException("The lower range cannot be larger or equal to the upper range");
		double posInterval = (rangeStart>0? rangeEnd-rangeStart : Math.abs(rangeEnd));
		double negInterval = (rangeEnd < 0? Math.abs(rangeEnd) + Math.abs(rangeStart): Math.abs(rangeStart));

		for(Entry entry: colorMap.entrySet()){
			double value = entry.getValue();
			double newVal;

			if(value > rangeEnd)
				newVal = 1.0; // Cap upper
			else if(value < rangeStart)
				newVal= -1.0; // Cap lower
			else if(value>= 0 && rangeStart>0)
				newVal = (value-rangeStart)/posInterval;
			else if(value>= 0)
				newVal= value/posInterval;
			else if(value < 0 && rangeEnd <0)
				newVal = (value-rangeEnd)/negInterval;
			else
				newVal = value/negInterval;
			colorMap.put(entry.getKey(), newVal);
		}

		return colorMap;
	}

	public static  boolean equals(Map m, Map m2, double delta){
		if(m.size() != m2.size())
			return false;
		for(K key: m.keySet()){
			if(!m2.containsKey(key))
				return false;
			if(Math.abs(m.get(key)-m2.get(key))>delta)
				return false;
		}
		return true;
	}

	public static boolean keepFalse(boolean a, boolean b){
		return ! (!a || !b);
	}

}