All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.autodiff.samediff.SameDiffOpExecutioner Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
package org.nd4j.autodiff.samediff;

import lombok.Getter;
import org.bytedeco.javacpp.Pointer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.*;
import org.nd4j.linalg.api.ops.aggregates.Aggregate;
import org.nd4j.linalg.api.ops.aggregates.Batch;
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
import org.nd4j.linalg.api.ops.impl.accum.Variance;
import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.cache.TADManager;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.profiler.OpProfiler;

import java.util.*;
import java.util.concurrent.atomic.AtomicReference;

/**
 *
 */
public class SameDiffOpExecutioner implements OpExecutioner,OpProfiler.OpProfilerListener {

    @Getter
    private Map variables;
    @Getter
    private SameDiff sameDiff;
    @Getter
    private AtomicReference opAtomicReference;
    @Getter
    private OpExecutioner backendExecutioner = Nd4j.getExecutioner();

    public SameDiffOpExecutioner() {
        variables = new IdentityHashMap<>();
        sameDiff = SameDiff.create();
        OpProfiler.getInstance().addListener(this);
    }

    private Op  processOp(Op op) {
        if(opAtomicReference == null) {
            opAtomicReference = new AtomicReference<>(op);
        }

        for(INDArray arr : new INDArray[] {op.x(),op.y(),op.z()}) {
            if(arr == null)
                continue;
            if(!variables.containsKey(arr)) {
                SDVariable sdVariable = sameDiff.var(UUID.randomUUID().toString(),arr);
                variables.put(arr,sdVariable);
            }
        }

        if(op.x() != null && op.y() != null) {
            SDVariable result = sameDiff.invoke(op, variables.get(op.x()), variables.get(op.y()));
            variables.put(op.z(),result);
        }
        else {
            SDVariable result = sameDiff.invoke(op, variables.get(op.x()));
            variables.put(op.z(),result);

        }

        return op;
    }




    /**
     * This method returns opName of the last invoked op
     *
     * @return
     */
    @Override
    public String getLastOp() {
        return opAtomicReference.get().opName();
    }

    /**
     * Execute the operation
     *
     * @param op the operation to execute
     */
    @Override
    public Op exec(Op op) {
        return processOp(op);
    }

    /**
     * Iterate over every row of every slice
     *
     * @param op the operation to apply
     */
    @Override
    public void iterateOverAllRows(Op op) {
        throw new UnsupportedOperationException();
    }

    /**
     * Iterate over every column of every slice
     *
     * @param op the operation to apply
     */
    @Override
    public void iterateOverAllColumns(Op op) {
        throw new UnsupportedOperationException();
    }

    /**
     * Execute a TransformOp and return the result
     *
     * @param op the operation to execute
     */
    @Override
    public INDArray execAndReturn(TransformOp op) {
        return processOp(op).z();

    }

    /**
     * Execute and return the result from an accumulation
     *
     * @param op the operation to execute
     * @return the accumulated result
     */
    @Override
    public Accumulation execAndReturn(Accumulation op) {
        return (Accumulation) processOp(op).z();
    }

    /**
     * Execute and return the result from an accumulation
     *
     * @param op            the operation to execute
     * @param biasCorrected
     * @return the accumulated result
     */
    @Override
    public Accumulation execAndReturn(Variance op, boolean biasCorrected) {
        return (Accumulation) processOp(op);
    }

    /**
     * Execute and return the result from an index accumulation
     *
     * @param op the index accumulation operation to execute
     * @return the accumulated index
     */
    @Override
    public IndexAccumulation execAndReturn(IndexAccumulation op) {
        return (IndexAccumulation) processOp(op);
    }

    /**
     * Execute and return the result from a scalar op
     *
     * @param op the operation to execute
     * @return the accumulated result
     */
    @Override
    public INDArray execAndReturn(ScalarOp op) {
        return processOp(op).z();
    }

    /**
     * Execute and return the result from a vector op
     *
     * @param op
     */
    @Override
    public INDArray execAndReturn(BroadcastOp op) {
        return processOp(op).z();
    }

    /**
     * Execute and return the result from a vector op
     *
     * @param op
     */
    @Override
    public INDArray execAndReturn(ShapeOp op) {
        return backendExecutioner.execAndReturn(op);
    }

    /**
     * Execute the operation along 1 or more dimensions
     *
     * @param op        the operation to execute
     * @param dimension
     */
    @Override
    public Op exec(Op op, int... dimension) {
        return processOp(op);
    }

    /**
     * Execute an accumulation along one or more dimensions
     *
     * @param accumulation the accumulation
     * @param dimension    the dimension
     * @return the accumulation op
     */
    @Override
    public INDArray exec(Accumulation accumulation, int... dimension) {
        return processOp(accumulation).z();
    }

    /**
     * Execute an broadcast along one or more dimensions
     *
     * @param broadcast the accumulation
     * @param dimension the dimension
     * @return the broadcast op
     */
    @Override
    public INDArray exec(BroadcastOp broadcast, int... dimension) {
        return processOp(broadcast).z();
    }

    /**
     * Execute an accumulation along one or more dimensions
     *
     * @param accumulation  the accumulation
     * @param biasCorrected
     * @param dimension     the dimension  @return the accmulation op
     */
    @Override
    public INDArray exec(Variance accumulation, boolean biasCorrected, int... dimension) {
        return processOp(accumulation).z();
    }

    /**
     * Execute an index accumulation along one or more dimensions
     *
     * @param indexAccum the index accumulation operation
     * @param dimension  the dimension/s to execute along
     * @return result
     */
    @Override
    public INDArray exec(IndexAccumulation indexAccum, int... dimension) {
        return processOp(indexAccum).z();
    }

    /**
     * Execute and return  a result
     * ndarray from the given op
     *
     * @param op the operation to execute
     * @return the result from the operation
     */
    @Override
    public INDArray execAndReturn(Op op) {
        return processOp(op).z();
    }

    /**
     * Get the execution mode for this
     * execuioner
     *
     * @return the execution mode for this executioner
     */
    @Override
    public ExecutionMode executionMode() {
        return backendExecutioner.executionMode();
    }

    /**
     * Set the execution mode
     *
     * @param executionMode the execution mode
     */
    @Override
    public void setExecutionMode(ExecutionMode executionMode) {
        backendExecutioner.setExecutionMode(executionMode);
    }

    /**
     * Execute MetaOp
     *
     * @param op
     */
    @Override
    public void exec(MetaOp op) {

    }

    /**
     * Execute GridOp
     *
     * @param op
     */
    @Override
    public void exec(GridOp op) {

    }

    @Override
    public void exec(Aggregate op) {

    }

    /**
     * @param op
     */
    @Override
    public void exec(ShapeOp op) {
        backendExecutioner.exec(op);
    }

    /**
     * This method executes previously built batch
     *
     * @param batch
     */
    @Override
    public  void exec(Batch batch) {

    }

    /**
     * This method takes arbitrary sized list of aggregates, and packs them into batches
     *
     * @param batch
     */
    @Override
    public void exec(List batch) {

    }

    /**
     * This method executes specified RandomOp using default RNG available via Nd4j.getRandom()
     *
     * @param op
     */
    @Override
    public INDArray exec(RandomOp op) {
        return processOp(op).z();
    }

    /**
     * This method executes specific RandomOp against specified RNG
     *
     * @param op
     * @param rng
     */
    @Override
    public INDArray exec(RandomOp op, Random rng) {
        return processOp(op).z();
    }

    /**
     * This method return set of key/value and
     * key/key/value objects,
     * describing current environment
     *
     * @return
     */
    @Override
    public Properties getEnvironmentInformation() {
        return backendExecutioner.getEnvironmentInformation();
    }

    /**
     * This method specifies desired profiling mode
     *
     * @param mode
     */
    @Override
    public void setProfilingMode(ProfilingMode mode) {
        backendExecutioner.setProfilingMode(mode);
    }

    /**
     * Ths method returns current profiling
     *
     * @return
     */
    @Override
    public ProfilingMode getProfilingMode() {
        return backendExecutioner.getProfilingMode();
    }

    /**
     * This method returns TADManager instance used for this OpExecutioner
     *
     * @return
     */
    @Override
    public TADManager getTADManager() {
        return backendExecutioner.getTADManager();
    }

    /**
     * This method prints out environmental information returned by getEnvironmentInformation() method
     */
    @Override
    public void printEnvironmentInformation() {
        backendExecutioner.printEnvironmentInformation();
    }

    /**
     * This method ensures all operations that supposed to be executed at this moment, are executed.
     */
    @Override
    public void push() {
        backendExecutioner.push();
    }

    /**
     * This method ensures all operations that supposed to be executed at this moment, are executed and finished.
     */
    @Override
    public void commit() {
        backendExecutioner.commit();
    }

    /**
     * This method encodes array as thresholds, updating input array at the same time
     *
     * @param input
     * @param threshold
     * @return encoded array is returned
     */
    @Override
    public INDArray thresholdEncode(INDArray input, double threshold) {
        return backendExecutioner.thresholdEncode(input,threshold);
    }

    /**
     * This method encodes array as thresholds, updating input array at the same time
     *
     * @param input
     * @param threshold
     * @param boundary  @return encoded array is returned
     */
    @Override
    public INDArray thresholdEncode(INDArray input, double threshold, Integer boundary) {
        return backendExecutioner.thresholdEncode(input,threshold,boundary);
    }

    /**
     * This method decodes thresholds array, and puts it into target array
     *
     * @param encoded
     * @param target
     * @return target is returned
     */
    @Override
    public INDArray thresholdDecode(INDArray encoded, INDArray target) {
        return backendExecutioner.thresholdDecode(encoded,target);
    }

    /**
     * This method returns number of elements affected by encoder
     *
     * @param indArray
     * @param target
     * @param threshold
     * @return
     */
    @Override
    public long bitmapEncode(INDArray indArray, INDArray target, double threshold) {
        return backendExecutioner.bitmapEncode(indArray,target,threshold);
    }

    @Override
    public INDArray bitmapEncode(INDArray indArray, double threshold) {
        return backendExecutioner.bitmapEncode(indArray,threshold);
    }

    @Override
    public INDArray bitmapDecode(INDArray encoded, INDArray target) {
        return backendExecutioner.bitmapDecode(encoded,target);
    }

    @Override
    public void invoke(Op op) {
        processOp(op);

    }

    @Override
    public Map getCustomOperations() {
        return backendExecutioner.getCustomOperations();
    }

    /**
     * This method executes given CustomOp
     * 

* PLEASE NOTE: You're responsible for input/output validation * * @param op */ @Override public void exec(CustomOp op) { backendExecutioner.exec(op); } @Override public List calculateOutputShape(CustomOp op) { return backendExecutioner.calculateOutputShape(op); } @Override public void registerGraph(long id, Pointer graph) { backendExecutioner.registerGraph(id, graph); } @Override public Map executeGraph(long id, Map map) { return backendExecutioner.executeGraph(id, map); } @Override public void forgetGraph(long id) { backendExecutioner.forgetGraph(id); } @Override public void enableDebugMode(boolean reallyEnable) { backendExecutioner.enableDebugMode(reallyEnable); } @Override public void enableVerboseMode(boolean reallyEnable) { backendExecutioner.enableVerboseMode(reallyEnable); } /** * This method allows to set desired number of elements per thread, for performance optimization purposes. * I.e. if array contains 2048 elements, and threshold is set to 1024, 2 threads will be used for given op execution. *

* Default value: 1024 * * @param threshold */ @Override public void setElementsThreshold(int threshold) { backendExecutioner.setElementsThreshold(threshold); } /** * This method allows to set desired number of sub-arrays per thread, for performance optimization purposes. * I.e. if matrix has shape of 64 x 128, and threshold is set to 8, each thread will be processing 8 sub-arrays (sure, if you have 8 core cpu). * If your cpu has, say, 4, cores, only 4 threads will be spawned, and each will process 16 sub-arrays *

* Default value: 8 * * @param threshold */ @Override public void setTadThreshold(int threshold) { backendExecutioner.setTadThreshold(threshold); } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy