All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.linalg.api.dimensionfunctions.DimensionFunctions Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
package org.nd4j.linalg.api.dimensionfunctions;

import com.google.common.base.Function;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.reduceops.Ops;

/**
 * Dimension wise functions
 *
 * @author Adam Gibson
 */
public class DimensionFunctions {
    public static Function normmax(final int dimension) {
        return new Function() {
            @Override
            public INDArray apply(INDArray input) {
                return input.normmax(dimension);
            }
        };
    }
    public static Function norm2(final int dimension) {
        return new Function() {
            @Override
            public INDArray apply(INDArray input) {
                return input.norm2(dimension);
            }
        };
    }

    public static Function norm1(final int dimension) {
        return new Function() {
            @Override
            public INDArray apply(INDArray input) {
                return input.norm1(dimension);
            }
        };
    }


    public static Function sum(final int dimension) {
        return new Function() {
            @Override
            public INDArray apply(INDArray input) {

                return input.sum(dimension);
            }
        };
    }

    public static Function var(final int dimension) {
        return new Function() {
            @Override
            public INDArray apply(INDArray input) {

                return input.var(dimension);
            }
        };
    }

    public static Function std(final int dimension) {
        return new Function() {
            @Override
            public INDArray apply(INDArray input) {

                return input.std(dimension);
            }
        };
    }

    public static Function prod(final int dimension) {
        return new Function() {
            @Override
            public INDArray apply(INDArray input) {
                return input.prod(dimension);
            }
        };
    }

    public static Function cumsum(final int dimension) {
        return new Function() {
            @Override
            public INDArray apply(INDArray input) {
                return input.cumsum(dimension);
            }
        };
    }


    public static Function mean(final int dimension) {
        return new Function() {
            @Override
            public INDArray apply(INDArray input) {
                return input.mean(dimension);
            }
        };
    }

    public static Function min(final int dimension) {
        return new Function() {
            @Override
            public INDArray apply(INDArray input) {
                return input.min(dimension);
            }
        };
    }

    public static Function max(final int dimension) {
        return new Function() {
            @Override
            public INDArray apply(INDArray input) {
                return input.max(dimension);
            }
        };
    }



    public static Function norm2() {
        return new Function() {
            @Override
            public INDArray apply(INDArray input) {
                return  Nd4j.scalar(Ops.norm2(input));
            }
        };
    }

    public static Function norm1() {
        return new Function() {
            @Override
            public INDArray apply(INDArray input) {
                return Nd4j.scalar(Ops.norm1(input));
            }
        };
    }


    public static Function sum() {
        return new Function() {
            @Override
            public INDArray apply(INDArray input) {
                return Nd4j.scalar(Ops.sum(input));
            }
        };
    }

    public static Function var() {
        return new Function() {
            @Override
            public INDArray apply(INDArray input) {

                return  Nd4j.scalar(Ops.var(input));
            }
        };
    }

    public static Function std() {
        return new Function() {
            @Override
            public INDArray apply(INDArray input) {
                return Nd4j.scalar(Ops.std(input));
            }
        };
    }

    public static Function prod() {
        return new Function() {
            @Override
            public INDArray apply(INDArray input) {
                return Nd4j.scalar(Ops.prod(input));
            }
        };
    }

    public static Function cumsum() {
        return new Function() {
            @Override
            public INDArray apply(INDArray input) {
                double s = 0.0;
                for (int i = 0; i < input.length(); i++) {
                    if(input.data().dataType().equals(DataBuffer.FLOAT))
                        s += input.getDouble(i);
                    else
                        s+= input.getDouble(i);
                    input.putScalar(i, s);
                }

                return input;
            }
        };
    }


    public static Function mean() {
        return new Function() {
            @Override
            public INDArray apply(INDArray input) {
                return Nd4j.scalar(Ops.mean(input));
            }
        };
    }

    public static Function min() {
        return new Function() {
            @Override
            public INDArray apply(INDArray input) {
                return Nd4j.scalar(Ops.min(input));
            }
        };
    }

    public static Function max() {
        return new Function() {
            @Override
            public INDArray apply(INDArray input) {
                return Nd4j.scalar(Ops.max(input));
            }
        };
    }

    public static Function normmax() {
        return new Function() {
            @Override
            public INDArray apply(INDArray input) {
                return Nd4j.scalar(Ops.normmax(input));
            }
        };
    }



}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy