All Downloads are FREE. Search and download functionalities are using the official Maven repository.

deepboof.misc.TensorOps Maven / Gradle / Ivy

There is a newer version: 0.5.3
Show newest version
/*
 * Copyright (c) 2016, Peter Abeles. All Rights Reserved.
 *
 * This file is part of DeepBoof
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package deepboof.misc;

import deepboof.Tensor;
import deepboof.tensors.Tensor_F32;
import deepboof.tensors.Tensor_F64;

import java.io.File;
import java.util.ArrayList;
import java.util.List;

/**
 * @author Peter Abeles
 */
public class TensorOps {

	/**
	 * Convenience function for wrapping passed in elements into a list
	 */
	public static  List WT(T ...elements ) {
		List list = new ArrayList<>();

		for (int i = 0; i < elements.length; i++) {
			list.add(elements[i]);
		}

		return list;
	}

	/**
	 * Convenience function for making it easy to create an array of ints
	 */
	public static int[] WI( int ... elements ) {
		return elements;
	}

	/**
	 * Convenience function for making it easy to create an array of ints
	 *
	 * @return [ a , elements[0] , ..., elements[N-1] ]
	 */
	public static int[] WI( int a , int[] elements ) {
		int out[] = new int[1+elements.length];
		out[0] = a;
		System.arraycopy(elements,0,out,1,elements.length);

		return out;
	}

	/**
	 * Convenience function for making it easy to create an array of ints
	 *
	 * @return [ elements[0] , ..., elements[N-1], a ]
	 */
	public static int[] WI( int[] elements , int a ) {
		int out[] = new int[1+elements.length];
		System.arraycopy(elements,0,out,0,elements.length);
		out[elements.length] = a;

		return out;
	}

	/**
	 * Truncates the head (element 0) from the array
	 */
	public static int[] TH(int[] elements ) {
		int[] out = new int[ elements.length-1 ];
		System.arraycopy(elements,1,out,0,out.length);
		return out;
	}

	public static List WI( int a , List list ) {
		List output = new ArrayList();

		for( int[] elements : list ) {
			output.add( WI(a,elements));
		}

		return output;
	}

	/**
	 * Adds a dimension to the input tensor.  Returns a new tensor which references
	 * the same data as the original.
	 *
	 * Input Shape = [5 4 3]
	 * Output Shape = [1 5 4 3]
	 */
	public static T AD( T input ) {
		T out;
		if( input instanceof Tensor_F64 ) {
			out = (T)new Tensor_F64();
		} else {
			throw new RuntimeException("Unsupported type");
		}
		out.shape = WI(1,input.shape);
		out.setData(input.getData());
		out.computeStrides();
		return out;

	}

	/**
	 * Returns the total length of all the tensors in the list summed together
	 *
	 * @param shapes List of tensor shapes
	 * @return Sum of tensor lengths
	 */
	public static int sumTensorLength( List shapes ) {
		int total = 0;

		for( int i =0; i < shapes.size(); i++ ) {
			total += tensorLength(shapes.get(i));
		}

		return total;
	}

	/**
	 * Returns the total length of one tensor
	 *
	 * @param shape shape of a tensor
	 * @return length of the Tensor's data arrayn
	 */
	public static int tensorLength( int... shape ) {
		if( shape.length == 0 )
			return 0;
		int N = 1;
		for (int i = 0; i < shape.length; i++) {
			N *= shape[i];
		}
		return N;
	}

	/**
	 * Compares a list of tensors shape's against each other.
	 * This simply invokes {@link #checkShape(String, int, int[], int[], boolean)}
	 *
	 * @param which String describing which variables are being checked
	 * @param expected List of expected tensors
	 * @param actual List of actual tensors.  Axis 0 is optionally ignored here, see ignoreAxis0
	 * @param ignoreAxis0 true to ignore axis 0
	 */
	public static void checkShape(String which, List expected , List> actual , boolean ignoreAxis0 )
	{
		if( expected.size() != actual.size() )
			throw new IllegalArgumentException(
					which+": Unexpected number of tensors. "+expected.size()+" vs "+actual.size());

		for (int i = 0; i < expected.size(); i++) {
			int[] e = expected.get(i);
			int[] a = actual.get(i).getShape();

			checkShape(which, i, e, a, ignoreAxis0);
		}
	}

	/**
	 * Checks to see if the two tensors have the same shape
	 * @param a tensor
	 * @param b tensor
	 */
	public static void checkShape( Tensor_F64 a , Tensor_F64 b ) {
		if( a.shape.length != b.shape.length ) {
			throw new IllegalArgumentException("Dimension of tensors do not match. "+a.shape.length+" "+b.shape.length);
		}
		for (int i = 0; i < a.shape.length; i++) {
			int da = a.shape[i];
			int db = b.shape[i];

			if( da != db ) {
				throw new IllegalArgumentException("dimension "+i+"  does not match.  "+da+"  "+db);
			}
		}
	}

	/**
	 * Checks to see if the two tensors have the same shape
	 * @param a tensor
	 * @param b tensor
	 */
	public static void checkShape( Tensor_F32 a , Tensor_F32 b ) {
		if( a.shape.length != b.shape.length ) {
			throw new IllegalArgumentException("Dimension of tensors do not match. "+a.shape.length+" "+b.shape.length);
		}
		for (int i = 0; i < a.shape.length; i++) {
			int da = a.shape[i];
			int db = b.shape[i];

			if( da != db ) {
				throw new IllegalArgumentException("dimension "+i+"  does not match.  "+da+"  "+db);
			}
		}
	}

	/**
	 * Checks to see if the two tensors have the same shape, with the option to ignore the first axis for the 'actual'
	 * shape.   The first axis is typically the mini-batch, but the expected value might not include the mini-batch
	 * since the size of the mini-batch is determined later on.
     * Throws an {@link IllegalArgumentException} if they don't match.
	 *
	 * @param which String describing which variable is being checked
	 * @param tensor Index of the tensor in a tensor list.  Used to provide a more detailed error message.
	 *               If < 0 then this is ignored
	 * @param expected Expected shape.
	 * @param actual Actual shape.  Axis 0 is optionally ignored.
	 * @param ignoreAxis0 If true it will ignore the first dimension in expected
	 */
	public static void checkShape(String which, int tensor, int[] expected, int[] actual, boolean ignoreAxis0 ) {

		if( ignoreAxis0 ) {
			if (expected.length + 1 != actual.length) {
				String header = tensor >= 0 ? which + ":  Tensor[" + tensor + "] " : which + ": ";
				throw new IllegalArgumentException(header + " dimension doesn't match, expected = "
						+ (expected.length + 1) + " found = " + actual.length);
			} else {
				for (int i = 0; i < expected.length; i++) {
					if (expected[i] != actual[i+1]) {
						String header = tensor >= 0 ? which + ":  Tensor[" + tensor + "] " : which + ": ";

						throw new IllegalArgumentException(header + " shapes don't match, expected = "
								+ toStringShape(expected) + ", found = " + toStringShapeA(actual));
					}
				}
			}
		} else {
			if (expected.length != actual.length) {
				String header = tensor >= 0 ? which + ":  Tensor[" + tensor + "] " : which + ": ";
				throw new IllegalArgumentException(header + " dimension doesn't match, expected = "
						+ expected.length + " found = " + actual.length);
			} else {
				for (int i = 0; i < expected.length; i++) {
					if (expected[i] != actual[i]) {
						String header = tensor >= 0 ? which + ":  Tensor[" + tensor + "] " : which + ": ";

						throw new IllegalArgumentException(header + " shapes don't match, expected = "
								+ toStringShape(expected) + ", found = " + toStringShape(actual));
					}
				}
			}
		}
	}

	public static String toStringShapeA( int []shape ) {
		String out = "( * , ";
		for (int i = 1; i < shape.length; i++) {
			out += shape[i] +" , ";
		}
		return out + ")";
	}

	public static String toStringShape( int []shape ) {
		String out = "( ";
		for (int i = 0; i < shape.length; i++) {
			out += shape[i] +" , ";
		}
		return out + ")";
	}

	/**
	 * 

Computes the number of elements for an inner portion of the tensor starting at * the specified index and going outside

* * Example:
* Tensor shape = (d[0], ... , d[K-1]). Then if start dimen is 2, the output will * be the product of d[2] to d[K-1]. */ public static int outerLength(int[] shape , int startDimen ) { if( startDimen >= shape.length ) return 0; int D = 1; for (int i = startDimen; i < shape.length; i++) { D *= shape[i]; } return D; } public static File pathToRoot() { File active = new File(".").getAbsoluteFile(); while( active != null ) { boolean foundModules = false; boolean foundExamples = false; boolean foundSettings = false; File[] children = active.listFiles(); if( children == null ) break; for( File d : children ) { if( d.isDirectory() && d.getName().endsWith("modules")) { foundModules = true; } if( d.isDirectory() && d.getName().endsWith("examples")) { foundExamples = true; } if( d.isFile() && d.getName().equals("settings.gradle")) { foundSettings = true; } } if( foundModules && foundExamples && foundSettings ) { return active; } else { active = active.getParentFile(); } } throw new RuntimeException("Cant find the project root directory"); } /** * Computes the sum of all the elements in the tensor * @param tensor Tensor */ public static double elementSum( Tensor tensor ) { if( tensor instanceof Tensor_F64 ) { return TensorOps_F64.elementSum( (Tensor_F64)tensor ); } else if( tensor instanceof Tensor_F32 ) { return TensorOps_F32.elementSum( (Tensor_F32)tensor ); } else { throw new IllegalArgumentException("Support not added yet for this tensor type"); } } public static void fill( Tensor t , double value ) { if( t instanceof Tensor_F64 ) { TensorOps_F64.fill( (Tensor_F64)t, value ); } else if( t instanceof Tensor_F32 ) { TensorOps_F32.fill( (Tensor_F32)t, (float)value ); } else { throw new IllegalArgumentException("Support not added yet for this tensor type"); } } public static void boundSpatial( int bounds[] , int rows , int cols ) { if( bounds[0] < 0 ) bounds[0] = 0; if( bounds[1] < 0 ) bounds[1] = 0; if( bounds[2] > rows ) bounds[2] = rows; if( bounds[3] > cols ) bounds[3] = cols; } }




© 2015 - 2025 Weber Informatics LLC | Privacy Policy