All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.autodiff.functions.DifferentialFunction Maven / Gradle / Ivy

package org.nd4j.autodiff.functions;

import java.util.List;
import java.util.UUID;

import com.google.common.base.Preconditions;
import lombok.*;
import org.nd4j.autodiff.AbstractIdentityFactory;
import org.nd4j.autodiff.ArrayFactory;
import org.nd4j.autodiff.ArrayField;
import org.nd4j.autodiff.Field;
import org.nd4j.autodiff.graph.Graph;
import org.nd4j.autodiff.opstate.NDArrayInformation;
import org.nd4j.autodiff.opstate.NDArrayVertex;
import org.nd4j.autodiff.opstate.OpState;
import org.nd4j.autodiff.samediff.SDGraph;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.util.ArrayUtil;


@AllArgsConstructor
@Data
@NoArgsConstructor
public abstract class DifferentialFunction>
        implements Field>,
        Differential> {

    @Getter
    @Setter
    protected SameDiff sameDiff;
    @Getter
    protected OpState opState;
    @Getter
    @Setter
    protected int vertexId;
    protected Object[] extraArgs;


    /**
     *
     * @param sameDiff
     * @param extraArgs
     */
    public DifferentialFunction(SameDiff sameDiff, Object[] extraArgs) {
        this.sameDiff = sameDiff;
        this.extraArgs = extraArgs;
    }


    /**
     * Get the result shape for this function
     * @return
     */
    public int[] getResultShape() {
        if(opState == null)
            throw new IllegalStateException("Unable to get result shape with null op state");
        return opState.getResult().getShape();
    }

    /**
     * Get the value of this function
     * @return
     */
    public abstract X doGetValue();




    /**
     * Get the value specifying
     * whether to freeze the graph or not
     * @param freeze whether to freeze the graph or not,
     *               this means whether to add nodes to the internal
     *               computation graph or not
     * @return the value of this function
     */
    public  X getValue(boolean freeze) {
        boolean graphAlreadyFrozen = this.sameDiff.getGraph().isFrozen();
        //if graph is already frozen leave it frozen
        if(freeze && !graphAlreadyFrozen) {
            this.sameDiff.getGraph().freeze();
        }

        X val = doGetValue();
        if(val instanceof ArrayField && !freeze) {
            ArrayField arrayField = (ArrayField) val;
            NDArrayVertex vertex = (NDArrayVertex) getSameDiff().getGraph().getVertex(getVertexId());
            arrayField.setVertex(vertex);
            arrayField.setOps(this.sameDiff);
            Preconditions.checkState(vertex != null,"Vertex " + getVertexId() + " was null.");
            Preconditions.checkState(vertex.getValue() != null,"Vertex did not have a value set.");
            arrayField.getInput().setScalarValue(vertex.getValue().getScalarValue());
            arrayField.setInput(vertex.getValue());
            Preconditions.checkState(sameDiff == arrayField.getOps(),"Passed in array factory != the passed in graph. Unable to instantiate.");

        }

        if(freeze && !graphAlreadyFrozen) {
            this.sameDiff.getGraph().unfreeze();
        }

        return val;
    }

    @Override
    public abstract double getReal();


    @Override
    public  String getFormula(List> variables) {
        sameDiff.getGraph().freeze();
        String ret = doGetFormula(variables);
        sameDiff.getGraph().unfreeze();
        return ret;
    }

    public abstract String doGetFormula(List> variables);

    public abstract String functionName();

    @Override
    public abstract String toString();



    public boolean isConstant() {
        return false;
    }


    public boolean isVariable() {
        return false;
    }

    @Override
    public abstract DifferentialFunction diff(DifferentialFunction i_v1);

    private void validateDifferentialFunctionGraph(DifferentialFunction function) {
        Preconditions.checkState(function.getSameDiff() == this.getSameDiff(),"Function applications must be contained in same graph. The left " + function +" must match this function " + this);

    }


    @Override
    public DifferentialFunction rdivi(DifferentialFunction i_v) {
        validateDifferentialFunctionGraph(i_v);
        X ret = i_v.getValue(true).rdivi(getValue(true));
        return new Constant<>(sameDiff, ret, i_v.getResultShape());
    }

    @Override
    public DifferentialFunction rsubi(DifferentialFunction i_v) {
        validateDifferentialFunctionGraph(i_v);
        X ret = i_v.getValue(true).rsubi(getValue(true));
        return new Constant<>(sameDiff, ret, i_v.getResultShape());
    }

    @Override
    public DifferentialFunction addi(DifferentialFunction i_v) {
        validateDifferentialFunctionGraph(i_v);
        X ret = i_v.getValue(true).addi(getValue(true));
        return new Constant<>(sameDiff, ret, i_v.getResultShape());
    }

    @Override
    public DifferentialFunction muli(DifferentialFunction i_v) {
        validateDifferentialFunctionGraph(i_v);
        X ret = i_v.getValue(true).muli(getValue(true));
        return new Constant<>(sameDiff, ret, i_v.getResultShape());
    }

    @Override
    public DifferentialFunction subi(DifferentialFunction i_v) {
        validateDifferentialFunctionGraph(i_v);
        X ret = i_v.getValue(true).subi(getValue(true));
        return new Constant<>(sameDiff, ret, i_v.getResultShape());
    }




    @Override
    public DifferentialFunction divi(DifferentialFunction i_v) {
        validateDifferentialFunctionGraph(i_v);
        X ret = i_v.getValue(true).divi(getValue(true));
        return new Constant<>(sameDiff, ret, i_v.getResultShape());
    }

    @Override
    public DifferentialFunction inversei() {
        DifferentialFunction ret = new Inverse<>(sameDiff,this,true);
        return ret;
    }

    @Override
    public DifferentialFunction negatei() {
        DifferentialFunction ret = new Negative<>(sameDiff,this,true);
        return ret;
    }

    @Override
    public DifferentialFunction muli(double i_n) {
        PolynomialTerm ret =  new PolynomialTerm<>(sameDiff,i_n, this, 1,true);
        return ret;
    }

    @Override
    public DifferentialFunction powi(int i_n) {
        PolynomialTerm ret = new PolynomialTerm<>(sameDiff,1L,
                this, i_n,true);
        return ret;
    }

    @Override
    public DifferentialFunction addi(double i_v) {
        Scalar constant = new Scalar<>(sameDiff, i_v,true);
        return constant.addi(this);
    }

    @Override
    public DifferentialFunction subi(double i_v) {
        Scalar constant = new Scalar<>(sameDiff, i_v,true);
        return constant.subi(this);
    }



    @Override
    public DifferentialFunction divi(double v) {
        Scalar constant = new Scalar<>(sameDiff, 
                v,true);
        return this.divi(constant);
    }


    @Override
    public DifferentialFunction rsubi(double v) {
        Scalar constant = new Scalar<>(sameDiff, v,true);
        return this.rsubi(constant);
    }

    @Override
    public DifferentialFunction rdivi(double v) {
        Scalar constant = new Scalar<>(sameDiff, v,true);
        return this.rdivi(constant);
    }

    @Override
    public DifferentialFunction set(DifferentialFunction i_v) {
        validateDifferentialFunctionGraph(i_v);
        X ret = getValue(true).set(i_v.getValue(true));
        return new Constant<>(sameDiff, ret, i_v.getResultShape());

    }


    @Override
    public DifferentialFunction rdiv(DifferentialFunction i_v) {
        validateDifferentialFunctionGraph(i_v);
        X ret = i_v.getValue(true).rdiv(getValue(true));
        return new Constant<>(sameDiff, ret, i_v.getResultShape());

    }

    @Override
    public DifferentialFunction rsub(DifferentialFunction i_v) {
        validateDifferentialFunctionGraph(i_v);
        X ret = i_v.getValue(true).rsub(getValue(true));
        return new Constant<>(sameDiff, ret, i_v.getResultShape());

    }

    @Override
    public DifferentialFunction add(DifferentialFunction i_v) {
        validateDifferentialFunctionGraph(i_v);
        X ret = i_v.getValue(true).add(getValue(true));
        return new Constant<>(sameDiff, ret, i_v.getResultShape());
    }

    @Override
    public DifferentialFunction mul(DifferentialFunction i_v) {
        validateDifferentialFunctionGraph(i_v);
        X ret = i_v.getValue(true).mul(getValue(true));
        return new Constant<>(sameDiff, ret, i_v.getResultShape());
    }

    @Override
    public DifferentialFunction sub(DifferentialFunction i_v) {
        validateDifferentialFunctionGraph(i_v);
        X ret = i_v.getValue(true).sub(getValue(true));
        return new Constant<>(sameDiff, ret, i_v.getResultShape());
    }




    @Override
    public DifferentialFunction div(DifferentialFunction i_v) {
        validateDifferentialFunctionGraph(i_v);
        X ret = i_v.getValue(true).div(getValue(true));
        return new Constant<>(sameDiff, ret, i_v.getResultShape());
    }

    @Override
    public DifferentialFunction inverse() {
        DifferentialFunction ret = new Inverse<>(sameDiff,this.mul(1.0));
        return ret;
    }

    @Override
    public DifferentialFunction negate() {
        DifferentialFunction ret = new Negative<>(sameDiff,this.mul(1.0));
        return ret;
    }

    @Override
    public DifferentialFunction mul(double i_n) {
        Scalar constant = new Scalar<>(sameDiff, i_n);
        return this.mul(constant);
    }

    @Override
    public DifferentialFunction pow(int i_n) {
        PolynomialTerm ret = new PolynomialTerm<>(sameDiff,1L, this, i_n);
        return ret;
    }

    @Override
    public DifferentialFunction add(double i_v) {
        Scalar constant = new Scalar<>(sameDiff, i_v);
        return constant.add(this);
    }

    @Override
    public DifferentialFunction sub(double i_v) {
        Scalar constant = new Scalar<>(sameDiff, i_v);
        return constant.sub(this);
    }

    @Override
    public DifferentialFunction rsub(double v) {
        Scalar constant = new Scalar<>(sameDiff, v);
        return this.rsub(constant);
    }

    @Override
    public DifferentialFunction rdiv(double v) {
        Scalar constant = new Scalar<>(sameDiff, v);
        return this.rdiv(constant);
    }

    protected void addEdges(SameDiff sameDiff,
                            DifferentialFunction i_v1,
                            DifferentialFunction i_v2,
                            String opName,
                            OpState.OpType opType,
                            int[] shape) {
        addEdges(sameDiff,
                i_v1,
                i_v2,
                opName,
                opType,
                shape,
                null);

    }

    @Override
    public DifferentialFunction[] args() {
        return new DifferentialFunction[0];
    }

    @Override
    public DifferentialFunction pow(DifferentialFunction a) {
        return null;
    }

    @Override
    public DifferentialFunction floor() {
        return null;
    }

    @Override
    public DifferentialFunction ceil() {
        return null;
    }

    @Override
    public DifferentialFunction round() {
        return null;
    }

    @Override
    public DifferentialFunction abs() {
        return null;
    }

    @Override
    public DifferentialFunction sqrt() {
        return null;
    }

    @Override
    public DifferentialFunction minus(double v) {
        return null;
    }

    @Override
    public DifferentialFunction prod(double v) {
        return null;
    }

    @Override
    public DifferentialFunction div(double v) {
        return null;
    }

    @Override
    public DifferentialFunction pow(double v) {
        return null;
    }

    @Override
    public DifferentialFunction cos() {
        return null;
    }

    @Override
    public DifferentialFunction acos() {
        return null;
    }

    @Override
    public DifferentialFunction cosh() {
        return null;
    }

    @Override
    public DifferentialFunction acosh() {
        return null;
    }

    @Override
    public DifferentialFunction sin() {
        return null;
    }

    @Override
    public DifferentialFunction asin() {
        return null;
    }

    @Override
    public DifferentialFunction sinh() {
        return null;
    }

    @Override
    public DifferentialFunction asinh() {
        return null;
    }

    @Override
    public DifferentialFunction tan() {
        return null;
    }

    @Override
    public DifferentialFunction atan() {
        return null;
    }

    @Override
    public DifferentialFunction tanh() {
        return null;
    }

    @Override
    public DifferentialFunction atanh() {
        return null;
    }

    @Override
    public DifferentialFunction exp() {
        return null;
    }

    @Override
    public DifferentialFunction log() {
        return null;
    }

    @Override
    public DifferentialFunction log10() {
        return null;
    }

    @Override
    public DifferentialFunction sgn() {
        return null;
    }

    @Override
    public DifferentialFunction pwr(DifferentialFunction y) {
        return null;
    }

    @Override
    public DifferentialFunction pwrs(DifferentialFunction y) {
        return null;
    }

    @Override
    public DifferentialFunction square() {
        return null;
    }

    @Override
    public DifferentialFunction relu() {
        return null;
    }

    @Override
    public DifferentialFunction hardTanh() {
        return null;
    }

    @Override
    public DifferentialFunction hardTanhDerivative() {
        return null;
    }

    @Override
    public DifferentialFunction leakyRelu() {
        return null;
    }

    @Override
    public DifferentialFunction elu() {
        return null;
    }

    @Override
    public DifferentialFunction eluDerivative() {
        return null;
    }

    @Override
    public DifferentialFunction leakyRelu(double cutoff) {
        return null;
    }

    @Override
    public DifferentialFunction leakyReluDerivative() {
        return null;
    }

    @Override
    public DifferentialFunction leakyReluDerivative(double cutoff) {
        return null;
    }

    @Override
    public DifferentialFunction sigmoid() {
        return null;
    }

    @Override
    public DifferentialFunction sigmoidDerivative() {
        return null;
    }

    @Override
    public DifferentialFunction step() {
        return null;
    }

    @Override
    public DifferentialFunction softsign() {
        return null;
    }

    @Override
    public DifferentialFunction softsignDerivative() {
        return null;
    }

    @Override
    public DifferentialFunction softmax() {
        return null;
    }

    @Override
    public DifferentialFunction softplus() {
        return null;
    }

    @Override
    public DifferentialFunction reshape(int[] shape) {
        return null;
    }

    @Override
    public DifferentialFunction transpose() {
        return null;
    }

    @Override
    public DifferentialFunction permute(int[] dimensions) {
        return null;
    }

    @Override
    public DifferentialFunction expandDims(int dim) {
        return null;
    }

    @Override
    public DifferentialFunction sum(int[] dimensions) {
        return null;
    }

    @Override
    public DifferentialFunction prod(int[] dimensions) {
        return null;
    }

    @Override
    public DifferentialFunction mean(int[] dimensions) {
        return null;
    }

    @Override
    public DifferentialFunction std(int[] dimensions, boolean biasCorrected) {
        return null;
    }

    @Override
    public DifferentialFunction variance(int[] dimensions, boolean biasCorrected) {
        return null;
    }

    @Override
    public DifferentialFunction std(int[] dimensions) {
        return null;
    }

    @Override
    public DifferentialFunction variance(int[] dimensions) {
        return null;
    }

    @Override
    public DifferentialFunction max(int[] dimensions) {
        return null;
    }

    @Override
    public DifferentialFunction min(int[] dimensions) {
        return null;
    }

    @Override
    public DifferentialFunction norm1(int[] dimensions) {
        return null;
    }

    @Override
    public DifferentialFunction norm2(int[] dimensions) {
        return null;
    }

    @Override
    public DifferentialFunction normmax(int[] dimensions) {
        return null;
    }

    @Override
    public DifferentialFunction valueArrayOf(int[] shape) {
        return null;
    }

    @Override
    public DifferentialFunction tile(int[] repeat) {
        return null;
    }

    @Override
    public DifferentialFunction repeat(int axis) {
        return null;
    }

    @Override
    public DifferentialFunction broadcast(int[] shape) {
        return null;
    }

    @Override
    public DifferentialFunction eq(DifferentialFunction i_y) {
        return null;
    }

    @Override
    public DifferentialFunction neq(DifferentialFunction i_y) {
        return null;
    }

    @Override
    public DifferentialFunction or(DifferentialFunction i_y) {
        return null;
    }

    @Override
    public DifferentialFunction rollAxis(int axis) {
        return null;
    }

    @Override
    public DifferentialFunction cosineSimilarity(DifferentialFunction i_y, int... dimensions) {
        return null;
    }

    @Override
    public DifferentialFunction euclideanDistance(DifferentialFunction i_y, int... dimensions) {
        return null;
    }

    @Override
    public DifferentialFunction manhattanDistance(DifferentialFunction i_y, int... dimensions) {
        return null;
    }

    @Override
    public DifferentialFunction lossBinaryXENT(DifferentialFunction i_y, int... dimensions) {
        return null;
    }

    @Override
    public DifferentialFunction lossCosineSimilarity(DifferentialFunction i_y, int... dimensions) {
        return null;
    }

    @Override
    public DifferentialFunction lossHinge(DifferentialFunction i_y, int... dimensions) {
        return null;
    }

    @Override
    public DifferentialFunction lossKLD(DifferentialFunction i_y, int... dimensions) {
        return null;
    }

    @Override
    public DifferentialFunction lossL1(DifferentialFunction i_y, int... dimensions) {
        return null;
    }

    @Override
    public DifferentialFunction lossL2(DifferentialFunction i_y, int... dimensions) {
        return null;
    }

    @Override
    public DifferentialFunction lossMAE(DifferentialFunction i_y, int... dimensions) {
        return null;
    }

    @Override
    public DifferentialFunction lossMAPE(DifferentialFunction i_y, int... dimensions) {
        return null;
    }

    @Override
    public DifferentialFunction lossMSE(DifferentialFunction i_y, int... dimensions) {
        return null;
    }

    @Override
    public DifferentialFunction lossMCXENT(DifferentialFunction i_y, int... dimensions) {
        return null;
    }

    @Override
    public DifferentialFunction lossMSLE(DifferentialFunction i_y, int... dimensions) {
        return null;
    }

    @Override
    public DifferentialFunction lossNegativeLogLikelihood(DifferentialFunction i_y, int... dimensions) {
        return null;
    }

    @Override
    public DifferentialFunction lossPoisson(DifferentialFunction i_y, int... dimensions) {
        return null;
    }

    @Override
    public DifferentialFunction lossSquaredHinge(DifferentialFunction i_y, int... dimensions) {
        return null;
    }

    protected void addEdges(SameDiff sameDiff,
                            DifferentialFunction i_v1,
                            DifferentialFunction i_v2,
                            String opName,
                            OpState.OpType opType,
                            int[] shape, Object[] extraArgs) {
        if(i_v1.getValue(true) instanceof ArrayField) {
            validateDifferentialFunctionGraph(i_v1);
            validateDifferentialFunctionGraph(i_v2);

            /**
             * getValue() generates invalid vertex ids
             * need to look at a way of getting the proper vertex
             * metadata
             *
             * Should be looking at a way to derive the vertex id
             * for each of these equations.
             *
             *
             *
             */
            ArrayField v1 = (ArrayField) i_v1.getValue(true);
            int v1VertexId = i_v1.resultVertexId();
            ArrayField v2 = (ArrayField) i_v2.getValue(true);
            int v2VertexId = i_v2.resultVertexId();

            NDArrayInformation arrInfo = NDArrayInformation.builder()
                    .arrId(UUID.randomUUID().toString())
                    .id(opName +"(" + v1.getInput().getId() + "," + v2.getInput().getId() + ")")
                    .shape(shape).build();
            //result
            NDArrayVertex newVertex = new NDArrayVertex(sameDiff.getGraph().nextVertexId(), arrInfo);
            if(newVertex.vertexID() == v2VertexId || newVertex.vertexID() == v1VertexId)
                throw new ND4JIllegalStateException("Illegal vertex id specified in new vertex." +
                        " Perhaps a mismatched graph call? Another likely cause is applyGraph");
            this.vertexId = newVertex.vertexID();
            //add the result vertex
            sameDiff.getGraph().addVertex(newVertex);
            OpState opState,opState2;


            //ensure there's 2 vertices for when the 2 inputs are the same
            if(v1.equals(v2)) {
                NDArrayVertex dupVertex = new NDArrayVertex(sameDiff.getGraph().nextVertexId(),
                        NDArrayInformation.builder()
                                .shape(v1.getInput().getShape())
                                .id(v1.getInput().getId()).build());
                //update vertex id
                v2VertexId = dupVertex.vertexID();
                sameDiff.getGraph().addVertex(dupVertex);
                opState = OpState.builder()
                        .opType(opType)
                        .opName(opName)
                        .id(opName + "(" + dupVertex.getValue().getId() + " -> " + newVertex.getValue().getId() + ")")
                        .vertexIds(new String[]{String.valueOf(v2VertexId),String.valueOf(newVertex.vertexID())})
                        .n(ArrayUtil.prod(shape))
                        .extraArgs(extraArgs)
                        .result(arrInfo)
                        .build();


            }
            else {
                opState =  OpState.builder()
                        .opType(opType)
                        .opName(opName)
                        .id(opName + "(" + v1.getVertex().getValue().getId() + " -> " + newVertex.getValue().getId() + ")")
                        .vertexIds(new String[]{String.valueOf(v2VertexId),String.valueOf(newVertex.vertexID())})
                        .n(ArrayUtil.prod(shape))
                        .extraArgs(extraArgs)
                        .result(arrInfo)
                        .build();
            }

            opState2 = OpState.builder()
                    .opType(opType)
                    .opName(opName).result(arrInfo)
                    .id(opName + "(" + v1.getVertex().getValue().getId() + " -> " + newVertex.getValue().getId() + ")")
                    .vertexIds(new String[]{String.valueOf(v1VertexId),String.valueOf(newVertex.vertexID())})
                    .n(ArrayUtil.prod(shape))
                    .extraArgs(extraArgs)
                    .result(arrInfo)
                    .build();
            //add the first vertex no matter what as normal
            sameDiff.getGraph().addEdge(v1VertexId,
                    newVertex.vertexID(),
                    opState2,true);

            sameDiff.getGraph().addEdge(v2VertexId,
                    newVertex.vertexID(),
                    opState
                    ,true);
            newVertex.setOpState(opState2);
            arrInfo.setOwner(opState2);

            this.opState = opState;

        }


    }



    protected void addEdges(SameDiff sameDiff,
                            DifferentialFunction i_v1,
                            DifferentialFunction i_v2,
                            String opName) {
        validateDifferentialFunctionGraph(i_v1);
        validateDifferentialFunctionGraph(i_v2);
        if(i_v1.getValue(true) instanceof ArrayField) {
            ArrayField arrayField = (ArrayField) i_v1.getValue(true);
            addEdges(sameDiff,
                    i_v1,
                    i_v2,
                    opName,
                    OpState.OpType.TRANSFORM,
                    arrayField.getInput().getShape());

        }

        else
            throw new UnsupportedOperationException("Only supporting array fields");
    }

    protected boolean getInPlace(Object[] extraArgs) {
        if(extraArgs == null) {
            return false;
        }
        else {
            for(int i = 0; i < extraArgs.length; i++) {
                if(extraArgs[i] instanceof Boolean)
                    return (Boolean) extraArgs[i];
            }
        }

        return false;
    }


    public abstract DifferentialFunction dup();

    public  int resultVertexId() {
        return vertexId;
    }

    protected void validateDifferentialFunctionsameDiff(
            DifferentialFunction function) {
        Preconditions.checkState(function != null,"Passed in function was null.");
        Preconditions.checkState(function.getSameDiff() ==
                        this.getSameDiff(),
                "Function applications must be contained " +
                        "in same sameDiff. The left " + function +"" +
                        " must match this function " + this);
        Preconditions.checkState(sameDiff ==
                this.getSameDiff(),"Function applications m" +
                "ust be " +
                "contained in same sameDiff. The left " + function +" " +
                "must " +
                "match this function " + this);

    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy