All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.linalg.util.Shape Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
package org.nd4j.linalg.util;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

/**
 * Encapsulates all shape related logic (vector of 0 dimension is a scalar is equivalent to
 *                                       a vector of length 1...)
 *
 *  @author Adam Gibson
 */
public class Shape {
    /**
     * Gets rid of any singleton dimensions of the given array
     * @param shape the shape to squeeze
     * @return the array with all of the singleton dimensions removed
     */
    public static int[] squeeze(int[] shape,int[] stride) {
        List ret = new ArrayList<>();

        for(int i = 0; i < shape.length; i++)
            if(shape[i] != 1)
                ret.add(shape[i]);
        return ArrayUtil.toArray(ret);
    }


    public static int[] sizeForAxes(int[] axes,int[] shape) {
        int[] ret = new int[axes.length];
        for(int i = 0; i < axes.length; i++) {
            ret[i] = shape[axes[i]];
        }
        return ret;
    }


    /**
     * Returns whether the given shape is a vector
     * @param shape the shape to test
     * @return whether the given shape is a vector
     */
    public static boolean isVector(int[] shape) {
        if(shape.length > 2 || shape.length < 1)
            return false;
        else {
            int len = ArrayUtil.prod(shape);
            return shape[0] == len || shape[1] == len;
        }
    }

    /**
     * Returns whether the passed in shape is a matrix
     * @param shape whether the passed in shape is a matrix
     * @return true if the shape is a matrix false otherwise
     */
    public static boolean isMatrix(int[] shape) {
        if(shape.length != 2)
            return false;
        return !isVector(shape);
    }



    /**
     * Gets rid of any singleton dimensions of the given array
     * @param shape the shape to squeeze
     * @return the array with all of the singleton dimensions removed
     */
    public static int[] squeeze(int[] shape) {
        List ret = new ArrayList<>();

        for(int i = 0; i < shape.length; i++)
            if(shape[i] != 1)
                ret.add(shape[i]);
        return ArrayUtil.toArray(ret);
    }


    public static int nonZeroDimension(int[] shape) {
        if(shape[0] == 1 && shape.length > 1)
            return shape[1];
        return shape[0];
    }


    /**
     * Returns whether 2 shapes are equals by checking for dimension semantics
     * as well as array equality
     * @param shape1 the first shape for comparison
     * @param shape2 the second shape for comparison
     * @return whether the shapes are equivalent
     */
    public static boolean shapeEquals(int[] shape1,int[] shape2) {
        if(isColumnVectorShape(shape1)) {
            if(isColumnVectorShape(shape2)) {
                return Arrays.equals(shape1,shape2);
            }

        }

        if(isRowVectorShape(shape1)) {
            if(isRowVectorShape(shape2)) {
               int[] shape1Comp = squeeze(shape1);
                int[] shape2Comp = squeeze(shape2);
                return Arrays.equals(shape1Comp,shape2Comp);
            }
        }

        return scalarEquals(shape1,shape2) || Arrays.equals(shape1,shape2);
    }



    /**
     * Returns true if the given shapes are both scalars (0 dimension or shape[0] == 1)
     * @param shape1 the first shape for comparison
     * @param shape2 the second shape for comparison
     * @return whether the 2 shapes are equal based on scalar rules
     */
    public static boolean scalarEquals(int[] shape1,int[] shape2) {
        if(shape1.length == 0) {
            if(shape2.length == 1 && shape2[0] == 1)
                return true;
        }

        else if(shape2.length == 0) {
            if(shape1.length == 1 && shape1[0] == 1)
                return true;
        }

        return false;
    }

    public static boolean isRowVectorShape(int[] shape) {
        return
                (shape.length == 2
                &&  shape[0] == 1) ||
                shape.length == 1;

    }

    public static boolean isColumnVectorShape(int[] shape) {
        return
                (shape.length == 2
                        &&  shape[1] == 1);

    }



    /**
     * Returns true for the case where
     * singleton dimensions are being compared
     * @param test1 the first to test
     * @param test2 the second to test
     * @return true if the arrays
     * are equal with the singleton dimension omitted
     */
   public static boolean squeezeEquals(int[] test1,int[] test2) {
       int[] s1 = squeeze(test1);
       int[] s2 = squeeze(test2);
       return scalarEquals(s1,s2) || Arrays.equals(s1,s2);
   }


}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy