org.nd4j.linalg.api.dimensionfunctions.DimensionFunctions Maven / Gradle / Ivy
package org.nd4j.linalg.api.dimensionfunctions;
import com.google.common.base.Function;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.reduceops.Ops;
/**
* Dimension wise functions
*
* @author Adam Gibson
*/
public class DimensionFunctions {
public static Function normmax(final int dimension) {
return new Function() {
@Override
public INDArray apply(INDArray input) {
return input.normmax(dimension);
}
};
}
public static Function norm2(final int dimension) {
return new Function() {
@Override
public INDArray apply(INDArray input) {
return input.norm2(dimension);
}
};
}
public static Function norm1(final int dimension) {
return new Function() {
@Override
public INDArray apply(INDArray input) {
return input.norm1(dimension);
}
};
}
public static Function sum(final int dimension) {
return new Function() {
@Override
public INDArray apply(INDArray input) {
return input.sum(dimension);
}
};
}
public static Function var(final int dimension) {
return new Function() {
@Override
public INDArray apply(INDArray input) {
return input.var(dimension);
}
};
}
public static Function std(final int dimension) {
return new Function() {
@Override
public INDArray apply(INDArray input) {
return input.std(dimension);
}
};
}
public static Function prod(final int dimension) {
return new Function() {
@Override
public INDArray apply(INDArray input) {
return input.prod(dimension);
}
};
}
public static Function cumsum(final int dimension) {
return new Function() {
@Override
public INDArray apply(INDArray input) {
return input.cumsum(dimension);
}
};
}
public static Function mean(final int dimension) {
return new Function() {
@Override
public INDArray apply(INDArray input) {
return input.mean(dimension);
}
};
}
public static Function min(final int dimension) {
return new Function() {
@Override
public INDArray apply(INDArray input) {
return input.min(dimension);
}
};
}
public static Function max(final int dimension) {
return new Function() {
@Override
public INDArray apply(INDArray input) {
return input.max(dimension);
}
};
}
public static Function norm2() {
return new Function() {
@Override
public INDArray apply(INDArray input) {
return Nd4j.scalar(Ops.norm2(input));
}
};
}
public static Function norm1() {
return new Function() {
@Override
public INDArray apply(INDArray input) {
return Nd4j.scalar(Ops.norm1(input));
}
};
}
public static Function sum() {
return new Function() {
@Override
public INDArray apply(INDArray input) {
return Nd4j.scalar(Ops.sum(input));
}
};
}
public static Function var() {
return new Function() {
@Override
public INDArray apply(INDArray input) {
return Nd4j.scalar(Ops.var(input));
}
};
}
public static Function std() {
return new Function() {
@Override
public INDArray apply(INDArray input) {
return Nd4j.scalar(Ops.std(input));
}
};
}
public static Function prod() {
return new Function() {
@Override
public INDArray apply(INDArray input) {
return Nd4j.scalar(Ops.prod(input));
}
};
}
public static Function cumsum() {
return new Function() {
@Override
public INDArray apply(INDArray input) {
double s = 0.0;
for (int i = 0; i < input.length(); i++) {
if(input.data().dataType().equals(DataBuffer.FLOAT))
s += input.getDouble(i);
else
s+= input.getDouble(i);
input.putScalar(i, s);
}
return input;
}
};
}
public static Function mean() {
return new Function() {
@Override
public INDArray apply(INDArray input) {
return Nd4j.scalar(Ops.mean(input));
}
};
}
public static Function min() {
return new Function() {
@Override
public INDArray apply(INDArray input) {
return Nd4j.scalar(Ops.min(input));
}
};
}
public static Function max() {
return new Function() {
@Override
public INDArray apply(INDArray input) {
return Nd4j.scalar(Ops.max(input));
}
};
}
public static Function normmax() {
return new Function() {
@Override
public INDArray apply(INDArray input) {
return Nd4j.scalar(Ops.normmax(input));
}
};
}
}