org.nd4j.linalg.api.ops.ShapeOp Maven / Gradle / Ivy
package org.nd4j.linalg.api.ops;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape;
import java.util.ArrayList;
import java.util.List;
/**
* Shape manipulation ops
*
* @author Adam Gibson
*/
@Slf4j
public abstract class ShapeOp extends BaseOp {
public ShapeOp() {}
public ShapeOp(SameDiff sameDiff) {
this.sameDiff = sameDiff;
}
public ShapeOp(SameDiff sameDiff, SDVariable i_v, boolean inPlace) {
this(sameDiff,i_v,i_v.getShape(),inPlace,null);
}
public ShapeOp(SameDiff sameDiff,
SDVariable i_v,
int[] shape,
boolean inPlace,
Object[] extraArgs) {
super(sameDiff,inPlace,extraArgs);
if (i_v != null) {
f().validateDifferentialFunctionsameDiff(i_v);
this.xVertexId = i_v.getVarName();
sameDiff.addArgsFor(new String[]{xVertexId},this);
if(Shape.isPlaceholderShape(i_v.getShape())) {
sameDiff.addPropertyToResolve(this,i_v.getVarName());
}
} else {
throw new IllegalArgumentException("Input not null variable.");
}
}
/**
* Specify an alternative output array
*
* @param x the input
* @param z the output
* @param n the number of elements to iterate on
*/
public ShapeOp(INDArray x, INDArray z, long n) {
super(x, z, n);
}
public ShapeOp(INDArray x, INDArray y, INDArray z, long n) {
super(x, y, z, n);
}
@Override
public List calculateOutputShape() {
throw new UnsupportedOperationException();
}
@Override
public Type opType() {
return Type.SHAPE;
}
/**
* An op for one ndarray
*
* @param x the ndarray
*/
public ShapeOp(INDArray x) {
super(x);
}
/**
* Specify an alternative result array
*
* @param x the input
* @param z the output array
*/
public ShapeOp(INDArray x, INDArray z) {
super(x, z);
}
}