All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.improbable.keanu.tensor.lng.LongTensor Maven / Gradle / Ivy

package io.improbable.keanu.tensor.lng;

import io.improbable.keanu.tensor.FixedPointTensor;
import io.improbable.keanu.tensor.TensorFactories;
import org.apache.commons.lang3.ArrayUtils;

import static io.improbable.keanu.tensor.TensorShape.getAbsoluteDimension;

public interface LongTensor extends FixedPointTensor {

    static LongTensor create(long value, long[] shape) {
        return TensorFactories.longTensorFactory.create(value, shape);
    }

    static LongTensor create(long[] values, long... shape) {
        return TensorFactories.longTensorFactory.create(values, shape);
    }

    static LongTensor create(long... values) {
        return create(values, values.length);
    }

    static LongTensor ones(long... shape) {
        return TensorFactories.longTensorFactory.ones(shape);
    }

    static LongTensor eye(int n) {
        return TensorFactories.longTensorFactory.eye(n);
    }

    static LongTensor zeros(long... shape) {
        return TensorFactories.longTensorFactory.zeros(shape);
    }

    static LongTensor scalar(long scalarValue) {
        return TensorFactories.longTensorFactory.scalar(scalarValue);
    }

    static LongTensor vector(long... values) {
        return create(values, values.length);
    }

    static LongTensor stack(int dimension, LongTensor... toStack) {
        long[] shape = toStack[0].getShape();
        int absoluteDimension = getAbsoluteDimension(dimension, shape.length + 1);
        long[] stackedShape = ArrayUtils.insert(absoluteDimension, shape, 1);

        LongTensor[] reshaped = new LongTensor[toStack.length];
        for (int i = 0; i < toStack.length; i++) {
            reshaped[i] = toStack[i].reshape(stackedShape);
        }

        return concat(absoluteDimension, reshaped);
    }

    static LongTensor concat(LongTensor... toConcat) {
        return concat(0, toConcat);
    }

    static LongTensor concat(int dimension, LongTensor... toConcat) {
        return TensorFactories.longTensorFactory.concat(dimension, toConcat);
    }

    static LongTensor min(LongTensor a, LongTensor b) {
        return a.duplicate().minInPlace(b);
    }

    static LongTensor max(LongTensor a, LongTensor b) {
        return a.duplicate().maxInPlace(b);
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy