All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.autodiff.functions.Scalar Maven / Gradle / Ivy

package org.nd4j.autodiff.functions;

import org.nd4j.autodiff.AbstractIdentityFactory;
import org.nd4j.autodiff.ArrayField;
import org.nd4j.autodiff.Field;
import org.nd4j.autodiff.opstate.OpState;
import org.nd4j.autodiff.samediff.SDGraph;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.ops.impl.transforms.arithmetic.MulOp;


/**
 * Scalar value
 * @param 
 */
public class Scalar> extends Constant {

    protected double value;

    public Scalar(SameDiff sameDiff,
                  double value) {
        this(sameDiff, value, false);

    }

    public Scalar(SameDiff sameDiff,
                  double value,boolean inPlace) {
        super(sameDiff, (X) sameDiff.getArrayFactory().scalar(value),new int[]{1,1},inPlace);
        this.value = value;

    }


    @Override
    public DifferentialFunction mul(DifferentialFunction i_v) {
        DifferentialFunction dup = i_v.dup();
        if(i_v.getValue(true) instanceof ArrayField) {
            ArrayField arrayField = (ArrayField) i_v.getValue(true);
            addEdges(sameDiff,
                    dup,
                    this,
                    new MulOp().name(),
                    OpState.OpType.TRANSFORM,
                    arrayField.getInput().getShape());
        }

        return dup;
    }


    @Override
    public DifferentialFunction dup() {
        return new Scalar<>(sameDiff, value);
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy