All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.autodiff.functions.Variable Maven / Gradle / Ivy

There is a newer version: 0.9.1
Show newest version
package org.nd4j.autodiff.functions;

import java.util.List;

import lombok.Data;
import org.nd4j.autodiff.AbstractIdentityFactory;
import org.nd4j.autodiff.ArrayField;
import org.nd4j.autodiff.Field;
import org.nd4j.autodiff.samediff.SDGraph;

@Data
public class Variable> extends DifferentialFunction {

    private X m_x;
    private AbstractIdentityFactory m_factory;
    private String m_name;
    private PreEvaluator preEvaluator;

    protected Variable(SDGraph graph,
                       String i_name,
                       X i_v,
                       AbstractIdentityFactory i_factory) {
        this(graph,i_name, i_v, i_factory, null);
    }

    protected Variable(SDGraph graph,
                       String i_name,
                       X i_v,
                       AbstractIdentityFactory i_factory,
                       PreEvaluator preEvaluator) {
        super(graph,null);
        this.preEvaluator = preEvaluator;
        setName(i_name);
        if (i_v != null && i_factory != null) {
            m_x = i_v;
            m_factory = i_factory;
        } else {
            throw new IllegalArgumentException("Input not null value.");
        }

        if(i_v instanceof ArrayField) {
            ArrayField arrayField = (ArrayField) i_v;
            this.vertexId = arrayField.getVertex().vertexID();
        }
    }


    private void setName(String i_name) {
        if (i_name != null) {
            m_name = i_name;// new String(i_name);
        } else {
            throw new IllegalArgumentException("Input not null value.");
        }
    }

    public String getName() {
        return m_name;
    }

    public void set(X i_v) {
        if (i_v != null) {
            m_x = i_v;
        } else {
            throw new IllegalArgumentException("Input not null value.");
        }
    }

    @Override
    public boolean isVariable() {
        return true;
    }

    @Override
    public X doGetValue() {
        if (preEvaluator != null) {
            preEvaluator.update(this);
        }
        return m_x;
    }

    @Override
    public double getReal() {
        if (preEvaluator != null) {
            preEvaluator.update(this);
        }
        return m_x.getReal();
    }

    @Override
    public DifferentialFunction[] args() {
        return new DifferentialFunction[] {this};
    }

    @Override
    public DifferentialFunction arg() {
        return this;
    }

    @Override
    public Constant diff(Variable i_v) {
        if(m_x instanceof ArrayField) {
            ArrayField arrayField = (ArrayField) m_x;
            Constant ret =  (this.equals(i_v) ? new One<>(graph, arrayField.getInput().getShape(),m_factory) : new Zero<>(graph,arrayField.getInput().getShape(), m_factory));

            /*addEdges(graph,
                    this,ret,
                    "diff",
                    OpState.OpType.TRANSFORM,
                    arrayField.getInput().getShape());*/
            return ret;

        }

        throw new IllegalStateException("Illegal type for variable. Should be ArrayField");
    }


    /**
     * Get the result shape for this function
     * @return
     */
    @Override
    public int[] getResultShape() {
        ArrayField arrayField = (ArrayField) m_x;
        return arrayField.getInput().getShape();
    }


    @Override
    public String toString() {
        return getName() + ":" + getValue(true);
    }

    @Override
    public String doGetFormula(List> variables) {
        variables.add(this);
        return getName();
    }

    @Override
    public String functionName() {
        return m_name;
    }

    @Override
    public DifferentialFunction div(DifferentialFunction i_v) {
        return (i_v == this) ? new One<>(graph,i_v.getResultShape(), m_factory) : super.mul(i_v.inverse());
    }

    @Override
    public DifferentialFunction dup() {
        return new Variable<>(graph, getName(), m_x, m_factory);
    }


}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy