neureka.backend.main.algorithms.ScalarBroadcast Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of neureka Show documentation
Show all versions of neureka Show documentation
A platform independent tensor library written in Java.
The newest version!
package neureka.backend.main.algorithms;
import neureka.Shape;
import neureka.Tensor;
import neureka.backend.api.AutoDiffMode;
import neureka.backend.api.Result;
import neureka.backend.api.fun.SuitabilityPredicate;
import neureka.backend.api.template.algorithms.AbstractDeviceAlgorithm;
import neureka.backend.api.template.algorithms.AbstractFunDeviceAlgorithm;
import neureka.backend.api.template.algorithms.FallbackAlgorithm;
import neureka.backend.main.implementations.fun.api.CPUFun;
import neureka.backend.main.implementations.fun.api.ScalarFun;
import neureka.backend.main.implementations.scalar.CPUScalarBroadcastFunction;
import neureka.math.args.Arg;
import neureka.devices.Device;
import neureka.devices.host.CPU;
import neureka.devices.opencl.OpenCLDevice;
import neureka.dtype.NumericType;
import neureka.ndim.NDimensional;
public class ScalarBroadcast extends AbstractFunDeviceAlgorithm
{
public ScalarBroadcast(ScalarFun fun) {
super("scalar broadcast");
setAutogradModeFor(
call -> call
.validate().allNotNullHaveSame(NDimensional::shape)
.ifValid(AutoDiffMode.FORWARD_AND_BACKWARD)
.orElse(AutoDiffMode.BACKWARD_ONLY)
);
setIsSuitableFor( call ->
call.validate()
.allNotNull( t -> t.getDataType().typeClassImplements(NumericType.class) )
.tensors( tensors -> {
if ( tensors.length != 2 ) return false;
if ( !tensors[1].isVirtual() ) return false;
if ( tensors[0] != null && tensors[0].isVirtual() ) return false;
return tensors[0] == null && tensors[1] != null || tensors[0].shape().equals(tensors[1].shape());
})
.suitabilityIfValid( SuitabilityPredicate.VERY_GOOD )
);
setCallPreparation(
call -> {
Device device = call.getDeviceFor(Number.class);
assert call.input( 0 ) == null; // Creating a new tensor:
Shape outShape = call.input( 1 ).shape();
Class
© 2015 - 2025 Weber Informatics LLC | Privacy Policy