All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.arosbio.commons.CollectionUtils Maven / Gradle / Ivy

Go to download

Conformal AI package, including all data IO, transformations, machine learning models and predictor classes. Without inclusion of chemistry-dependent code.

There is a newer version: 2.0.0
Show newest version
/*
 * Copyright (C) Aros Bio AB.
 *
 * CPSign is an Open Source Software that is dual licensed to allow you to choose a license that best suits your requirements:
 *
 * 1) GPLv3 (GNU General Public License Version 3) with Additional Terms, including an attribution clause as well as a limitation to use the software for commercial purposes.
 *
 * 2) CPSign Proprietary License that allows you to use CPSign for commercial activities, such as in a revenue-generating operation or environment, or integrate CPSign in your proprietary software without worrying about disclosing the source code of your proprietary software, which is required if you choose to use the software under GPLv3 license. See arosbio.com/cpsign/commercial-license for details.
 */
package com.arosbio.commons;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Random;
import java.util.Set;
import java.util.stream.Collectors;

import org.apache.commons.lang3.tuple.Pair;
import org.apache.commons.math3.stat.descriptive.SummaryStatistics;

import com.google.common.collect.Range;

public class CollectionUtils {
	
	/**
	 * IndexedValues are sorted descending depending on their value. 
	 * Used e.g. by the FeatureSelectors for keeping the indices with most variance/importance 
	 * @author staffan
	 *
	 */
	public static class IndexedValue implements Cloneable, Comparable {
		public final int index;
		public final double value;
	
		public IndexedValue(int index, double value) {
			this.index = index;
			this.value = value;
		}
	
		public IndexedValue withValue(double newValue){
			return new IndexedValue(index, newValue);
		}
	
		@Override
		public int compareTo(IndexedValue o) {
			int cmp = Double.compare(this.value, o.value);
			return cmp != 0 ? cmp :  this.index - o.index ;
		}
	
		public String toString() {
			return String.format("%d:%s",index,value);
		}
		
		public IndexedValue clone() {
			return new IndexedValue(index, value);
		}
	
	}

	public static int countTrueValues(Collection values) {
		int sum = 0;
		for (boolean b : values) {
			if (b)
				sum++;
		}
		return sum;
	}

	/**
	 * Uses the {@link #listRange(double, double, double)} using default {@code step} of 1
	 * @param start the start value
	 * @param end the end value
	 * @return a List of values between start and end
	 */
	public static List listRange(double start, double end){
		return listRange(start,end, 1d);
	}

	public static List listRange(double start, double end, double step){
		// Special treat if start and end equals
		if (Math.abs(start-end)<0.00000001){
			return Arrays.asList(start);
		}
		
		// validate input
		if (Math.abs(step) <= 0.0001){
			throw new IllegalArgumentException("step parameter cannot be 0");
		}
		if (start > end && step >= 0){
			throw new IllegalArgumentException(String.format("Invalid range enumeration {start=%s,end=%s,step=%s}", start,end,step));
		} else if (start < end && step <= 0){
			throw new IllegalArgumentException(String.format("Invalid range enumeration {start=%s,end=%s,step=%s}", start,end,step));
		}

		if (start < end){

			if ( (end-start) / step > 1000 ) {
				throw new IllegalArgumentException("Not allowed to create a list range with more than 1000 entries");
			}

			List result = new ArrayList<>((int) ((end-start)/step));

			int i=0;
			double nextValue = start;
			while (i < 1001 && (nextValue < end || MathUtils.equals(nextValue, end))) {
				result.add(nextValue);
				i++;
				nextValue = start + i*step;
			}
			return result;
		} else {
			// start > end (equals treated before)
			if ( (start-end) / step > 1000 ) {
				throw new IllegalArgumentException("Not allowed to create a list range with more than 1000 entries");
			}

			List result = new ArrayList<>((int) ((end-start)/step));

			int i=0;
			double nextValue = start;
			while (i < 1001 && (nextValue > end || MathUtils.equals(nextValue, end))) {
				result.add(nextValue);
				i++;
				nextValue = start + i*step;
			}
			return result;

		}

	}
	
	public static List listRange(double start, double stop, double step, double base){
		return listRange(start,stop,step)
				.stream()
				.map(e -> Math.pow(base, e))
				.collect(Collectors.toList());
	}

	public static List listRange(int start, int end){
		return listRange(start, end, 1);
	}

	public static List listRange(int start, int end, int step){
		// Special treat if start and end equals
		if (start==end){
			return Arrays.asList(start);
		}
		// validate input
		if (Math.abs(step) <= 0.0001){
			throw new IllegalArgumentException("step parameter cannot be 0");
		}
		if (start > end && step >= 0){
			throw new IllegalArgumentException(String.format("Invalid range enumeration {start=%s,end=%s,step=%s}", start,end,step));
		} else if (start < end && step <= 0){
			throw new IllegalArgumentException(String.format("Invalid range enumeration {start=%s,end=%s,step=%s}", start,end,step));
		}

		if (start < end){

			if ( (double)(end-start) / step > 1000 ) {
				throw new IllegalArgumentException("Not allowed to create a list range with more than 1000 entries");
			}

			List result = new ArrayList<>((int) ((end-start)/step));

			for (int i=start; i<=end; i+= step) {
				result.add(i);
			}

			return result;
		} else {
			// start > end
			if ( (double)(start-end) / step > 1000 ) {
				throw new IllegalArgumentException("Not allowed to create a list range with more than 1000 entries");
			}

			List result = new ArrayList<>((int) ((end-start)/step));
			for (int i=start; i>=end; i+= step) {
				result.add(i);
			}
			
			return result;

		}

	}

	public static  Pair,List> splitRandomly(List input, int numInSecond, long seed){
		List first = new ArrayList<>(input), // Shallow copy all to the first
			second = new ArrayList<>(numInSecond);
		
		Random rng = new Random(seed);
		for (int i=0; i input) {
		for (Double n : input) {
			if (n == null || n.isNaN()) {
				return true;
			}
		}
		return false;
	}
	
	public static  boolean containsNull(Collection input) {
		for (T t : input) {
			if (t == null)
				return true;
		}
		return false;
	}
	
	public static List filterNullOrNaN(Collection input) {
		List res = new ArrayList<>();
		for (Double n : input) {
			if (n != null && Double.isFinite(n))
				res.add(n);
		}
		return res;
	}
	
	public static double[] toArray(List input) {
		if (input == null || input.isEmpty())
			return new double[0];
		double[] arr = new double[input.size()];
		for (int i=0; i input) {
		if (input == null || input.isEmpty())
			return new int[0];
		int[] arr = new int[input.size()];
		for (int i=0; i Map countFrequencies(Collection input){
		Map freqs = new HashMap<>();

		for (T t : input) {
			freqs.put(t, freqs.getOrDefault(t, 0)+1);
		}

		return freqs;
	}
	
	public static int countValuesSmallerThan(Collection vals, int threshold) {
		int count=0;
		for (int v : vals) {
			if (v < threshold)
				count++;
		}
		return count;
	}

	public static > List getUnique(Collection input){
		Set asSet = new LinkedHashSet<>(input); // Keep ordering using linked hash set
		List asList = new ArrayList<>(asSet);
		return asList;
	}

	public static > List getUniqueAndSorted(Collection input){
		return input.stream()
			.distinct()
			.sorted()
			.collect(Collectors.toCollection(ArrayList::new));
	}

	public static >
	boolean isSorted(Iterable iterable) {
		Iterator iter = iterable.iterator();
		if (!iter.hasNext()) {
			return true;
		}
		T t = iter.next();
		while (iter.hasNext()) {
			T t2 = iter.next();
			if (t.compareTo(t2) > 0) {
				return false;
			}
			t = t2;
		}
		return true;
	}

	public static > List sort(Collection c){
		List list = new ArrayList<>(c);
		Collections.sort(list);
		return list;
	}
	
	public static  List toObjectList(List list){
		List resList = new ArrayList<>();
		for (T element: list) {
			resList.add((Object)element);
		}
		return resList;
	}

	@SuppressWarnings("unchecked")
	public static  List toList(List l) {
		List r = new ArrayList<>();
		for (T e: l) {
			r.add((T2) e);
		}
		return r;
	}

	public static List toNumberList(List list){
		List resList = new ArrayList<>();
		for (T elem: list) {
			resList.add((Number) elem);
		}
		return resList;
	}

	public static List arrToList(double[] arr){
		List lst =  new ArrayList<>(arr.length);
		for(int i=0; i arrToList(int[] arr){
		List lst = new ArrayList<>(arr.length);
		for(int i=0; i c) {
		double sum = 0;
		for (Double d : c) {
			sum+= d;
		}
		return sum;
	}
	
	public static int sumInts(Collection c) {
		int sum = 0;
		for (int val : c) {
			sum+= val;
		}
		return sum;
	}

	public static  Set getUniqueFromFirstSet(Set s1, Set s2){
		Set uniques = new HashSet<>();
		Iterator iter = s1.iterator();
		while(iter.hasNext()){
			T item = iter.next();
			if(! s2.contains(item)){
				uniques.add(item);
			}
		}
		return uniques;
	}


	public static  List> partitionStatic(List input, int numberPerList) {
		if (input == null || input.isEmpty())
			throw new IllegalArgumentException("empty list cannot be partitioned");
		if (numberPerList <= 0)
			throw new IllegalArgumentException("Number per list cannot be 0 or less");

		List> result = new ArrayList<>();
		int currIndex = 0;
		while (currIndex + numberPerList < input.size()) {
			result.add(input.subList(currIndex, currIndex+numberPerList));
			currIndex+=numberPerList;
		}
		// Add the rest (if any)
		if (currIndex < input.size())
			result.add(input.subList(currIndex, input.size()));

		return result;
	}

	public static  List> partition(List input, int folds) {
		if (input == null || input.isEmpty())
			throw new IllegalArgumentException("empty list cannot be partitioned");
		if (folds <= 0)
			throw new IllegalArgumentException("Number of folds cannot be 0 or less");

		List indexes = getFoldSplits(input, folds);
		List> partitions = new ArrayList<>();

		for (int i=1; i List getFoldSplits(List list, int folds){
		// Decide the start and end-indexes 
		int defaultFoldSize = (int) Math.floor(((double)list.size())/folds);
		if (defaultFoldSize < 1)
			throw new IllegalArgumentException("Using " + folds + " folds on a list of size " + list.size() + " give less than 1 record per list");
		int recordsLeftToAssign = list.size() - folds*defaultFoldSize;

		int currentSplitIndex = 0;
		List indexSplits = new ArrayList<>();
		indexSplits.add(currentSplitIndex);
		for(int i=0; i0) {
				currentSplitIndex++;
				recordsLeftToAssign--;
			}
			indexSplits.add(currentSplitIndex);
		}
		indexSplits.add(list.size());

		return indexSplits;
	}

	

	public static  List> getDisjunctSets(List list, int splits, boolean allowEmptySets){
		if (splits < 2)
			throw new IllegalArgumentException("Number of folds must be >=2");
		if (! allowEmptySets && splits > list.size())
			throw new IllegalArgumentException("Cannot create " + splits + " out of " + list.size() + " records");

		int defaultFoldSize = (int) Math.floor(((double)list.size())/splits);
		int recordsLeftToAssign = list.size() - splits*defaultFoldSize;

		List> sets = new ArrayList<>();

		int start = 0, end = defaultFoldSize;
		for (int i=0; i 0) {
				end++;
				recordsLeftToAssign--;
			}
			sets.add(list.subList(start, end));
			start = end;
		}
		// The last fold is simply the remaining indices
		sets.add(list.subList(start, list.size()));

		return sets;

	}

	public static > List getSortedIndicesBySize(List lst, boolean ascending){
		List indices = new ArrayList<>(lst.size());

		for (int i=0; i i.index).collect(Collectors.toList());
	}


	public static  List sortBy(List input, List indices) 
		throws IllegalArgumentException, NullPointerException, IndexOutOfBoundsException {
		Objects.requireNonNull(input);
		Objects.requireNonNull(indices);
		if (input.size() != indices.size())
			throw new IllegalArgumentException("List to sort and indices must have same length");
		List sorted = new ArrayList<>();
		
		for (int index : indices){
			sorted.add(input.get(index));
		}

		return sorted;
	}

	public static  List getIndices(List input, List indices) 
		throws IllegalArgumentException, NullPointerException, IndexOutOfBoundsException {
		Objects.requireNonNull(input);
		Objects.requireNonNull(indices);

		List sorted = new ArrayList<>();		
		for (int index : indices){
			sorted.add(input.get(index));
		}

		return sorted;
	}


	public static Map toStringKeys(Map map){
		Map asStr = new HashMap<>();
		for (Map.Entry ent: map.entrySet())
			if (ent.getKey() instanceof String)
				asStr.put((String)ent.getKey(), ent.getValue());
		return asStr;
	}

	public static Object getArbitratyDepth(Map map, Object key, Object defaultValue) {
		Object res = getArbitratyDepth(map, key);
		return (res!=null? res : defaultValue);
	}
	/**
	 * Get the Value stored for a given Key, or null if not present
	 * @param map a map, possibly with multiple levels
	 * @param key key to look for
	 * @return The value associated with the key, or null
	 */
	@SuppressWarnings("unchecked")
	public static Object getArbitratyDepth(Map map, Object key) {
		List> nextLevels = new ArrayList<>(); 
		for (Map.Entry entry: map.entrySet()) {
			if (entry.getKey().equals(key))
				return entry.getValue();
			else if (entry.getValue() instanceof Map)
				nextLevels.add((Map)entry.getValue());
		}

		// Check next levels
		for (Map m : nextLevels) {
			Object value = getArbitratyDepth(m, key);
			if (value != null)
				return value;
		}

		return null;
	}
	
	@SuppressWarnings("unchecked")
	public static boolean containsKeyArbitraryDepth(Map map, Object key) {
		for (Map.Entry entry : map.entrySet()) {
			if (entry.getKey().equals(key)) {
				return true;
			} else if (entry.getValue() instanceof Map) {
				boolean childContains = containsKeyArbitraryDepth((Map) entry.getValue(), key);
				if (childContains)
					return true;
			}
		}
		return false;
	}

	@SuppressWarnings("unchecked")
	public static void removeAtArbitraryDepth(Map map, Object key) {
		// remove all at this level
		map.remove(key);
		// go deeper
		for (Object value: map.values()) {
			if (value instanceof Map)
				removeAtArbitraryDepth((Map)value, key);
		}
	}

	public static  List toList(Iterable iter){
		List result = new ArrayList<>();
		for(T i : iter){
			result.add(i);
		}
		return result;
	}

	public static  List toList(Iterator iter){
		List result = new ArrayList<>();
		while(iter.hasNext()){
			result.add(iter.next());
		}
		return result;
	}

	/**
	 * Performs a shallow replication, the same object in all 
	 * @param obj object to replicate
	 * @param num number of repeats
	 * @param  the type of {@code obj}
	 * @return a list with {@code obj} replicated {@code num} times
	 */
	@SuppressWarnings("unchecked")
	public static  List rep(T obj, int num){
		List list = new ArrayList<>();

		if (obj instanceof Double)
			for (int i=0; i iter) {
		int len = 0;
		while(iter.hasNext()) {
			iter.next();
			len++;
		}
		return len;
	}

	public static  Map dropNullValues(Map map){
		if (map == null || map.isEmpty())
			return map;
		
		Map noNull = new HashMap<>(map.size());

		for (Map.Entry kv : map.entrySet()){
			if (kv.getValue() != null && ! "null".equalsIgnoreCase(kv.getValue().toString())) {
				noNull.put(kv.getKey(), kv.getValue());
			}
		}

		return noNull;
	}
	
	public static SummaryStatistics getStatistics(Collection col) {
		SummaryStatistics ss = new SummaryStatistics();
		for (Number n: col) {
			ss.addValue(n.doubleValue());
		}
		return ss;
	}
	
	public static >  boolean rangeHasNoBounds(Range range) {
		return ! range.hasLowerBound() && ! range.hasUpperBound();
	}
	
	public static boolean containsIgnoreCase(String[] list, String key) {
		return containsIgnoreCase(Arrays.asList(list), key);
	}
	
	public static boolean containsIgnoreCase(Collection list, String key) {
		for (String s : list) {
			if (s.equalsIgnoreCase(key))
				return true;
		}
		return false;
	}

}