
neureka.backend.standard.operations.operator.Power Maven / Gradle / Ivy
package neureka.backend.standard.operations.operator;
import neureka.Neureka;
import neureka.Tsr;
import neureka.autograd.DefaultADAgent;
import neureka.backend.api.ExecutionCall;
import neureka.backend.api.algorithms.Algorithm;
import neureka.backend.api.operations.AbstractOperation;
import neureka.backend.api.operations.Operation;
import neureka.backend.api.operations.OperationContext;
import neureka.backend.standard.algorithms.Broadcast;
import neureka.backend.standard.algorithms.Convolution;
import neureka.backend.standard.algorithms.Operator;
import neureka.backend.standard.algorithms.Scalarization;
import neureka.backend.standard.implementations.CLImplementation;
import neureka.backend.standard.implementations.HostImplementation;
import neureka.calculus.Function;
import neureka.devices.Device;
import neureka.devices.host.HostCPU;
import neureka.devices.opencl.OpenCLDevice;
import neureka.ndim.config.NDConfiguration;
import org.jetbrains.annotations.Contract;
import java.util.List;
public class Power extends AbstractOperation
{
private final static DefaultOperatorCreator _creator = ( inputs, d )->
{
double[] t1_val = inputs[ 1 ].value64();
double[] t2_val = inputs[ 2 ].value64();
if ( d < 0 ) return ( t0Idx, t1Idx, t2Idx ) -> Math.pow(t1_val[ t1Idx.i() ], t2_val[t2Idx.i()]);
else {
return ( t0Idx, t1Idx, t2Idx ) -> {
if (d == 0) {
return t2_val[t2Idx.i()]
* Math.pow(
t1_val[ t1Idx.i() ],
t2_val[t2Idx.i()] - 1
);
} else {
return Math.pow(
t1_val[ t1Idx.i() ],
t2_val[t2Idx.i()]
) * Math.log(t1_val[ t1Idx.i() ]);
}
};
}
};
private final static DefaultOperatorCreator _creatorX = ( inputs, d )->
{
double[] t1_val = inputs[ 1 ].value64();
double[] t2_val = inputs[ 2 ].value64();
NDConfiguration ndc1 = inputs[ 1 ].getNDConf();
NDConfiguration ndc2 = inputs[ 2 ].getNDConf();
if ( d < 0 ) return ( t0Idx, t1Idx, t2Idx ) ->
Math.pow(t1_val[ndc1.i_of_idx( t1Idx )], t2_val[ndc2.i_of_idx(t2Idx)]);
else {
return ( t0Idx, t1Idx, t2Idx ) -> {
if (d == 0) {
double temp = t2_val[ndc2.i_of_idx(t2Idx)];
return temp * Math.pow( t1_val[ndc1.i_of_idx( t1Idx )], temp - 1 );
} else {
double temp = t1_val[ndc1.i_of_idx( t1Idx )];
return Math.pow( temp, t2_val[ndc2.i_of_idx(t2Idx)] ) * Math.log(temp);
}
};
}
};
public Power()
{
super("power", "^", -1, true, false, true, false);
setStringifier(
children ->
{
StringBuilder reconstructed = new StringBuilder();
for ( int i = 0; i < children.size(); ++i ) {
reconstructed.append( children.get( i ) );
if ( i < children.size() - 1 ) reconstructed.append(" ^ ");
}
return "(" + reconstructed + ")";
}
);
//_____________________
// DEFAULT OPERATION :
DefaultOperatorCreator operationCreator = ( inputs, d )->
{
double[] t1_val = inputs[ 1 ].value64();
double[] t2_val = inputs[ 2 ].value64();
if ( d < 0 ) return ( t1Idx, t2Idx ) ->
Math.pow(t1_val[ t1Idx.i() ], t2_val[t2Idx.i()]);
else {
return ( t1Idx, t2Idx ) ->
{
if ( d == 0 ) return
t2_val[t2Idx.i()] * Math.pow(
t1_val[ t1Idx.i() ],
t2_val[t2Idx.i()] - 1
);
else return
Math.pow(
t1_val[ t1Idx.i() ],
t2_val[t2Idx.i()]
) * Math.log(t1_val[ t1Idx.i() ]);
};
}
};
DefaultOperatorCreator operationXCreator = ( inputs, d )->
{
double[] t1_val = inputs[ 1 ].value64();
double[] t2_val = inputs[ 2 ].value64();
NDConfiguration ndc1 = inputs[ 1 ].getNDConf();
NDConfiguration ndc2 = inputs[ 2 ].getNDConf();
if ( d < 0 ) return t1Idx ->
Math.pow(t1_val[ndc1.i_of_idx( t1Idx )], t2_val[ndc2.i_of_idx( t1Idx )]);
else {
return t1Idx ->
{
double temp1 = t1_val[ndc1.i_of_idx( t1Idx )];
double temp2 = t2_val[ndc2.i_of_idx( t1Idx )];
if ( d == 0 ) return temp2 * Math.pow( temp1, temp2 - 1 );
else return Math.pow( temp1, temp2 ) * Math.log(temp1);
};
}
};
Algorithm.RecursiveJunctionAgent rja = (call, goDeeperWith)->
{
Tsr[] tsrs = call.getTensors();
Device device = call.getDevice();
int d = call.getDerivativeIndex();
Operation type = call.getOperation();
Tsr alternative = null;
if ( tsrs.length > 3 )
{
if ( d < 0 ) {
Tsr[] reduction = new Tsr[]{tsrs[ 0 ], tsrs[ 1 ], tsrs[ 2 ]};
alternative = goDeeperWith.apply(
call.withNew( reduction )
);
tsrs[ 0 ] = reduction[ 0 ];
reduction = Utility.offsetted(tsrs, 1);
alternative = goDeeperWith.apply(
call.withNew( reduction )
);
tsrs[ 0 ] = reduction[ 0 ];
} else {
if ( d==0 ) {
Tsr[] reduction = Utility.subset(tsrs, 1, 2, tsrs.length-2);
reduction[ 0 ] = Tsr.Create.newTsrLike(tsrs[ 1 ]);
alternative = goDeeperWith.apply(
new ExecutionCall<>( device, reduction, -1, OperationContext.get().instance("*") )
);
Tsr exp = reduction[ 0 ];
reduction = new Tsr[]{tsrs[ 0 ], tsrs[ 1 ], exp};
alternative = goDeeperWith.apply(
new ExecutionCall<>( device, reduction, 0, type )
);
tsrs[ 0 ] = reduction[ 0 ];
exp.delete();
} else {
Tsr[] reduction = Utility.subset(tsrs, 1, 2, tsrs.length-2);
reduction[ 0 ] = Tsr.Create.newTsrLike(tsrs[ 1 ]);
alternative = goDeeperWith.apply(
new ExecutionCall<>( device, reduction, d-1, OperationContext.get().instance("*") )
);
Tsr inner = reduction[ 0 ];
reduction = new Tsr[]{Tsr.Create.newTsrLike(tsrs[ 1 ]), inner, tsrs[d]};
alternative = goDeeperWith.apply(
new ExecutionCall<>( device, reduction, -1, OperationContext.get().instance("*") )
);
Tsr exp = reduction[ 0 ];
reduction = new Tsr[]{tsrs[ 0 ], tsrs[ 1 ], exp};
alternative = goDeeperWith.apply(
new ExecutionCall<>( device, reduction, 1, type )
);
tsrs[ 0 ] = reduction[ 0 ];
inner.delete();
exp.delete();
}
}
return alternative;
} else {
return alternative;
}
};
Operator operator = new Operator()
.setADAgentSupplier(
( Function f, ExecutionCall call, boolean forward ) ->
getDefaultAlgorithm().supplyADAgentFor( f, call, forward )
)
.setRJAgent( rja )
.build();
setAlgorithm(Operator.class,
operator.setImplementationFor(
HostCPU.class,
new HostImplementation(
call ->
call.getDevice().getExecutor()
.threaded (
call.getTensor( 0 ).size(),
(Neureka.instance().settings().indexing().isUsingArrayBasedIndexing())
? ( start, end ) ->
Operator.operate (
call.getTensor( 0 ),
call.getTensor(1),
call.getTensor(2),
call.getDerivativeIndex(),
start, end,
operationXCreator.create(call.getTensors(), call.getDerivativeIndex())
)
: ( start, end ) ->
Operator.operate (
call.getTensor( 0 ),
call.getTensor(1),
call.getTensor(2),
call.getDerivativeIndex(),
start, end,
operationCreator.create(call.getTensors(), call.getDerivativeIndex())
)
),
3
)
).setImplementationFor(
OpenCLDevice.class,
new CLImplementation(
call ->
{
int offset = (call.getTensor( 0 ) != null) ? 0 : 1;
int gwz = (call.getTensor( 0 ) != null)
? call.getTensor( 0 ).size()
: call.getTensor( 1 ).size();
call.getDevice().getKernel(call)
.pass( call.getTensor( offset ) )
.pass( call.getTensor( offset + 1 ) )
.pass( call.getTensor( offset + 2 ) )
.pass( call.getTensor( 0 ).rank() )
.pass( call.getDerivativeIndex() )
.call( gwz );
},
3,
operator.getKernelSource(), // kernelSource
"output = pow(input1, input2);",
"if (d==0) { \n" +
" output = input2 * pow(input1, input2-1.0f); \n" +
"} else { \n" +
" output = pow(input1, input2) * log(input1); \n" +
"}",
this // OperationType
)
)
);
//________________
// BROADCASTING :
Broadcast broadcast = new Broadcast()
.setBackwardADAnalyzer( call -> true )
.setForwardADAnalyzer( call -> true )
.setADAgentSupplier(
( Function f, ExecutionCall call, boolean forward ) ->
{
Tsr> ctxDerivative = (Tsr>)call.getAt("derivative");
Function mul = Function.Detached.MUL;
if ( ctxDerivative != null ) {
return new DefaultADAgent( ctxDerivative )
.setForward( (node, forwardDerivative ) -> mul.call(new Tsr[]{forwardDerivative, ctxDerivative}) )
.setBackward( (node, forwardDerivative ) -> mul.call(new Tsr[]{forwardDerivative, ctxDerivative}) );
}
Tsr[] inputs = call.getTensors();
int d = call.getDerivativeIndex();
if ( forward ) throw new IllegalArgumentException("Broadcast implementation does not support forward-AD!");
else
{
Tsr> deriv = f.derive( inputs, d );
return new DefaultADAgent( deriv )
.setForward( (node, forwardDerivative ) -> mul.call(new Tsr[]{forwardDerivative, deriv}) )
.setBackward( (node, backwardError ) -> mul.call(new Tsr[]{backwardError, deriv}) );
}
}
)
.setRJAgent( rja )
.build();
setAlgorithm(
Broadcast.class,
broadcast.setImplementationFor(
HostCPU.class,
new HostImplementation(
call ->
call.getDevice().getExecutor()
.threaded (
call.getTensor( 0 ).size(),
(Neureka.instance().settings().indexing().isUsingArrayBasedIndexing())
? ( start, end ) ->
Broadcast.broadcast (
call.getTensor( 0 ), call.getTensor(1), call.getTensor(2),
call.getDerivativeIndex(), start, end,
_creatorX.create(call.getTensors(), call.getDerivativeIndex())
)
: ( start, end ) ->
Broadcast.broadcast (
call.getTensor( 0 ), call.getTensor(1), call.getTensor(2),
call.getDerivativeIndex(), start, end,
_creator.create(call.getTensors(), call.getDerivativeIndex())
)
),
3
)
).setImplementationFor(
OpenCLDevice.class,
new CLImplementation(
call -> {
int offset = (call.getTensor( 0 ) != null) ? 0 : 1;
int gwz = (call.getTensor( 0 ) != null) ? call.getTensor( 0 ).size() : call.getTensor( 1 ).size();
call.getDevice().getKernel(call)
.pass( call.getTensor( offset ) )
.pass( call.getTensor( offset + 1 ) )
.pass( call.getTensor( offset + 2 ) )
.pass( call.getTensor( 0 ).rank() )
.pass( call.getDerivativeIndex() )
.call( gwz );
},
3,
broadcast.getKernelSource(), // kernelSource
"value += pow(src1, src2);",
"if (d==0) {\n" +
" value = (handle * pow(target, handle-(float)1 )) * drain;\n" +
"} else {\n" +
" value += (pow(target, handle) * log(handle)) * drain;\n" +
"}",
this // OperationType
)
)
);
//___________________________
// TENSOR SCALAR OPERATION :
ScalarOperatorCreator scalarCreator =
( inputs, value, d ) -> {
double[] t1_val = inputs[ 1 ].value64();
if ( d < 0 ) return t1Idx -> Math.pow(t1_val[ t1Idx.i() ], value);
else {
if (d==0) return t1Idx -> value*Math.pow(t1_val[ t1Idx.i() ], value-1);
else return t1Idx -> Math.pow(t1_val[ t1Idx.i() ], value)*Math.log(value);
}
};
ScalarOperatorCreator scalarXCreator =
( inputs, value, d ) -> {
double[] t1_val = inputs[ 1 ].value64();
NDConfiguration ndc1 = inputs[ 1 ].getNDConf();
if ( d < 0 ) return t1Idx -> Math.pow(t1_val[ndc1.i_of_idx( t1Idx )], value);
else {
if (d==0) return t1Idx -> value*Math.pow(t1_val[ndc1.i_of_idx( t1Idx )], value-1);
else return t1Idx -> Math.pow(t1_val[ndc1.i_of_idx( t1Idx )], value)*Math.log(value);
}
};
Scalarization scalarization = new Scalarization()
.setBackwardADAnalyzer( call -> true )
.setForwardADAnalyzer( call -> true )
.setADAgentSupplier(
( Function f, ExecutionCall call, boolean forward ) ->
getDefaultAlgorithm().supplyADAgentFor( f, call, forward )
)
.setCallHook( (caller, call ) -> null )
.setRJAgent( rja )
.build();
setAlgorithm(
Scalarization.class,
scalarization.setImplementationFor(
HostCPU.class,
new HostImplementation(
call -> {
double value = call.getTensor( 0 ).value64(2);
call.getDevice().getExecutor()
.threaded (
call.getTensor( 0 ).size(),
(Neureka.instance().settings().indexing().isUsingArrayBasedIndexing())
? ( start, end ) ->
Scalarization.scalarize (
call.getTensor( 0 ),
start, end,
scalarXCreator.create(call.getTensors(), value, -1)
)
: ( start, end ) ->
Scalarization.scalarize (
call.getTensor( 0 ),
start, end,
scalarCreator.create(call.getTensors(), value, -1)
)
);
},
3
)
).setImplementationFor(
OpenCLDevice.class,
new CLImplementation(
call -> {
int offset = (call.getTensor( 2 ).isVirtual() || call.getTensor( 2 ).size() == 1)?1:0;
int gwz = call.getTensor( 0 ).size();
call.getDevice().getKernel(call)
.pass(call.getTensor( 0 ))
.pass(call.getTensor( 0 ))
.pass((float)call.getTensor(1+offset).value64( 0 ))
.pass( call.getTensor( 0 ).rank() )
.pass( call.getDerivativeIndex() )
.call( gwz );
},
3,
scalarization.getKernelSource(), // kernelSource
"output = pow(input1, value);",
"if ( d==0 ) { \n" +
" output = value * pow(input1, value-(float)1 ); \n" +
"} else { \n" +
" output = pow(input1, value) * log(value); \n" +
"}",
this // OperationType
)
)
);
//__________________________
// RELATED OPERATION TYPES :
new AbstractOperation("inv_power_left", ((char) 171) + "^", 3, true, false, false, false) {
@Override
public double calculate( double[] inputs, int j, int d, List src ) {
return src.get( 0 ).call( inputs, j );
}
};
new AbstractOperation("inv_power_right", "^" + ((char) 187), 3, true, false, false, false) {
@Override
public double calculate( double[] inputs, int j, int d, List src ) {
return src.get( 0 ).call( inputs, j );
}
};
// Convolution:
new AbstractOperation(
"power", "p", 2, true, false, false, false
) {
@Override
public double calculate( double[] inputs, int j, int d, List src ) {
return 0;
}
}.setAlgorithm(
Convolution.class,
new Convolution()
.setBackwardADAnalyzer( call -> true )
.setForwardADAnalyzer(
call -> {
Tsr> last = null;
for ( Tsr> t : call.getTensors() ) {
if ( last != null && !last.shape().equals(t.shape()) ) return false;
last = t; // Note: shapes are cached!
}
return true;
}
).setADAgentSupplier(
( Function f, ExecutionCall call, boolean forward ) ->
{
Tsr> ctxDerivative = (Tsr>) call.getAt("derivative");
Function mul = Function.Detached.MUL;
if ( ctxDerivative != null ) {
return new DefaultADAgent( ctxDerivative )
.setForward( (node, forwardDerivative ) -> mul.call(new Tsr[]{forwardDerivative, ctxDerivative}) )
.setBackward( (node, forwardDerivative ) -> mul.call(new Tsr[]{forwardDerivative, ctxDerivative}) );
}
Tsr[] inputs = call.getTensors();
int d = call.getDerivativeIndex();
if ( forward )
throw new IllegalArgumentException("Convolution of does not support forward-AD!");
else
{
Tsr> localDerivative = f.derive( inputs, d );
return new DefaultADAgent( localDerivative )
.setForward( (node, forwardDerivative ) -> mul.call(new Tsr[]{forwardDerivative, localDerivative}) )
.setBackward( (node, backwardError ) -> mul.call(new Tsr[]{backwardError, localDerivative}) );
}
}
)
.setCallHook( (caller, call ) -> null )
.setRJAgent( ( call, goDeeperWith ) -> null )
.setDrainInstantiation(
call -> {
Tsr[] tsrs = call.getTensors();
int offset = ( tsrs[ 0 ] == null ) ? 1 : 0;
return new ExecutionCall( call.getDevice(), new Tsr[]{tsrs[offset], tsrs[1+offset]}, -1, OperationContext.get().instance("idy") );
}
)
.build()
)
.setStringifier(
children -> {
StringBuilder reconstructed = new StringBuilder();
for ( int i = 0; i < children.size(); ++i ) {
reconstructed.append( children.get( i ) );
if ( i < children.size() - 1 ) {
reconstructed.append(" p ");
}
}
return "(" + reconstructed + ")";
}
);
new AbstractOperation("", ((char) 171) + "p", 3, true, false, false, false) {
@Override
public double calculate( double[] inputs, int j, int d, List src ) {
return src.get( 0 ).call( inputs, j );
}
};
new AbstractOperation("", "p" + ((char) 187), 3, true, false, false, false) {
@Override
public double calculate( double[] inputs, int j, int d, List src ) {
return src.get( 0 ).call( inputs, j );
}
};
}
// d/dx(f(x)^g(x))=
// f(x)^g(x) * d/dx(g(x)) * ln(f(x))
// + f(x)^(g(x)-1) * g(x) * d/dx(f(x))
@Contract(pure = true)
@Override
public double calculate( double[] inputs, int j, int d, List src ) {
if ( j < 0 ) return calculate( inputs, d, src );
if ( d < 0 ) {
double result = src.get( 0 ).call( inputs, j );
for ( int i = 1; i < src.size(); i++ ) {
final double current = src.get( i ).call( inputs, j );
result = Math.pow(result, current);
}
return result;
} else {
double out = 0;
for ( int si = 0; si < src.size(); si++ ) {
double b = 1;
for ( int i = 1; i < src.size(); i++ ) {
b *= src.get( i ).call( inputs, j );
}
if ( si == 0 ) {
out += src.get( 0 ).derive( inputs, d, j ) * b * Math.pow(src.get( 0 ).call( inputs, j ), b - 1);
} else {
double a = src.get( 0 ).call( inputs, j );
out += ( a >= 0 ) ? src.get(si).derive( inputs, d, j ) * b * Math.log(a) : 0;
}
}
return out;
}
}
@Contract(pure = true)
public static double calculate( double[] inputs, int d, List src ) {
if ( d < 0 ) {
double result = src.get( 0 ).call( inputs );
for ( int i = 1; i < src.size(); i++ ) {
final double current = src.get( i ).call( inputs );
result = Math.pow(result, current);
}
return result;
} else {
double b = 1;
double bd = 0;
double a = 0;
for ( int i = 1; i < src.size(); i++ ) {
double dd = 1;
a = src.get( i ).call( inputs );
for ( int di = 1; di < src.size(); di++ ) {
if ( di != i ) dd *= a;
else dd *= src.get(di).derive( inputs, d );
}
bd += dd;
b *= a;
}
double out = 0;
a = src.get( 0 ).call( inputs );
out += src.get( 0 ).derive( inputs, d ) * b * Math.pow(a, b - 1);
out += (a >= 0) ? bd * Math.pow(a, b) * Math.log(a) : 0;
return out;
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy