All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.autodiff.functions.Scalar Maven / Gradle / Ivy

There is a newer version: 0.9.1
Show newest version
package org.nd4j.autodiff.functions;

import org.nd4j.autodiff.AbstractIdentityFactory;
import org.nd4j.autodiff.ArrayField;
import org.nd4j.autodiff.Field;
import org.nd4j.autodiff.opstate.OpState;
import org.nd4j.autodiff.samediff.SDGraph;
import org.nd4j.linalg.api.ops.impl.transforms.arithmetic.MulOp;


/**
 * Scalar value
 * @param 
 */
public class Scalar> extends Constant {

    protected double value;

    public Scalar(SDGraph graph,
                  double value,
                  AbstractIdentityFactory i_factory) {
        this(graph, value, i_factory, false);

    }

    public Scalar(SDGraph graph,
                  double value,
                  AbstractIdentityFactory i_factory,boolean inPlace) {
        super(graph,i_factory.scalar(value),new int[]{1,1} ,i_factory,inPlace);
        this.value = value;
    }


    @Override
    public DifferentialFunction mul(DifferentialFunction i_v) {
        DifferentialFunction dup = i_v.dup();
        if(i_v.getValue(true) instanceof ArrayField) {
            ArrayField arrayField = (ArrayField) i_v.getValue(true);
            addEdges(graph,
                    dup,
                    this,
                    new MulOp().name(),
                    OpState.OpType.TRANSFORM,
                    arrayField.getInput().getShape());
        }

        return dup;
    }


    @Override
    public DifferentialFunction dup() {
        return new Scalar<>(graph, value,getM_factory());
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy