All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.autodiff.functions.Constant Maven / Gradle / Ivy

There is a newer version: 0.9.1
Show newest version
package org.nd4j.autodiff.functions;

import java.util.List;

import lombok.Data;
import org.nd4j.autodiff.AbstractIdentityFactory;
import org.nd4j.autodiff.ArrayField;
import org.nd4j.autodiff.Field;
import org.nd4j.autodiff.samediff.SDGraph;

@Data
public class Constant> extends DifferentialFunction {

    protected X m_x;
    protected AbstractIdentityFactory m_factory;
    protected int[] shape;

    protected Constant(SDGraph graph,
                       X i_v,
                       int[] shape,
                       AbstractIdentityFactory i_factory,
                       boolean inPlace) {
        super(graph,new Object[]{i_v,inPlace});
        this.shape = shape;
        if(i_factory == null) {
            i_factory = (AbstractIdentityFactory) graph.getSameDiff().getArrayFactory();
        }
        if (i_v != null && i_factory != null) {
            m_x = i_v;
            m_factory = i_factory;

        } else {
            throw new IllegalArgumentException("Input not null value.");
        }

        if(i_v instanceof ArrayField) {
            ArrayField arrayField = (ArrayField) i_v;
            this.vertexId = arrayField.getVertex().vertexID();
        }
    }

    protected Constant(SDGraph graph,
                       X i_v,
                       int[] shape,
                       AbstractIdentityFactory i_factory) {
        this(graph,i_v,shape,i_factory,false);
    }

    /**
     * Get the result shape for this function
     *
     * @return
     */
    @Override
    public int[] getResultShape() {
        return shape;
    }

    @Override
    public boolean isConstant() {
        return true;
    }

    @Override
    public X doGetValue() {
        return m_x;
    }

    @Override
    public double getReal() {
        return m_x.getReal();
    }

    @Override
    public DifferentialFunction[] args() {
        return new DifferentialFunction[] {this};
    }

    @Override
    public DifferentialFunction arg() {
        return this;
    }

    @Override
    public DifferentialFunction diff(Variable i_v) {
        return new Zero<>(graph,shape, m_factory);
    }

    @Override
    public String toString() {
        return getValue(true).toString();
    }

    @Override
    public String doGetFormula(List> variables) {
        return getValue(true).toString();
    }

    @Override
    public String functionName() {
        return "constant";
    }



    @Override
    public Constant inverse() {
        Constant ret = new Constant<>(graph, m_x.inverse(),shape, m_factory);
        return ret;
    }

    @Override
    public Constant negate() {
        Constant ret =  new Constant<>(graph, m_x.negate(),shape, m_factory);
        return ret;
    }

    @Override
    public DifferentialFunction dup() {
        return new Constant<>(graph,m_x,shape,getM_factory());
    }

    // This class must be immutable.
    // set and assign must not be implemented.
    @SuppressWarnings("unused")
    private final void set(X i_x) {
    }

    @SuppressWarnings("unused")
    private final void assign(X i_x) {
    }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy