All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.linalg.util.ArrayUtil Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
/*******************************************************************************
 * Copyright (c) 2015-2018 Skymind, Inc.
 *
 * This program and the accompanying materials are made available under the
 * terms of the Apache License, Version 2.0 which is available at
 * https://www.apache.org/licenses/LICENSE-2.0.
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 *
 * SPDX-License-Identifier: Apache-2.0
 ******************************************************************************/

package org.nd4j.linalg.util;

import com.google.common.primitives.Ints;
import com.google.common.primitives.Longs;
import lombok.val;
import org.apache.commons.lang3.RandomUtils;
import org.nd4j.base.Preconditions;

import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.lang.reflect.Array;
import java.nio.ByteBuffer;
import java.util.*;

/**
 * @author Adam Gibson
 */
public class ArrayUtil {


    private ArrayUtil() {}


    /**
     * Returns true if any array elements are negative.
     * If the array is null, it returns false
     * @param arr the array to test
     * @return
     */
    public static boolean containsAnyNegative(int[] arr) {
        if(arr == null)
            return false;

        for(int i = 0; i < arr.length; i++) {
            if(arr[i] < 0)
                return true;
        }
        return false;
    }

    public static boolean containsAnyNegative(long[] arr) {
        if(arr == null)
            return false;

        for(int i = 0; i < arr.length; i++) {
            if(arr[i] < 0)
                return true;
        }
        return false;
    }

    /**
     *
     * @param arrs
     * @param check
     * @return
     */
    public static boolean anyLargerThan(int[] arrs, int check) {
        for(int i = 0; i < arrs.length; i++) {
            if(arrs[i] > check)
                return true;
        }

        return false;
    }


    /**
     *
     * @param arrs
     * @param check
     * @return
     */
    public static boolean anyLessThan(int[] arrs, int check) {
        for(int i = 0; i < arrs.length; i++) {
            if(arrs[i] < check)
                return true;
        }

        return false;
    }


    /**
     * Convert a int array to a string array
     * @param arr the array to convert
     * @return the equivalent string array
     */
    public static String[] convertToString(int[] arr) {
        Preconditions.checkNotNull(arr);
        String[] ret = new String[arr.length];
        for(int i = 0; i < arr.length; i++) {
            ret[i] = String.valueOf(arr[i]);
        }

        return ret;
    }


    /**
     * Proper comparison contains for list of int
     * arrays
     * @param list the to search
     * @param target the target int array
     * @return whether the given target
     * array is contained in the list
     */
    public static boolean listOfIntsContains(List list,int[] target) {
        for(int[] arr : list)
            if(Arrays.equals(target,arr))
                return true;
        return false;
    }

    /**
     * Repeat a value n times
     * @param n the number of times to repeat
     * @param toReplicate the value to repeat
     * @return an array of length n filled with the
     * given value
     */
    public static int[] nTimes(int n, int toReplicate) {
        int[] ret = new int[n];
        Arrays.fill(ret, toReplicate);
        return ret;
    }

    public static long[] nTimes(long n, long toReplicate) {
        // FIXME: int cast
        val ret = new long[(int) n];
        Arrays.fill(ret, toReplicate);
        return ret;
    }

    /**
     * Returns true if all of the elements in the
     * given int array are unique
     * @param toTest the array to test
     * @return true if all o fthe items
     * are unique false otherwise
     */
    public static boolean allUnique(int[] toTest) {
        Set set = new HashSet<>();
        for (int i : toTest) {
            if (!set.contains(i))
                set.add(i);
            else
                return false;
        }

        return true;
    }

    /**
     * Credit to mikio braun from jblas
     * 

* Create a random permutation of the numbers 0, ..., size - 1. *

* see Algorithm P, D.E. Knuth: The Art of Computer Programming, Vol. 2, p. 145 */ public static int[] randomPermutation(int size) { Random r = new Random(); int[] result = new int[size]; for (int j = 0; j < size; j++) { result[j] = j + 1; } for (int j = size - 1; j > 0; j--) { int k = r.nextInt(j); int temp = result[j]; result[j] = result[k]; result[k] = temp; } return result; } public static short toHalf(float data) { return fromFloat(data); } public static short toHalf(double data) { return fromFloat((float) data); } public static short[] toHalfs(float[] data) { short[] ret = new short[data.length]; for (int i = 0; i < ret.length; i++) { ret[i] = fromFloat(data[i]); } return ret; } public static short[] toHalfs(int[] data) { short[] ret = new short[data.length]; for (int i = 0; i < ret.length; i++) { ret[i] = fromFloat((float) data[i]); } return ret; } public static short[] toHalfs(long[] data) { short[] ret = new short[data.length]; for (int i = 0; i < ret.length; i++) { ret[i] = fromFloat((float) data[i]); } return ret; } public static short[] toHalfs(double[] data) { short[] ret = new short[data.length]; for (int i = 0; i < ret.length; i++) { ret[i] = fromFloat((float) data[i]); } return ret; } public static short fromFloat(float v) { if (Float.isNaN(v)) return (short) 0x7fff; if (v == Float.POSITIVE_INFINITY) return (short) 0x7c00; if (v == Float.NEGATIVE_INFINITY) return (short) 0xfc00; if (v == 0.0f) return (short) 0x0000; if (v == -0.0f) return (short) 0x8000; if (v > 65504.0f) return 0x7bff; // max value supported by half float if (v < -65504.0f) return (short) (0x7bff | 0x8000); if (v > 0.0f && v < 5.96046E-8f) return 0x0001; if (v < 0.0f && v > -5.96046E-8f) return (short) 0x8001; final int f = Float.floatToIntBits(v); return (short) (((f >> 16) & 0x8000) | ((((f & 0x7f800000) - 0x38000000) >> 13) & 0x7c00) | ((f >> 13) & 0x03ff)); } public static int[] toInts(float[] data) { int[] ret = new int[data.length]; for (int i = 0; i < ret.length; i++) ret[i] = (int) data[i]; return ret; } public static int[] toInts(double[] data) { int[] ret = new int[data.length]; for (int i = 0; i < ret.length; i++) ret[i] = (int) data[i]; return ret; } public static int[] toInts(long[] array) { int[] retVal = new int[array.length]; for (int i = 0; i < array.length; i++) { retVal[i] = (int) array[i]; } return retVal; } public static int[] mod(int[] input,int mod) { int[] ret = new int[input.length]; for(int i = 0; i < ret.length; i++) { ret[i] = input[i] % mod; } return ret; } /** * Calculate the offset for a given stride array * @param stride the stride to use * @param i the offset to calculate for * @return the offset for the given * stride */ public static int offsetFor(int[] stride, int i) { int ret = 0; for (int j = 0; j < stride.length; j++) ret += (i * stride[j]); return ret; } /** * Sum of an int array * @param add the elements * to calculate the sum for * @return the sum of this array */ public static int sum(List add) { if (add.size() < 1) return 0; int ret = 0; for (int i = 0; i < add.size(); i++) ret += add.get(i); return ret; } /** * Sum of an int array * @param add the elements * to calculate the sum for * @return the sum of this array */ public static int sum(int[] add) { if (add.length < 1) return 0; int ret = 0; for (int i = 0; i < add.length; i++) ret += add[i]; return ret; } public static long sumLong(long... add) { if (add.length < 1) return 0; int ret = 0; for (int i = 0; i < add.length; i++) ret += add[i]; return ret; } /** * Product of an int array * @param mult the elements * to calculate the sum for * @return the product of this array */ public static int prod(List mult) { if (mult.size() < 1) return 0; int ret = 1; for (int i = 0; i < mult.size(); i++) ret *= mult.get(i); return ret; } /** * Product of an int array * @param mult the elements * to calculate the sum for * @return the product of this array */ public static int prod(long... mult) { if (mult.length < 1) return 0; int ret = 1; for (int i = 0; i < mult.length; i++) ret *= mult[i]; return ret; } /** * Product of an int array * @param mult the elements * to calculate the sum for * @return the product of this array */ public static int prod(int... mult) { if (mult.length < 1) return 0; int ret = 1; for (int i = 0; i < mult.length; i++) ret *= mult[i]; return ret; } /** * Product of an int array * @param mult the elements * to calculate the sum for * @return the product of this array */ public static long prodLong(List mult) { if (mult.size() < 1) return 0; long ret = 1; for (int i = 0; i < mult.size(); i++) ret *= mult.get(i).longValue(); return ret; } /** * Product of an int array * @param mult the elements * to calculate the sum for * @return the product of this array */ public static long prodLong(int... mult) { if (mult.length < 1) return 0; long ret = 1; for (int i = 0; i < mult.length; i++) ret *= mult[i]; return ret; } public static long prodLong(long... mult) { if (mult.length < 1) return 0; long ret = 1; for (int i = 0; i < mult.length; i++) ret *= mult[i]; return ret; } public static boolean equals(float[] data, double[] data2) { if (data.length != data2.length) return false; for (int i = 0; i < data.length; i++) { double equals = Math.abs(data2[i] - data[i]); if (equals > 1e-6) return false; } return true; } public static int[] consArray(int a, int[] as) { int len = as.length; int[] nas = new int[len + 1]; nas[0] = a; System.arraycopy(as, 0, nas, 1, len); return nas; } /** * Returns true if any of the elements are zero * @param as * @return */ public static boolean isZero(int[] as) { for (int i = 0; i < as.length; i++) { if (as[i] == 0) return true; } return false; } public static boolean isZero(long[] as) { for (int i = 0; i < as.length; i++) { if (as[i] == 0L) return true; } return false; } public static boolean anyMore(int[] target, int[] test) { assert target.length == test.length : "Unable to compare: different sizes"; for (int i = 0; i < target.length; i++) { if (target[i] > test[i]) return true; } return false; } public static boolean anyLess(int[] target, int[] test) { assert target.length == test.length : "Unable to compare: different sizes"; for (int i = 0; i < target.length; i++) { if (target[i] < test[i]) return true; } return false; } public static boolean lessThan(int[] target, int[] test) { assert target.length == test.length : "Unable to compare: different sizes"; for (int i = 0; i < target.length; i++) { if (target[i] < test[i]) return true; if (target[i] > test[i]) return false; } return false; } public static boolean greaterThan(int[] target, int[] test) { assert target.length == test.length : "Unable to compare: different sizes"; for (int i = 0; i < target.length; i++) { if (target[i] > test[i]) return true; if (target[i] < test[i]) return false; } return false; } /** * Compute the offset * based on teh shape strides and offsets * @param shape the shape to compute * @param offsets the offsets to compute * @param strides the strides to compute * @return the offset for the given shape,offset,and strides */ public static int calcOffset(List shape, List offsets, List strides) { if (shape.size() != offsets.size() || shape.size() != strides.size()) throw new IllegalArgumentException("Shapes,strides, and offsets must be the same size"); int ret = 0; for (int i = 0; i < offsets.size(); i++) { //we should only do this in the general case, not on vectors //the reason for this is we force everything including scalars //to be 2d if (shape.get(i) == 1 && offsets.size() > 2 && i > 0) continue; ret += offsets.get(i) * strides.get(i); } return ret; } /** * Compute the offset * based on teh shape strides and offsets * @param shape the shape to compute * @param offsets the offsets to compute * @param strides the strides to compute * @return the offset for the given shape,offset,and strides */ public static int calcOffset(int[] shape, int[] offsets, int[] strides) { if (shape.length != offsets.length || shape.length != strides.length) throw new IllegalArgumentException("Shapes,strides, and offsets must be the same size"); int ret = 0; for (int i = 0; i < offsets.length; i++) { if (shape[i] == 1) continue; ret += offsets[i] * strides[i]; } return ret; } /** * Compute the offset * based on teh shape strides and offsets * @param shape the shape to compute * @param offsets the offsets to compute * @param strides the strides to compute * @return the offset for the given shape,offset,and strides */ public static long calcOffsetLong(List shape, List offsets, List strides) { if (shape.size() != offsets.size() || shape.size() != strides.size()) throw new IllegalArgumentException("Shapes,strides, and offsets must be the same size"); long ret = 0; for (int i = 0; i < offsets.size(); i++) { //we should only do this in the general case, not on vectors //the reason for this is we force everything including scalars //to be 2d if (shape.get(i) == 1 && offsets.size() > 2 && i > 0) continue; ret += (long) offsets.get(i) * strides.get(i); } return ret; } public static long calcOffsetLong2(List shape, List offsets, List strides) { if (shape.size() != offsets.size() || shape.size() != strides.size()) throw new IllegalArgumentException("Shapes,strides, and offsets must be the same size"); long ret = 0; for (int i = 0; i < offsets.size(); i++) { //we should only do this in the general case, not on vectors //the reason for this is we force everything including scalars //to be 2d if (shape.get(i) == 1 && offsets.size() > 2 && i > 0) continue; ret += (long) offsets.get(i) * strides.get(i); } return ret; } /** * Compute the offset * based on teh shape strides and offsets * @param shape the shape to compute * @param offsets the offsets to compute * @param strides the strides to compute * @return the offset for the given shape,offset,and strides */ public static long calcOffsetLong(int[] shape, int[] offsets, int[] strides) { if (shape.length != offsets.length || shape.length != strides.length) throw new IllegalArgumentException("Shapes,strides, and offsets must be the same size"); long ret = 0; for (int i = 0; i < offsets.length; i++) { if (shape[i] == 1) continue; ret += (long) offsets[i] * strides[i]; } return ret; } /** * * @param xs * @param ys * @return */ public static int dotProduct(List xs, List ys) { int result = 0; int n = xs.size(); if (ys.size() != n) throw new IllegalArgumentException("Different array sizes"); for (int i = 0; i < n; i++) { result += xs.get(i) * ys.get(i); } return result; } /** * * @param xs * @param ys * @return */ public static int dotProduct(int[] xs, int[] ys) { int result = 0; int n = xs.length; if (ys.length != n) throw new IllegalArgumentException("Different array sizes"); for (int i = 0; i < n; i++) { result += xs[i] * ys[i]; } return result; } /** * * @param xs * @param ys * @return */ public static long dotProductLong(List xs, List ys) { long result = 0; int n = xs.size(); if (ys.size() != n) throw new IllegalArgumentException("Different array sizes"); for (int i = 0; i < n; i++) { result += (long) xs.get(i) * ys.get(i); } return result; } /** * * @param xs * @param ys * @return */ public static long dotProductLong2(List xs, List ys) { long result = 0; int n = xs.size(); if (ys.size() != n) throw new IllegalArgumentException("Different array sizes"); for (int i = 0; i < n; i++) { result += (long) xs.get(i) * ys.get(i); } return result; } /** * * @param xs * @param ys * @return */ public static long dotProductLong(int[] xs, int[] ys) { long result = 0; int n = xs.length; if (ys.length != n) throw new IllegalArgumentException("Different array sizes"); for (int i = 0; i < n; i++) { result += (long) xs[i] * ys[i]; } return result; } public static int[] empty() { return new int[0]; } public static int[] of(int... arr) { return arr; } public static int[] copy(int[] copy) { int[] ret = new int[copy.length]; System.arraycopy(copy, 0, ret, 0, ret.length); return ret; } public static long[] copy(long[] copy) { long[] ret = new long[copy.length]; System.arraycopy(copy, 0, ret, 0, ret.length); return ret; } public static double[] doubleCopyOf(float[] data) { double[] ret = new double[data.length]; for (int i = 0; i < ret.length; i++) ret[i] = data[i]; return ret; } public static float[] floatCopyOf(double[] data) { if (data.length == 0) return new float[1]; float[] ret = new float[data.length]; for (int i = 0; i < ret.length; i++) ret[i] = (float) data[i]; return ret; } /** * Returns a subset of an array from 0 to "to" * * @param data the data to getFromOrigin a subset of * @param to the end point of the data * @return the subset of the data specified */ public static double[] range(double[] data, int to) { return range(data, to, 1); } /** * Returns a subset of an array from 0 to "to" * using the specified stride * * @param data the data to getFromOrigin a subset of * @param to the end point of the data * @param stride the stride to go through the array * @return the subset of the data specified */ public static double[] range(double[] data, int to, int stride) { return range(data, to, stride, 1); } /** * Returns a subset of an array from 0 to "to" * using the specified stride * * @param data the data to getFromOrigin a subset of * @param to the end point of the data * @param stride the stride to go through the array * @param numElementsEachStride the number of elements to collect at each stride * @return the subset of the data specified */ public static double[] range(double[] data, int to, int stride, int numElementsEachStride) { double[] ret = new double[to / stride]; if (ret.length < 1) ret = new double[1]; int count = 0; for (int i = 0; i < data.length; i += stride) { for (int j = 0; j < numElementsEachStride; j++) { if (i + j >= data.length || count >= ret.length) break; ret[count++] = data[i + j]; } } return ret; } public static List toList(int... ints){ if(ints == null){ return null; } List ret = new ArrayList<>(); for (int anInt : ints) { ret.add(anInt); } return ret; } public static int[] toArray(List list) { int[] ret = new int[list.size()]; for (int i = 0; i < list.size(); i++) ret[i] = list.get(i); return ret; } public static long[] toArrayLong(List list) { long[] ret = new long[list.size()]; for (int i = 0; i < list.size(); i++) ret[i] = list.get(i); return ret; } public static double[] toArrayDouble(List list) { double[] ret = new double[list.size()]; for (int i = 0; i < list.size(); i++) ret[i] = list.get(i); return ret; } /** * Generate an int array ranging from * from to to. * if from is > to this method will * count backwards * * @param from the from * @param to the end point of the data * @param increment the amount to increment by * @return the int array with a length equal to absoluteValue(from - to) */ public static int[] range(int from, int to, int increment) { int diff = Math.abs(from - to); int[] ret = new int[diff / increment]; if (ret.length < 1) ret = new int[1]; if (from < to) { int count = 0; for (int i = from; i < to; i += increment) { if (count >= ret.length) break; ret[count++] = i; } } else if (from > to) { int count = 0; for (int i = from - 1; i >= to; i -= increment) { if (count >= ret.length) break; ret[count++] = i; } } return ret; } public static long[] range(long from, long to, long increment) { long diff = Math.abs(from - to); long[] ret = new long[(int) (diff / increment)]; if (ret.length < 1) ret = new long[1]; if (from < to) { int count = 0; for (long i = from; i < to; i += increment) { if (count >= ret.length) break; ret[count++] = i; } } else if (from > to) { int count = 0; for (int i = (int) from - 1; i >= to; i -= increment) { if (count >= ret.length) break; ret[count++] = i; } } return ret; } /** * Generate an int array ranging from * from to to. * if from is > to this method will * count backwards * * @param from the from * @param to the end point of the data * @return the int array with a length equal to absoluteValue(from - to) */ public static int[] range(int from, int to) { if (from == to) return new int[0]; return range(from, to, 1); } public static long[] range(long from, long to) { if (from == to) return new long[0]; return range(from, to, 1); } public static double[] toDoubles(int[] ints) { double[] ret = new double[ints.length]; for (int i = 0; i < ints.length; i++) ret[i] = (double) ints[i]; return ret; } public static double[] toDoubles(long[] ints) { double[] ret = new double[ints.length]; for (int i = 0; i < ints.length; i++) ret[i] = (double) ints[i]; return ret; } public static double[] toDoubles(float[] ints) { double[] ret = new double[ints.length]; for (int i = 0; i < ints.length; i++) ret[i] = (double) ints[i]; return ret; } public static float[] toFloats(int[][] ints) { return toFloats(Ints.concat(ints)); } public static double[] toDoubles(int[][] ints) { return toDoubles(Ints.concat(ints)); } public static float[] toFloats(int[] ints) { float[] ret = new float[ints.length]; for (int i = 0; i < ints.length; i++) ret[i] = (float) ints[i]; return ret; } public static float[] toFloats(long[] ints) { float[] ret = new float[ints.length]; for (int i = 0; i < ints.length; i++) ret[i] = (float) ints[i]; return ret; } public static float[] toFloats(double[] ints) { float[] ret = new float[ints.length]; for (int i = 0; i < ints.length; i++) ret[i] = (float) ints[i]; return ret; } /** * Return a copy of this array with the * given index omitted * * @param data the data to copy * @param index the index of the item to remove * @param newValue the newValue to replace * @return the new array with the omitted * item */ public static int[] replace(int[] data, int index, int newValue) { int[] copy = copy(data); copy[index] = newValue; return copy; } /** * Return a copy of this array with only the * given index(es) remaining * * @param data the data to copy * @param index the index of the item to remove * @return the new array with the omitted * item */ public static int[] keep(int[] data, int... index) { if (index.length == data.length) return data; int[] ret = new int[index.length]; int count = 0; for (int i = 0; i < data.length; i++) if (Ints.contains(index, i)) ret[count++] = data[i]; return ret; } /** * Return a copy of this array with only the * given index(es) remaining * * @param data the data to copy * @param index the index of the item to remove * @return the new array with the omitted * item */ public static long[] keep(long[] data, int... index) { if (index.length == data.length) return data; long[] ret = new long[index.length]; int count = 0; for (int i = 0; i < data.length; i++) if (Ints.contains(index, i)) ret[count++] = data[i]; return ret; } /** * Return a copy of this array with the * given index omitted * * PLEASE NOTE: index to be omitted must exist in source array. * * @param data the data to copy * @param index the index of the item to remove * @return the new array with the omitted * item */ public static int[] removeIndex(int[] data, int... index) { if (index.length >= data.length) { throw new IllegalStateException("Illegal remove: indexes.length > data.length (index.length=" + index.length + ", data.length=" + data.length + ")"); } int offset = 0; /* workaround for non-existent indexes (such as Integer.MAX_VALUE) for (int i = 0; i < index.length; i ++) { if (index[i] >= data.length || index[i] < 0) offset++; } */ int[] ret = new int[data.length - index.length + offset]; int count = 0; for (int i = 0; i < data.length; i++) if (!Ints.contains(index, i)) { ret[count++] = data[i]; } return ret; } public static long[] removeIndex(long[] data, int... index) { if (index.length >= data.length) { throw new IllegalStateException("Illegal remove: indexes.length > data.length (index.length=" + index.length + ", data.length=" + data.length + ")"); } int offset = 0; /* workaround for non-existent indexes (such as Integer.MAX_VALUE) for (int i = 0; i < index.length; i ++) { if (index[i] >= data.length || index[i] < 0) offset++; } */ long[] ret = new long[data.length - index.length + offset]; int count = 0; for (int i = 0; i < data.length; i++) if (!Ints.contains(index, i)) { ret[count++] = data[i]; } return ret; } /** * Zip 2 arrays in to: * * @param as * @param bs * @return */ public static int[][] zip(int[] as, int[] bs) { int[][] result = new int[as.length][2]; for (int i = 0; i < result.length; i++) { result[i] = new int[] {as[i], bs[i]}; } return result; } /** * Get the tensor matrix multiply shape * @param aShape the shape of the first array * @param bShape the shape of the second array * @param axes the axes to do the multiply * @return the shape for tensor matrix multiply */ public static long[] getTensorMmulShape(long[] aShape, long[] bShape, int[][] axes) { // FIXME: int cast int validationLength = Math.min(axes[0].length, axes[1].length); for (int i = 0; i < validationLength; i++) { if (aShape[axes[0][i]] != bShape[axes[1][i]]) throw new IllegalArgumentException( "Size of the given axes a" + " t each dimension must be the same size."); if (axes[0][i] < 0) axes[0][i] += aShape.length; if (axes[1][i] < 0) axes[1][i] += bShape.length; } List listA = new ArrayList<>(); for (int i = 0; i < aShape.length; i++) { if (!Ints.contains(axes[0], i)) listA.add(i); } List listB = new ArrayList<>(); for (int i = 0; i < bShape.length; i++) { if (!Ints.contains(axes[1], i)) listB.add(i); } int n2 = 1; int aLength = Math.min(aShape.length, axes[0].length); for (int i = 0; i < aLength; i++) { n2 *= aShape[axes[0][i]]; } //if listA and listB are empty these donot initialize. //so initializing with {1} which will then get overriden if not empty long[] oldShapeA; if (listA.size() == 0) { oldShapeA = new long[] {1}; } else { oldShapeA = Longs.toArray(listA); for (int i = 0; i < oldShapeA.length; i++) oldShapeA[i] = aShape[(int) oldShapeA[i]]; } int n3 = 1; int bNax = Math.min(bShape.length, axes[1].length); for (int i = 0; i < bNax; i++) { n3 *= bShape[axes[1][i]]; } long[] oldShapeB; if (listB.size() == 0) { oldShapeB = new long[] {1}; } else { oldShapeB = Longs.toArray(listB); for (int i = 0; i < oldShapeB.length; i++) oldShapeB[i] = bShape[(int) oldShapeB[i]]; } long[] aPlusB = Longs.concat(oldShapeA, oldShapeB); return aPlusB; } /** * Permute the given input * switching the dimensions of the input shape * array with in the order of the specified * dimensions * @param shape the shape to permute * @param dimensions the dimensions * @return */ public static int[] permute(int[] shape, int[] dimensions) { int[] ret = new int[shape.length]; for (int i = 0; i < shape.length; i++) { ret[i] = shape[dimensions[i]]; } return ret; } public static long[] permute(long[] shape, int[] dimensions) { val ret = new long[shape.length]; for (int i = 0; i < shape.length; i++) { ret[i] = shape[dimensions[i]]; } return ret; } /** * Original credit: https://github.com/alberts/array4j/blob/master/src/main/java/net/lunglet/util/ArrayUtils.java * @param a * @return */ public static int[] argsort(int[] a) { return argsort(a, true); } /** * * @param a * @param ascending * @return */ public static int[] argsort(final int[] a, final boolean ascending) { Integer[] indexes = new Integer[a.length]; for (int i = 0; i < indexes.length; i++) { indexes[i] = i; } Arrays.sort(indexes, new Comparator() { @Override public int compare(final Integer i1, final Integer i2) { return (ascending ? 1 : -1) * Ints.compare(a[i1], a[i2]); } }); int[] ret = new int[indexes.length]; for (int i = 0; i < ret.length; i++) ret[i] = indexes[i]; return ret; } /** * Convert all dimensions in the specified * axes array to be positive * based on the specified range of values * @param range * @param axes * @return */ public static int[] convertNegativeIndices(int range, int[] axes) { int[] axesRet = ArrayUtil.range(0, range); int[] newAxes = ArrayUtil.copy(axes); for (int i = 0; i < axes.length; i++) { newAxes[i] = axes[axesRet[i]]; } return newAxes; } /** * Generate an array from 0 to length * and generate take a subset * @param length the length to generate to * @param from the begin of the interval to take * @param to the end of the interval to take * @return the generated array */ public static int[] copyOfRangeFrom(int length, int from, int to) { return Arrays.copyOfRange(ArrayUtil.range(0, length), from, to); } //Credit: http://stackoverflow.com/questions/15533854/converting-byte-array-to-double-array /** * * @param doubleArray * @return */ public static byte[] toByteArray(double[] doubleArray) { int times = Double.SIZE / Byte.SIZE; byte[] bytes = new byte[doubleArray.length * times]; for (int i = 0; i < doubleArray.length; i++) { ByteBuffer.wrap(bytes, i * times, times).putDouble(doubleArray[i]); } return bytes; } /** * * @param byteArray * @return */ public static double[] toDoubleArray(byte[] byteArray) { int times = Double.SIZE / Byte.SIZE; double[] doubles = new double[byteArray.length / times]; for (int i = 0; i < doubles.length; i++) { doubles[i] = ByteBuffer.wrap(byteArray, i * times, times).getDouble(); } return doubles; } /** * * @param doubleArray * @return */ public static byte[] toByteArray(float[] doubleArray) { int times = Float.SIZE / Byte.SIZE; byte[] bytes = new byte[doubleArray.length * times]; for (int i = 0; i < doubleArray.length; i++) { ByteBuffer.wrap(bytes, i * times, times).putFloat(doubleArray[i]); } return bytes; } public static long[] toLongArray(int[] intArray) { long[] ret = new long[intArray.length]; for (int i = 0; i < intArray.length; i++) { ret[i] = intArray[i]; } return ret; } /** * * @param byteArray * @return */ public static float[] toFloatArray(byte[] byteArray) { int times = Float.SIZE / Byte.SIZE; float[] doubles = new float[byteArray.length / times]; for (int i = 0; i < doubles.length; i++) { doubles[i] = ByteBuffer.wrap(byteArray, i * times, times).getFloat(); } return doubles; } /** * * @param intArray * @return */ public static byte[] toByteArray(int[] intArray) { int times = Integer.SIZE / Byte.SIZE; byte[] bytes = new byte[intArray.length * times]; for (int i = 0; i < intArray.length; i++) { ByteBuffer.wrap(bytes, i * times, times).putInt(intArray[i]); } return bytes; } /** * * @param byteArray * @return */ public static int[] toIntArray(byte[] byteArray) { int times = Integer.SIZE / Byte.SIZE; int[] ints = new int[byteArray.length / times]; for (int i = 0; i < ints.length; i++) { ints[i] = ByteBuffer.wrap(byteArray, i * times, times).getInt(); } return ints; } /** * Return a copy of this array with the * given index omitted * * @param data the data to copy * @param index the index of the item to remove * @return the new array with the omitted * item */ public static int[] removeIndex(int[] data, int index) { if (data == null) return null; if (index >= data.length) throw new IllegalArgumentException("Unable to remove index " + index + " was >= data.length"); if (data.length < 1) return data; if (index < 0) return data; int len = data.length; int[] result = new int[len - 1]; System.arraycopy(data, 0, result, 0, index); System.arraycopy(data, index + 1, result, index, len - index - 1); return result; } public static long[] removeIndex(long[] data, int index) { if (data == null) return null; if (index >= data.length) throw new IllegalArgumentException("Unable to remove index " + index + " was >= data.length"); if (data.length < 1) return data; if (index < 0) return data; int len = data.length; long[] result = new long[len - 1]; System.arraycopy(data, 0, result, 0, index); System.arraycopy(data, index + 1, result, index, len - index - 1); return result; } /** * Create a copy of the given array * starting at the given index with the given length. * * The intent here is for striding. * * For example in slicing, you want the major stride to be first. * You achieve this by taking the last index * of the matrix's stride and putting * this as the first stride of the new ndarray * for slicing. * * All of the elements except the copied elements are * initialized as the given value * @param valueStarting the starting value * @param copy the array to copy * @param idxFrom the index to start at in the from array * @param idxAt the index to start at in the return array * @param length the length of the array to create * @return the given array */ public static int[] valueStartingAt(int valueStarting, int[] copy, int idxFrom, int idxAt, int length) { int[] ret = new int[length]; Arrays.fill(ret, valueStarting); for (int i = 0; i < length; i++) { if (i + idxFrom >= copy.length || i + idxAt >= ret.length) break; ret[i + idxAt] = copy[i + idxFrom]; } return ret; } /** * Returns the array with the item in index * removed, if the array is empty it will return the array itself * * @param data the data to remove data from * @param index the index of the item to remove * @return a copy of the array with the removed item, * or the array itself if empty */ public static Integer[] removeIndex(Integer[] data, int index) { if (data == null) return null; if (data.length < 1) return data; int len = data.length; Integer[] result = new Integer[len - 1]; System.arraycopy(data, 0, result, 0, index); System.arraycopy(data, index + 1, result, index, len - index - 1); return result; } /** * Computes the standard packed array strides for a given shape. * * @param shape the shape of a matrix: * @param startNum the start number for the strides * @return the strides for a matrix of n dimensions */ public static int[] calcStridesFortran(int[] shape, int startNum) { if (shape.length == 2 && (shape[0] == 1 || shape[1] == 1)) { int[] ret = new int[2]; Arrays.fill(ret, startNum); return ret; } int dimensions = shape.length; int[] stride = new int[dimensions]; int st = startNum; for (int j = 0; j < stride.length; j++) { stride[j] = st; st *= shape[j]; } return stride; } /** * Computes the standard packed array strides for a given shape. * * @param shape the shape of a matrix: * @param startNum the start number for the strides * @return the strides for a matrix of n dimensions */ public static long[] calcStridesFortran(long[] shape, int startNum) { if (shape.length == 2 && (shape[0] == 1 || shape[1] == 1)) { long[] ret = new long[2]; Arrays.fill(ret, startNum); return ret; } int dimensions = shape.length; long[] stride = new long[dimensions]; int st = startNum; for (int j = 0; j < stride.length; j++) { stride[j] = st; st *= shape[j]; } return stride; } /** * Computes the standard packed array strides for a given shape. * * @param shape the shape of a matrix: * @return the strides for a matrix of n dimensions */ public static int[] calcStridesFortran(int[] shape) { return calcStridesFortran(shape, 1); } public static long[] calcStridesFortran(long[] shape) { return calcStridesFortran(shape, 1); } /** * Computes the standard packed array strides for a given shape. * * @param shape the shape of a matrix: * @param startValue the startValue for the strides * @return the strides for a matrix of n dimensions */ public static int[] calcStrides(int[] shape, int startValue) { if (shape.length == 2 && (shape[0] == 1 || shape[1] == 1)) { int[] ret = new int[2]; Arrays.fill(ret, startValue); return ret; } int dimensions = shape.length; int[] stride = new int[dimensions]; int st = startValue; for (int j = dimensions - 1; j >= 0; j--) { stride[j] = st; st *= shape[j]; } return stride; } /** * Computes the standard packed array strides for a given shape. * * @param shape the shape of a matrix: * @param startValue the startValue for the strides * @return the strides for a matrix of n dimensions */ public static long[] calcStrides(long[] shape, int startValue) { if (shape.length == 2 && (shape[0] == 1 || shape[1] == 1)) { long[] ret = new long[2]; Arrays.fill(ret, startValue); return ret; } int dimensions = shape.length; long[] stride = new long[dimensions]; int st = startValue; for (int j = dimensions - 1; j >= 0; j--) { stride[j] = st; st *= shape[j]; } return stride; } /** * Returns true if the given * two arrays are reverse copies of each other * @param first * @param second * @return */ public static boolean isInverse(int[] first, int[] second) { int backWardCount = second.length - 1; for (int i = 0; i < first.length; i++) { if (first[i] != second[backWardCount--]) return false; } return true; } public static int[] plus(int[] ints, int mult) { int[] ret = new int[ints.length]; for (int i = 0; i < ints.length; i++) ret[i] = ints[i] + mult; return ret; } public static int[] plus(int[] ints, int[] mult) { if (ints.length != mult.length) throw new IllegalArgumentException("Both arrays must have the same length"); int[] ret = new int[ints.length]; for (int i = 0; i < ints.length; i++) ret[i] = ints[i] + mult[i]; return ret; } public static int[] times(int[] ints, int mult) { int[] ret = new int[ints.length]; for (int i = 0; i < ints.length; i++) ret[i] = ints[i] * mult; return ret; } public static int[] times(int[] ints, int[] mult) { assert ints.length == mult.length : "Ints and mult must be the same length"; int[] ret = new int[ints.length]; for (int i = 0; i < ints.length; i++) ret[i] = ints[i] * mult[i]; return ret; } /** * For use with row vectors to ensure consistent strides * with varying offsets * * @param arr the array to get the stride for * @return the stride */ public static int nonOneStride(int[] arr) { for (int i = 0; i < arr.length; i++) if (arr[i] != 1) return arr[i]; return 1; } /** * Computes the standard packed array strides for a given shape. * * @param shape the shape of a matrix: * @return the strides for a matrix of n dimensions */ public static int[] calcStrides(int[] shape) { return calcStrides(shape, 1); } public static long[] calcStrides(long[] shape) { return calcStrides(shape, 1); } /** * Create a backwards copy of the given array * * @param e the array to createComplex a reverse clone of * @return the reversed copy */ public static int[] reverseCopy(int[] e) { if (e.length < 1) return e; int[] copy = new int[e.length]; for (int i = 0; i <= e.length / 2; i++) { int temp = e[i]; copy[i] = e[e.length - i - 1]; copy[e.length - i - 1] = temp; } return copy; } public static long[] reverseCopy(long[] e) { if (e.length < 1) return e; long[] copy = new long[e.length]; for (int i = 0; i <= e.length / 2; i++) { long temp = e[i]; copy[i] = e[e.length - i - 1]; copy[e.length - i - 1] = temp; } return copy; } public static double[] read(int length, DataInputStream dis) throws IOException { double[] ret = new double[length]; for (int i = 0; i < length; i++) ret[i] = dis.readDouble(); return ret; } public static void write(double[] data, DataOutputStream dos) throws IOException { for (int i = 0; i < data.length; i++) dos.writeDouble(data[i]); } public static double[] readDouble(int length, DataInputStream dis) throws IOException { double[] ret = new double[length]; for (int i = 0; i < length; i++) ret[i] = dis.readDouble(); return ret; } public static float[] readFloat(int length, DataInputStream dis) throws IOException { float[] ret = new float[length]; for (int i = 0; i < length; i++) ret[i] = dis.readFloat(); return ret; } public static void write(float[] data, DataOutputStream dos) throws IOException { for (int i = 0; i < data.length; i++) dos.writeFloat(data[i]); } public static void assertSquare(double[]... d) { if (d.length > 2) { for (int i = 0; i < d.length; i++) { assertSquare(d[i]); } } else { int firstLength = d[0].length; for (int i = 1; i < d.length; i++) { assert d[i].length == firstLength; } } } /** * Multiply the given array * by the given scalar * @param arr the array to multily * @param mult the scalar to multiply by */ public static void multiplyBy(int[] arr, int mult) { for (int i = 0; i < arr.length; i++) arr[i] *= mult; } /** * Reverse the passed in array in place * * @param e the array to reverse */ public static void reverse(int[] e) { for (int i = 0; i <= e.length / 2; i++) { int temp = e[i]; e[i] = e[e.length - i - 1]; e[e.length - i - 1] = temp; } } public static void reverse(long[] e) { for (int i = 0; i <= e.length / 2; i++) { long temp = e[i]; e[i] = e[e.length - i - 1]; e[e.length - i - 1] = temp; } } public static List zerosMatrix(long... dimensions) { List ret = new ArrayList<>(); for (int i = 0; i < dimensions.length; i++) { ret.add(new double[(int) dimensions[i]]); } return ret; } public static List zerosMatrix(int... dimensions) { List ret = new ArrayList<>(); for (int i = 0; i < dimensions.length; i++) { ret.add(new double[dimensions[i]]); } return ret; } public static float[] reverseCopy(float[] e) { float[] copy = new float[e.length]; for (int i = 0; i <= e.length / 2; i++) { float temp = e[i]; copy[i] = e[e.length - i - 1]; copy[e.length - i - 1] = temp; } return copy; } public static E[] reverseCopy(E[] e) { E[] copy = (E[]) new Object[e.length]; for (int i = 0; i <= e.length / 2; i++) { E temp = e[i]; copy[i] = e[e.length - i - 1]; copy[e.length - i - 1] = temp; } return copy; } public static void reverse(E[] e) { for (int i = 0; i <= e.length / 2; i++) { E temp = e[i]; e[i] = e[e.length - i - 1]; e[e.length - i - 1] = temp; } } public static float[] flatten(float[][] arr) { float[] ret = new float[arr.length * arr[0].length]; int count = 0; for (int i = 0; i < arr.length; i++) for (int j = 0; j < arr[i].length; j++) ret[count++] = arr[i][j]; return ret; } public static float[] flatten(float[][][] arr) { float[] ret = new float[arr.length * arr[0].length * arr[0][0].length]; int count = 0; for (int i = 0; i < arr.length; i++) for (int j = 0; j < arr[0].length; j++) for (int k = 0; k < arr[0][0].length; k++) { ret[count++] = arr[i][j][k]; } return ret; } public static double[] flatten(double[][][] arr) { double[] ret = new double[arr.length * arr[0].length * arr[0][0].length]; int count = 0; for (int i = 0; i < arr.length; i++) for (int j = 0; j < arr[0].length; j++) for (int k = 0; k < arr[0][0].length; k++) { ret[count++] = arr[i][j][k]; } return ret; } public static int[] flatten(int[][][] arr) { int[] ret = new int[arr.length * arr[0].length * arr[0][0].length]; int count = 0; for (int i = 0; i < arr.length; i++) for (int j = 0; j < arr[0].length; j++) for (int k = 0; k < arr[0][0].length; k++) { ret[count++] = arr[i][j][k]; } return ret; } public static float[] flatten(float[][][][] arr) { float[] ret = new float[arr.length * arr[0].length * arr[0][0].length * arr[0][0][0].length]; int count = 0; for (int i = 0; i < arr.length; i++) for (int j = 0; j < arr[0].length; j++) for (int k = 0; k < arr[0][0].length; k++) for (int m = 0; m < arr[0][0][0].length; m++) ret[count++] = arr[i][j][k][m]; return ret; } public static double[] flatten(double[][][][] arr) { double[] ret = new double[arr.length * arr[0].length * arr[0][0].length * arr[0][0][0].length]; int count = 0; for (int i = 0; i < arr.length; i++) for (int j = 0; j < arr[0].length; j++) for (int k = 0; k < arr[0][0].length; k++) for (int m = 0; m < arr[0][0][0].length; m++) ret[count++] = arr[i][j][k][m]; return ret; } public static int[] flatten(int[][][][] arr) { int[] ret = new int[arr.length * arr[0].length * arr[0][0].length * arr[0][0][0].length]; int count = 0; for (int i = 0; i < arr.length; i++) for (int j = 0; j < arr[0].length; j++) for (int k = 0; k < arr[0][0].length; k++) for (int m = 0; m < arr[0][0][0].length; m++) ret[count++] = arr[i][j][k][m]; return ret; } public static int[] flatten(int[][] arr) { int[] ret = new int[arr.length * arr[0].length]; int count = 0; for (int i = 0; i < arr.length; i++) for (int j = 0; j < arr[i].length; j++) ret[count++] = arr[i][j]; return ret; } public static long[] flatten(long[][] arr) { long[] ret = new long[arr.length * arr[0].length]; int count = 0; for (int i = 0; i < arr.length; i++) for (int j = 0; j < arr[i].length; j++) ret[count++] = arr[i][j]; return ret; } /** * Convert a 2darray in to a flat * array (row wise) * @param arr the array to flatten * @return a flattened representation of the array */ public static double[] flatten(double[][] arr) { double[] ret = new double[arr.length * arr[0].length]; int count = 0; for (int i = 0; i < arr.length; i++) for (int j = 0; j < arr[i].length; j++) ret[count++] = arr[i][j]; return ret; } /** * Convert a 2darray in to a flat * array (row wise) * @param arr the array to flatten * @return a flattened representation of the array */ public static double[] flattenF(double[][] arr) { double[] ret = new double[arr.length * arr[0].length]; int count = 0; for (int j = 0; j < arr[0].length; j++) for (int i = 0; i < arr.length; i++) ret[count++] = arr[i][j]; return ret; } public static float[] flattenF(float[][] arr) { float[] ret = new float[arr.length * arr[0].length]; int count = 0; for (int j = 0; j < arr[0].length; j++) for (int i = 0; i < arr.length; i++) ret[count++] = arr[i][j]; return ret; } public static int[] flattenF(int[][] arr) { int[] ret = new int[arr.length * arr[0].length]; int count = 0; for (int j = 0; j < arr[0].length; j++) for (int i = 0; i < arr.length; i++) ret[count++] = arr[i][j]; return ret; } public static long[] flattenF(long[][] arr) { long[] ret = new long[arr.length * arr[0].length]; int count = 0; for (int j = 0; j < arr[0].length; j++) for (int i = 0; i < arr.length; i++) ret[count++] = arr[i][j]; return ret; } /** * Cast an int array to a double array * @param arr the array to cast * @return the elements of this * array cast as an int */ public static double[][] toDouble(int[][] arr) { double[][] ret = new double[arr.length][arr[0].length]; for (int i = 0; i < arr.length; i++) { for (int j = 0; j < arr[i].length; j++) ret[i][j] = arr[i][j]; } return ret; } /** * Combines a applyTransformToDestination of int arrays in to one flat int array * * @param nums the int arrays to combineDouble * @return one combined int array */ public static float[] combineFloat(List nums) { int length = 0; for (int i = 0; i < nums.size(); i++) length += nums.get(i).length; float[] ret = new float[length]; int count = 0; for (float[] i : nums) { for (int j = 0; j < i.length; j++) { ret[count++] = i[j]; } } return ret; } /** * Combines a apply of int arrays in to one flat int array * * @param nums the int arrays to combineDouble * @return one combined int array */ public static float[] combine(List nums) { int length = 0; for (int i = 0; i < nums.size(); i++) length += nums.get(i).length; float[] ret = new float[length]; int count = 0; for (float[] i : nums) { for (int j = 0; j < i.length; j++) { ret[count++] = i[j]; } } return ret; } /** * Combines a apply of int arrays in to one flat int array * * @param nums the int arrays to combineDouble * @return one combined int array */ public static double[] combineDouble(List nums) { int length = 0; for (int i = 0; i < nums.size(); i++) length += nums.get(i).length; double[] ret = new double[length]; int count = 0; for (double[] i : nums) { for (int j = 0; j < i.length; j++) { ret[count++] = i[j]; } } return ret; } /** * Combines a apply of int arrays in to one flat int array * * @param ints the int arrays to combineDouble * @return one combined int array */ public static double[] combine(float[]... ints) { int length = 0; for (int i = 0; i < ints.length; i++) length += ints[i].length; double[] ret = new double[length]; int count = 0; for (float[] i : ints) { for (int j = 0; j < i.length; j++) { ret[count++] = i[j]; } } return ret; } /** * Combines a apply of int arrays in to one flat int array * * @param ints the int arrays to combineDouble * @return one combined int array */ public static int[] combine(int[]... ints) { int length = 0; for (int i = 0; i < ints.length; i++) length += ints[i].length; int[] ret = new int[length]; int count = 0; for (int[] i : ints) { for (int j = 0; j < i.length; j++) { ret[count++] = i[j]; } } return ret; } /** * Combines a apply of long arrays in to one flat long array * * @param ints the int arrays to combineDouble * @return one combined int array */ public static long[] combine(long[]... ints) { int length = 0; for (int i = 0; i < ints.length; i++) length += ints[i].length; long[] ret = new long[length]; int count = 0; for (long[] i : ints) { for (int j = 0; j < i.length; j++) { ret[count++] = i[j]; } } return ret; } public static E[] combine(E[]... arrs) { int length = 0; for (int i = 0; i < arrs.length; i++) length += arrs[i].length; E[] ret = (E[]) Array.newInstance(arrs[0][0].getClass(), length); int count = 0; for (E[] i : arrs) { for (int j = 0; j < i.length; j++) { ret[count++] = i[j]; } } return ret; } public static int[] toOutcomeArray(int outcome, int numOutcomes) { int[] nums = new int[numOutcomes]; nums[outcome] = 1; return nums; } public static double[] toDouble(int[] data) { double[] ret = new double[data.length]; for (int i = 0; i < ret.length; i++) ret[i] = data[i]; return ret; } public static double[] toDouble(long[] data) { double[] ret = new double[data.length]; for (int i = 0; i < ret.length; i++) ret[i] = data[i]; return ret; } public static float[] copy(float[] data) { float[] result = new float[data.length]; System.arraycopy(data, 0, result, 0, data.length); return result; } public static double[] copy(double[] data) { double[] result = new double[data.length]; System.arraycopy(data, 0, result, 0, data.length); return result; } /** Convert an arbitrary-dimensional rectangular double array to flat vector.
* Can pass double[], double[][], double[][][], etc. */ public static double[] flattenDoubleArray(Object doubleArray) { if (doubleArray instanceof double[]) return (double[]) doubleArray; LinkedList stack = new LinkedList<>(); stack.push(doubleArray); int[] shape = arrayShape(doubleArray); int length = ArrayUtil.prod(shape); double[] flat = new double[length]; int count = 0; while (!stack.isEmpty()) { Object current = stack.pop(); if (current instanceof double[]) { double[] arr = (double[]) current; for (int i = 0; i < arr.length; i++) flat[count++] = arr[i]; } else if (current instanceof Object[]) { Object[] o = (Object[]) current; for (int i = o.length - 1; i >= 0; i--) stack.push(o[i]); } else throw new IllegalArgumentException("Base array is not double[]"); } if (count != flat.length) throw new IllegalArgumentException("Fewer elements than expected. Array is ragged?"); return flat; } /** Convert an arbitrary-dimensional rectangular float array to flat vector.
* Can pass float[], float[][], float[][][], etc. */ public static float[] flattenFloatArray(Object floatArray) { if (floatArray instanceof float[]) return (float[]) floatArray; LinkedList stack = new LinkedList<>(); stack.push(floatArray); int[] shape = arrayShape(floatArray); int length = ArrayUtil.prod(shape); float[] flat = new float[length]; int count = 0; while (!stack.isEmpty()) { Object current = stack.pop(); if (current instanceof float[]) { float[] arr = (float[]) current; for (int i = 0; i < arr.length; i++) flat[count++] = arr[i]; } else if (current instanceof Object[]) { Object[] o = (Object[]) current; for (int i = o.length - 1; i >= 0; i--) stack.push(o[i]); } else throw new IllegalArgumentException("Base array is not float[]"); } if (count != flat.length) throw new IllegalArgumentException("Fewer elements than expected. Array is ragged?"); return flat; } /** Calculate the shape of an arbitrary multi-dimensional array. Assumes:
* (a) array is rectangular (not ragged) and first elements (i.e., array[0][0][0]...) are non-null
* (b) First elements have > 0 length. So array[0].length > 0, array[0][0].length > 0, etc.
* Can pass any Java array opType: double[], Object[][][], float[][], etc.
* Length of returned array is number of dimensions; returned[i] is size of ith dimension. */ public static int[] arrayShape(Object array) { int nDimensions = 0; Class c = array.getClass().getComponentType(); while (c != null) { nDimensions++; c = c.getComponentType(); } int[] shape = new int[nDimensions]; Object current = array; for (int i = 0; i < shape.length - 1; i++) { shape[i] = ((Object[]) current).length; current = ((Object[]) current)[0]; } if (current instanceof Object[]) { shape[shape.length - 1] = ((Object[]) current).length; } else if (current instanceof double[]) { shape[shape.length - 1] = ((double[]) current).length; } else if (current instanceof float[]) { shape[shape.length - 1] = ((float[]) current).length; } else if (current instanceof long[]) { shape[shape.length - 1] = ((long[]) current).length; } else if (current instanceof int[]) { shape[shape.length - 1] = ((int[]) current).length; } else if (current instanceof byte[]) { shape[shape.length - 1] = ((byte[]) current).length; } else if (current instanceof char[]) { shape[shape.length - 1] = ((char[]) current).length; } else if (current instanceof boolean[]) { shape[shape.length - 1] = ((boolean[]) current).length; } else if (current instanceof short[]) { shape[shape.length - 1] = ((short[]) current).length; } else throw new IllegalStateException("Unknown array opType"); //Should never happen return shape; } /** Returns the maximum value in the array */ public static int max(int[] in) { int max = Integer.MIN_VALUE; for (int i = 0; i < in.length; i++) if (in[i] > max) max = in[i]; return max; } /** Returns the minimum value in the array */ public static int min(int[] in) { int min = Integer.MAX_VALUE; for (int i = 0; i < in.length; i++) if (in[i] < min) min = in[i]; return min; } /** Returns the index of the maximum value in the array. * If two entries have same maximum value, index of the first one is returned. */ public static int argMax(int[] in) { int maxIdx = 0; for (int i = 1; i < in.length; i++) if (in[i] > in[maxIdx]) maxIdx = i; return maxIdx; } /** Returns the index of the minimum value in the array. * If two entries have same minimum value, index of the first one is returned. */ public static int argMin(int[] in) { int minIdx = 0; for (int i = 1; i < in.length; i++) if (in[i] < in[minIdx]) minIdx = i; return minIdx; } /** Returns the index of the maximum value in the array. * If two entries have same maximum value, index of the first one is returned. */ public static int argMax(long[] in) { int maxIdx = 0; for (int i = 1; i < in.length; i++) if (in[i] > in[maxIdx]) maxIdx = i; return maxIdx; } /** Returns the index of the minimum value in the array. * If two entries have same minimum value, index of the first one is returned. */ public static int argMin(long[] in) { int minIdx = 0; for (int i = 1; i < in.length; i++) if (in[i] < in[minIdx]) minIdx = i; return minIdx; } /** * * @return */ public static int[] buildHalfVector(Random rng, int length) { int[] result = new int[length]; List indexes = new ArrayList<>(); // we add indexes from second half only for (int i = result.length - 1; i >= result.length / 2; i--) { indexes.add(i); } Collections.shuffle(indexes, rng); for (int i = 0; i < result.length; i++) { if (i < result.length / 2) { result[i] = indexes.get(0); indexes.remove(0); } else result[i] = -1; } return result; } public static int[] buildInterleavedVector(Random rng, int length) { int[] result = new int[length]; List indexes = new ArrayList<>(); List odds = new ArrayList<>(); // we add odd indexes only to list for (int i = 1; i < result.length; i += 2) { indexes.add(i); odds.add(i - 1); } Collections.shuffle(indexes, rng); // now all even elements will be interleaved with odd elements for (int i = 0; i < result.length; i++) { if (i % 2 == 0 && indexes.size() >= 1) { int idx = indexes.get(0); indexes.remove(0); result[i] = idx; } else result[i] = -1; } // for odd tad numbers, we add special random clause for last element if (length % 2 != 0) { int rndClause = odds.get(rng.nextInt(odds.size())); int tmp = result[rndClause]; result[rndClause] = result[result.length - 1]; result[result.length - 1] = tmp; } return result; } public static long[] buildInterleavedVector(Random rng, long length) { // FIXME: int cast long[] result = new long[(int) length]; List indexes = new ArrayList<>(); List odds = new ArrayList<>(); // we add odd indexes only to list for (int i = 1; i < result.length; i += 2) { indexes.add(i); odds.add(i - 1); } Collections.shuffle(indexes, rng); // now all even elements will be interleaved with odd elements for (int i = 0; i < result.length; i++) { if (i % 2 == 0 && indexes.size() >= 1) { int idx = indexes.get(0); indexes.remove(0); result[i] = idx; } else result[i] = -1; } // for odd tad numbers, we add special random clause for last element if (length % 2 != 0) { int rndClause = odds.get(rng.nextInt(odds.size())); long tmp = result[rndClause]; result[rndClause] = result[result.length - 1]; result[result.length - 1] = tmp; } return result; } protected static void swap(List objects, int idxA, int idxB) { T tmpA = objects.get(idxA); T tmpB = objects.get(idxB); objects.set(idxA, tmpB); objects.set(idxB, tmpA); } public static void shuffleWithMap(List objects, int[] map) { for (int i = 0; i < map.length; i++) { if (map[i] >= 0) { swap(objects, i, map[i]); } } } public static int argMinOfMax(int[] first, int[] second) { int minIdx = 0; int maxAtMinIdx = Math.max(first[0], second[0]); for (int i = 1; i < first.length; i++) { int maxAtIndex = Math.max(first[i], second[i]); if (maxAtMinIdx > maxAtIndex) { maxAtMinIdx = maxAtIndex; minIdx = i; } } return minIdx; } public static long argMinOfMax(long[] first, long[] second) { long minIdx = 0; long maxAtMinIdx = Math.max(first[0], second[0]); for (int i = 1; i < first.length; i++) { long maxAtIndex = Math.max(first[i], second[i]); if (maxAtMinIdx > maxAtIndex) { maxAtMinIdx = maxAtIndex; minIdx = i; } } return minIdx; } public static int argMinOfMax(int[]... arrays) { int minIdx = 0; int maxAtMinIdx = Integer.MAX_VALUE; for (int i = 0; i < arrays[0].length; i++) { int maxAtIndex = Integer.MIN_VALUE; for (int j = 0; j < arrays.length; j++) { maxAtIndex = Math.max(maxAtIndex, arrays[j][i]); } if (maxAtMinIdx > maxAtIndex) { maxAtMinIdx = maxAtIndex; minIdx = i; } } return minIdx; } public static long argMinOfMax(long[]... arrays) { int minIdx = 0; long maxAtMinIdx = Long.MAX_VALUE; for (int i = 0; i < arrays[0].length; i++) { long maxAtIndex = Long.MIN_VALUE; for (int j = 0; j < arrays.length; j++) { maxAtIndex = Math.max(maxAtIndex, arrays[j][i]); } if (maxAtMinIdx > maxAtIndex) { maxAtMinIdx = maxAtIndex; minIdx = i; } } return minIdx; } public static int argMinOfSum(int[] first, int[] second) { int minIdx = 0; int sumAtMinIdx = first[0] + second[0]; for (int i = 1; i < first.length; i++) { int sumAtIndex = first[i] + second[i]; if (sumAtMinIdx > sumAtIndex) { sumAtMinIdx = sumAtIndex; minIdx = i; } } return minIdx; } public static > Map sortMapByValue(Map map) { List> list = new LinkedList<>(map.entrySet()); Collections.sort(list, new Comparator>() { @Override public int compare(Map.Entry o1, Map.Entry o2) { return (o1.getValue()).compareTo(o2.getValue()); } }); Map result = new LinkedHashMap<>(); for (Map.Entry entry : list) { result.put(entry.getKey(), entry.getValue()); } return result; } public static T getRandomElement(List list) { if (list.size() < 1) return null; return list.get(RandomUtils.nextInt(0, list.size())); } /** * Convert an int * @param bool * @return */ public static int fromBoolean(boolean bool) { return bool ? 1 : 0; } }