All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.autodiff.functions.One Maven / Gradle / Ivy

package org.nd4j.autodiff.functions;

import org.nd4j.autodiff.ArrayField;
import org.nd4j.autodiff.Field;
import org.nd4j.autodiff.opstate.OpState;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.ops.impl.transforms.arithmetic.MulOp;


public class One> extends Constant {


    public One(SameDiff sameDiff,
               int[] shape) {
        super(sameDiff, (X) sameDiff.getArrayFactory().one(shape),shape);
        this.shape = shape;
        ArrayField arrayField = (ArrayField) m_x;
        arrayField.getInput().setScalarValue(1.0);
    }




    @Override
    public DifferentialFunction mul(DifferentialFunction i_v) {
        DifferentialFunction dup = i_v.dup();
        if(i_v.getValue(true) instanceof ArrayField) {
            ArrayField arrayField = (ArrayField) i_v.getValue(true);
            addEdges(sameDiff,
                    dup,
                    this,
                    new MulOp().name(),
                    OpState.OpType.TRANSFORM,
                    arrayField.getInput().getShape());
        }

        return dup;
    }



    @Override
    public DifferentialFunction dup() {
        return new One<>(sameDiff, shape);
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy