org.nd4j.linalg.util.ComplexNDArrayUtil Maven / Gradle / Ivy
/*
*
* * Copyright 2015 Skymind,Inc.
* *
* * Licensed under the Apache License, Version 2.0 (the "License");
* * you may not use this file except in compliance with the License.
* * You may obtain a copy of the License at
* *
* * http://www.apache.org/licenses/LICENSE-2.0
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS,
* * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* * See the License for the specific language governing permissions and
* * limitations under the License.
*
*
*/
package org.nd4j.linalg.util;
import org.nd4j.linalg.api.complex.IComplexNDArray;
import org.nd4j.linalg.api.complex.IComplexNumber;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.Indices;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.ops.transforms.Transforms;
import java.util.Arrays;
/**
* IComplexNDArray operations
*
* @author Adam Gibson
*/
public class ComplexNDArrayUtil {
public static IComplexNDArray exp(IComplexNDArray toExp) {
return expi(toExp.dup());
}
/**
* Returns the exponential of a complex ndarray
*
* @param toExp the ndarray to convert
* @return the exponential of the specified
* ndarray
*/
public static IComplexNDArray expi(IComplexNDArray toExp) {
IComplexNDArray flattened = toExp.ravel();
for (int i = 0; i < flattened.length(); i++) {
IComplexNumber n = flattened.getComplex(i);
flattened.put(i, Nd4j.scalar(ComplexUtil.exp(n)));
}
return flattened.reshape(toExp.shape());
}
/**
* Center an array
*
* @param arr the arr to center
* @param shape the shape of the array
* @return the center portion of the array based on the
* specified shape
*/
public static IComplexNDArray center(IComplexNDArray arr, int[] shape) {
if (arr.length() < ArrayUtil.prod(shape))
return arr;
for (int i = 0; i < shape.length; i++)
if (shape[i] < 1)
shape[i] = 1;
INDArray shapeMatrix = ArrayUtil.toNDArray(shape);
INDArray currShape = ArrayUtil.toNDArray(arr.shape());
INDArray startIndex = Transforms.floor(currShape.sub(shapeMatrix).divi(Nd4j.scalar(2)));
INDArray endIndex = startIndex.add(shapeMatrix);
INDArrayIndex[] indexes = Indices.createFromStartAndEnd(startIndex, endIndex);
if (shapeMatrix.length() > 1)
return arr.get(indexes);
else {
IComplexNDArray ret = Nd4j.createComplex(new int[]{(int) shapeMatrix.getDouble(0)});
int start = (int) startIndex.getDouble(0);
int end = (int) endIndex.getDouble(0);
int count = 0;
for (int i = start; i < end; i++) {
ret.putScalar(count++, arr.getComplex(i));
}
return ret;
}
}
/**
* Truncates an ndarray to the specified shape.
* If the shape is the same or greater, it just returns
* the original array
*
* @param nd the ndarray to truncate
* @param n the number of elements to truncate to
* @return the truncated ndarray
*/
public static IComplexNDArray truncate(IComplexNDArray nd, int n, int dimension) {
if (nd.isVector()) {
IComplexNDArray truncated = Nd4j.createComplex(new int[]{1,n});
for (int i = 0; i < n; i++)
truncated.putScalar(i, nd.getComplex(i));
return truncated;
}
if (nd.size(dimension) > n) {
int[] shape = ArrayUtil.copy(nd.shape());
shape[dimension] = n;
IComplexNDArray ret = Nd4j.createComplex(shape);
IComplexNDArray ndLinear = nd.linearView();
IComplexNDArray retLinear = ret.linearView();
for(int i = 0; i < ret.length(); i++)
retLinear.putScalar(i,ndLinear.getComplex(i));
return ret;
}
return nd;
}
/**
* Pads an ndarray with zeros
*
* @param nd the ndarray to pad
* @param targetShape the the new shape
* @return the padded ndarray
*/
public static IComplexNDArray padWithZeros(IComplexNDArray nd, int[] targetShape) {
if (Arrays.equals(nd.shape(), targetShape))
return nd;
//no padding required
if (ArrayUtil.prod(nd.shape()) >= ArrayUtil.prod(targetShape))
return nd;
IComplexNDArray ret = Nd4j.createComplex(targetShape);
INDArrayIndex[] targetShapeIndex = NDArrayIndex.createCoveringShape(nd.shape());
ret.put(targetShapeIndex,nd);
return ret;
}
}