All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.autodiff.ArrayFactory Maven / Gradle / Ivy

package org.nd4j.autodiff;

import lombok.AllArgsConstructor;
import lombok.Data;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.opstate.NDArrayInformation;
import org.nd4j.autodiff.opstate.NDArrayVertex;
import org.nd4j.autodiff.samediff.SameDiff;

import java.lang.reflect.Method;
import java.util.*;

@AllArgsConstructor
@Data
public class ArrayFactory implements AbstractFactory {

    private SameDiff graph;
    private Map methodNames;

    public ArrayFactory(SameDiff graph) {
        this.graph = graph;
        methodNames = new HashMap<>();
        Method[] methods = getClass().getDeclaredMethods();
        for(Method method : methods)
            methodNames.put(method.getName(),method);
    }



    @Override
    public SameDiff sameDiff() {
        return graph;
    }

    @Override
    public List methodNames() {
        return new ArrayList<>(methodNames.keySet());
    }

    @Override
    public ArrayField invoke(String name, Object[] args) {
        try {
            return (ArrayField) methodNames.get(name).invoke(this,args);
        } catch (Exception e) {
           throw new RuntimeException(e);
        }
    }

    @Override
    public ArrayField eq(ArrayField i_x, ArrayField i_y) {
        return i_x.eq(i_y);
    }

    @Override
    public ArrayField neq(ArrayField i_x, ArrayField i_y) {
        return i_x.neq(i_y);
    }

    @Override
    public ArrayField or(ArrayField i_x, ArrayField i_y) {
        return i_x.or(i_y);
    }

    @Override
    public ArrayField add(ArrayField i_x, Number value) {
        return i_x.add(value.doubleValue());
    }

    @Override
    public ArrayField sub(ArrayField i_x, Number value) {
        return i_x.minus(value.doubleValue());
    }

    @Override
    public ArrayField mul(ArrayField i_x, Number value) {
        return i_x.mul((long) value.doubleValue());
    }

    @Override
    public ArrayField div(ArrayField i_x, Number value) {
        return i_x.div(value.doubleValue());
    }

    @Override
    public ArrayField broadcast(ArrayField i_x, int... shape) {
        return i_x.broadcast(shape);
    }

    @Override
    public ArrayField repeat(ArrayField i_x, int axis) {
        return i_x.repeat(axis);
    }

    @Override
    public ArrayField tile(ArrayField i_x, int... repeat) {
        return i_x.tile(repeat);
    }

    @Override
    public ArrayField sum(ArrayField i_x, int... dimensions) {
        return i_x.sum(dimensions);
    }

    @Override
    public ArrayField prod(ArrayField i_x, int... dimensions) {
        return i_x.prod(dimensions);
    }

    @Override
    public ArrayField mean(ArrayField i_x, int... dimensions) {
        return i_x.mean(dimensions);
    }

    @Override
    public ArrayField std(ArrayField i_x, boolean biasCorrected, int... dimensions) {
        return i_x.std(dimensions);
    }

    @Override
    public ArrayField variance(ArrayField i_x, boolean biasCorrected, int... dimensions) {
        return i_x.variance(dimensions);
    }

    @Override
    public ArrayField max(ArrayField i_x, int... dimensions) {
        return i_x.max(dimensions);
    }

    @Override
    public ArrayField min(ArrayField i_x, int... dimensions) {
        return i_x.min(dimensions);
    }

    @Override
    public ArrayField norm1(ArrayField i_x, int... dimensions) {
        return i_x.norm1(dimensions);
    }

    @Override
    public ArrayField norm2(ArrayField i_x, int... dimensions) {
        return i_x.norm2(dimensions);
    }

    @Override
    public ArrayField normmax(ArrayField i_x, int... dimensions) {
        return i_x.normmax(dimensions);
    }

    @Override
    public ArrayField neg(ArrayField i_x) {
        return i_x.negate();
    }

    @Override
    public ArrayField transpose(ArrayField i_x) {
        return i_x.transpose();
    }

    @Override
    public ArrayField reshape(ArrayField i_x, int[] shape) {
        return i_x.reshape(shape);
    }

    @Override
    public ArrayField valueArrayOf(ArrayField i_x, int[] shape) {
        return i_x.valueArrayOf(shape);
    }

    @Override
    public ArrayField val(double v) {
        // return Nd4j.valueArrayOf(v,i);
        throw new UnsupportedOperationException();
    }

    @Override
    public ArrayField abs(ArrayField x) {
        return x.abs();
    }

    @Override
    public ArrayField min(ArrayField x, ArrayField y) {
       /* return x.doubleValue() < y.doubleValue() ? new ArrayField(
                x.doubleValue()) : new ArrayField(y.doubleValue());*/
        throw new UnsupportedOperationException();
    }

    @Override
    public ArrayField max(ArrayField x, ArrayField y) {
     /*   return x.doubleValue() > y.doubleValue() ? new ArrayField(
                x.doubleValue()) : new ArrayField(y.doubleValue());*/
        throw new UnsupportedOperationException();
    }



    @Override
    public ArrayField zero(int[] shape) {
        NDArrayInformation information = NDArrayInformation.builder()
                .arrId(UUID.randomUUID().toString()).scalarValue(0.0)
                .id("zero-" + UUID.randomUUID().toString()).owner(null).shape(shape).build();
        return new ArrayField(new NDArrayVertex(graph.getGraph().nextVertexId(), information), graph);
    }

    @Override
    public ArrayField one(int[] shape) {
        NDArrayInformation information = NDArrayInformation.builder()
                .arrId(UUID.randomUUID().toString()).scalarValue(1.0)
                .id("one-"  + UUID.randomUUID().toString()).owner(null).shape(shape).build();
        return new ArrayField(new NDArrayVertex(graph.getGraph().nextVertexId(), information), graph);
    }

    /**
     * Scalar value
     *
     * @param value
     * @return
     */
    @Override
    public ArrayField scalar(double value) {
        NDArrayInformation information = NDArrayInformation.builder()
                .arrId(UUID.randomUUID().toString()).scalarValue(value)
                .id(String.valueOf(value)).owner(null).shape(new int[]{1,1}).build();
        return new ArrayField(new NDArrayVertex(graph.getGraph().nextVertexId(), information), graph);
    }

    @Override
    public ArrayField cos(ArrayField x) {
        return x.cos();
    }

    @Override
    public ArrayField acos(ArrayField x) {
        return x.acos();
    }

    @Override
    public ArrayField cosh(ArrayField x) {
        return x.cosh();
    }

    @Override
    public ArrayField acosh(ArrayField x) {
        return x.acosh();
    }

    @Override
    public ArrayField sin(ArrayField x) {
        return x.sin();
    }

    @Override
    public ArrayField asin(ArrayField x) {
        return x.asin();
    }

    @Override
    public ArrayField sinh(ArrayField x) {
        return x.sinh();
    }

    @Override
    public ArrayField asinh(ArrayField x) {
        return x.asinh();
    }

    @Override
    public ArrayField tan(ArrayField x) {
        return x.tan();
    }

    @Override
    public ArrayField atan(ArrayField x) {
        return x.atan();
    }

    @Override
    public ArrayField atan2(ArrayField x, ArrayField y) {
        //   return new ArrayField(Math.atan2(x.doubleValue(), y.doubleValue()));
        throw new UnsupportedOperationException();
    }

    @Override
    public ArrayField tanh(ArrayField x) {
        return x.tanh();
    }

    @Override
    public ArrayField atanh(ArrayField x) {
        return x.atanh();
    }

    @Override
    public ArrayField exp(ArrayField x) {
        return x.exp();
    }

    @Override
    public ArrayField log(ArrayField x) {
        return x.log();
    }

    @Override
    public ArrayField log10(ArrayField x) {
        return x.log10();
    }

    @Override
    public ArrayField flat(ArrayField x) {
      /*  double xValue = x.doubleValue();
        return new ArrayField(-xValue + (xValue + xValue) * randomGenerator.nextDouble());*/
        throw new UnsupportedOperationException();
    }

    @Override
    public ArrayField mc(ArrayField x, ArrayField y) {
      /*  double max = Math.max(x.doubleValue() * (1 + y.doubleValue()),
                x.doubleValue() * (1 - y.doubleValue()));
        double min = Math.min(x.doubleValue() * (1 + y.doubleValue()),
                x.doubleValue() * (1 - y.doubleValue()));
        return new ArrayField(min + (max - min) * randomGenerator.nextDouble());*/
        throw new UnsupportedOperationException();
    }

    @Override
    public ArrayField rand(ArrayField x) {
        throw new UnsupportedOperationException();
    }

    @Override
    public ArrayField random(ArrayField x) {
        throw new UnsupportedOperationException();
    }

    @Override
    public ArrayField gauss(ArrayField x) {
        throw new UnsupportedOperationException();
    }

    @Override
    public ArrayField sgn(ArrayField x) {
        return x.sgn();
    }

    @Override
    public ArrayField ifx(ArrayField x, ArrayField y, ArrayField z) {
        //return x.doubleValue() > .5 ? y : z;
        throw new UnsupportedOperationException();
    }

    @Override
    public ArrayField buf(ArrayField x) {
        //return x.doubleValue() > .5 ? new ArrayField(1) : new ArrayField(0);
        throw new UnsupportedOperationException();
    }

    @Override
    public ArrayField inv(ArrayField x) {
        //  return x.doubleValue() > .5 ? new ArrayField(0) : new ArrayField(1);
        throw new UnsupportedOperationException();

    }

    @Override
    public ArrayField u(ArrayField x) {
        //return x.doubleValue() > 0 ? new ArrayField(1) : new ArrayField(0);
        throw new UnsupportedOperationException();

    }

    @Override
    public ArrayField uramp(ArrayField x) {
        // return x.doubleValue() > 0 ? new ArrayField(x.doubleValue()) : new ArrayField(0);
        throw new UnsupportedOperationException();

    }

    @Override
    public ArrayField pow(ArrayField x, ArrayField y) {
        return x.pow(y);
    }

    @Override
    public ArrayField pwr(ArrayField x, ArrayField y) {
        return x.pwr(y);
    }

    @Override
    public ArrayField pwrs(ArrayField x, ArrayField y) {
        return x.pwrs(y);
    }

    @Override
    public ArrayField sqrt(ArrayField x) {
        return x.sqrt();
    }

    @Override
    public ArrayField square(ArrayField x) {
        return x.square();
    }

    @Override
    public ArrayField hypot(ArrayField x, ArrayField y) {
        return x.pow(2).add(y.pow(2)).sqrt();
    }

    @Override
    public ArrayField floor(ArrayField value) {
        return value.floor();
    }

    @Override
    public ArrayField ceil(ArrayField value) {
        return value.ceil();
    }

    @Override
    public ArrayField round(ArrayField value) {
        return value.round();
    }

    @Override
    public ArrayField relu(ArrayField value) {
        return value.relu();
    }

    @Override
    public ArrayField leakyRelu(ArrayField value, double alpha) {
        return value.leakyRelu();
    }

    /**
     * Leaky relu with an alpha of
     * 0.01
     *
     * @param value the value to transform
     * @return
     */
    @Override
    public ArrayField leakyRelu(ArrayField value) {
        return value.leakyRelu();
    }

    @Override
    public ArrayField leakyReluDerivative(ArrayField value, double alpha) {
        return value.leakyReluDerivative(alpha);
    }

    /**
     * Leaky relu with an alpha of
     * 0.01
     *
     * @param value the value to transform
     * @return
     */
    @Override
    public ArrayField leakyReluDerivative(ArrayField value) {
        return value.leakyReluDerivative(0.001);
    }

    @Override
    public ArrayField hardTanh(ArrayField value) {
        return value.hardTanh();
    }

    @Override
    public ArrayField hardTanhDerivative(ArrayField value) {
        return value.hardTanh();
    }

    @Override
    public ArrayField sigmoid(ArrayField value) {
        return value.sigmoid();
    }

    @Override
    public ArrayField sigmoidDerivative(ArrayField value) {
        return value.sigmoidDerivative();
    }

    @Override
    public ArrayField softmax(ArrayField value) {
        return value.softmax();
    }

    @Override
    public ArrayField elu(ArrayField value) {
        return value.elu();
    }

    @Override
    public ArrayField eluDerivative(ArrayField value) {
        return value.eluDerivative();
    }

    @Override
    public ArrayField step(ArrayField value) {
        return value.step();
    }

    @Override
    public ArrayField sign(ArrayField value) {
        return value.sgn();
    }

    @Override
    public ArrayField softsign(ArrayField value) {
        return value.softsign();
    }

    @Override
    public ArrayField softsignDeriviative(ArrayField value) {
        return value.softsignDerivative();
    }

    @Override
    public ArrayField softplus(ArrayField value) {
        return value.softplus();
    }

    @Override
    public ArrayField rollAxis(ArrayField value, int axis) {
        return value.rollAxis(axis);
    }

    @Override
    public ArrayField lossSquaredHinge(DifferentialFunction iX, DifferentialFunction i_y, int[] dimensions) {
        return iX.getValue(true).lossSquaredHinge(i_y.getValue(true),dimensions);
    }

    @Override
    public ArrayField lossPoisson(DifferentialFunction iX, DifferentialFunction i_y, int[] dimensions) {
        return iX.getValue(true).lossPoisson(i_y.getValue(true),dimensions);
    }

    @Override
    public ArrayField lossNegativeLogLikelihood(DifferentialFunction iX, DifferentialFunction i_y, int[] dimensions) {
        return iX.getValue(true).lossNegativeLogLikelihood(i_y.getValue(true),dimensions);
    }

    @Override
    public ArrayField lossMSLE(DifferentialFunction iX, DifferentialFunction i_y, int[] dimensions) {
        return iX.getValue(true).lossMSLE(i_y.getValue(true),dimensions);
    }

    @Override
    public ArrayField lossMCXENT(DifferentialFunction iX, DifferentialFunction i_y, int[] dimensions) {
        return iX.getValue(true).lossMCXENT(i_y.getValue(true),dimensions);
    }

    @Override
    public ArrayField lossMSE(DifferentialFunction iX, DifferentialFunction i_y, int[] dimensions) {
        return iX.getValue(true).lossMSE(i_y.getValue(true),dimensions);
    }

    @Override
    public ArrayField lossMAPE(DifferentialFunction iX, DifferentialFunction i_y, int[] dimensions) {
        return iX.getValue(true).lossMAPE(i_y.getValue(true),dimensions);
    }

    @Override
    public ArrayField lossMAE(DifferentialFunction iX, DifferentialFunction i_y, int[] dimensions) {
        return iX.getValue(true).lossMAE(i_y.getValue(true),dimensions);
    }

    @Override
    public ArrayField lossL2(DifferentialFunction iX, DifferentialFunction i_y, int[] dimensions) {
        return iX.getValue(true).lossL2(i_y.getValue(true),dimensions);
    }

    @Override
    public ArrayField lossL1(DifferentialFunction iX, DifferentialFunction i_y, int[] dimensions) {
        return iX.getValue(true).lossL1(i_y.getValue(true),dimensions);
    }

    @Override
    public ArrayField lossKLD(DifferentialFunction iX, DifferentialFunction i_y, int[] dimensions) {
        return iX.getValue(true).lossKLD(i_y.getValue(true),dimensions);
    }

    @Override
    public ArrayField lossHinge(DifferentialFunction iX, DifferentialFunction i_y, int[] dimensions) {
        return iX.getValue(true).lossHinge(i_y.getValue(true),dimensions);
    }

    @Override
    public ArrayField lossCosineSimilarity(DifferentialFunction iX, DifferentialFunction i_y, int[] dimensions) {
        return iX.getValue(true).lossCosineSimilarity(i_y.getValue(true),dimensions);
    }

    @Override
    public ArrayField lossBinaryXENT(DifferentialFunction iX, DifferentialFunction i_y, int[] dimensions) {
        return iX.getValue(true).lossBinaryXENT(i_y.getValue(true),dimensions);
    }

    @Override
    public ArrayField manhattanDistance(DifferentialFunction iX, DifferentialFunction i_y, int[] dimensions) {
        return iX.getValue(true).manhattanDistance(i_y.getValue(true),dimensions);
    }

    @Override
    public ArrayField euclideanDistance(DifferentialFunction iX, DifferentialFunction i_y, int[] dimensions) {
        return iX.getValue(true).euclideanDistance(i_y.getValue(true),dimensions);
    }

    @Override
    public ArrayField cosineSimilarity(DifferentialFunction iX, DifferentialFunction i_y, int[] dimensions) {
        return iX.getValue(true).cosineSimilarity(i_y.getValue(true),dimensions);
    }

    @Override
    public ArrayField expandDims(ArrayField input, int dim) {
        return input.expandDims(dim);
    }

    @Override
    public ArrayField mmul(DifferentialFunction input, DifferentialFunction y) {
        return input.getValue(true).mmul(y.getValue(true));

    }

    @Override
    public ArrayField tensorMmul(DifferentialFunction arrayField, DifferentialFunction y, int[][] dimensions) {
        return arrayField.getValue(true).tensorMmul(y,dimensions);
    }

    @Override
    public ArrayField permute(ArrayField value, int[] dimensions) {
        return value.permute(dimensions);
    }

    @Override
    public String toString() {
        return "ArrayFactory{" +
                "methodNames=" + methodNames +
                '}';
    }

    @Override
    public ArrayField set(ArrayField value, ArrayField value1) {
        return value.set(value1);
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy