All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.linalg.api.ops.executioner.DefaultOpExecutioner Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
/*-
 *
 *  * Copyright 2015 Skymind,Inc.
 *  *
 *  *    Licensed under the Apache License, Version 2.0 (the "License");
 *  *    you may not use this file except in compliance with the License.
 *  *    You may obtain a copy of the License at
 *  *
 *  *        http://www.apache.org/licenses/LICENSE-2.0
 *  *
 *  *    Unless required by applicable law or agreed to in writing, software
 *  *    distributed under the License is distributed on an "AS IS" BASIS,
 *  *    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  *    See the License for the specific language governing permissions and
 *  *    limitations under the License.
 *
 *
 */

package org.nd4j.linalg.api.ops.executioner;

import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.bytedeco.javacpp.Pointer;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.complex.IComplexNDArray;
import org.nd4j.linalg.api.environment.Nd4jEnvironment;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.*;
import org.nd4j.linalg.api.ops.aggregates.Aggregate;
import org.nd4j.linalg.api.ops.aggregates.Batch;
import org.nd4j.linalg.api.ops.impl.accum.Variance;
import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.cache.TADManager;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.profiler.OpProfiler;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Properties;

/**
 * Basic op executioner. Knows how to iterate over
 * the buffers of each
 * respective ndarray and apply transformations
 *
 * @author Adam Gibson
 */
@Slf4j
public class DefaultOpExecutioner implements OpExecutioner {

    private static final String SCOPE_PANIC_MSG = "For more details, see the ND4J User Guide: nd4j.org/userguide#workspaces-panic";

    protected ProfilingMode profilingMode = ProfilingMode.SCOPE_PANIC;
    protected ExecutionMode executionMode = ExecutionMode.JAVA;

    public DefaultOpExecutioner() {}

    protected void checkForCompression(Op op) {
        // check for INT datatype arrays
        interceptIntDataType(op);

        if (op.x() != null && op.x().isCompressed())
            Nd4j.getCompressor().decompressi(op.x());

        if (op.y() != null && op.y().isCompressed())
            Nd4j.getCompressor().decompressi(op.y());

        if (op.z() != null && op.z().isCompressed())
            Nd4j.getCompressor().decompressi(op.z());
    }

    @Override
    public String getLastOp() {
        return "UNKNOWN";
    }

    /**
     * This method checks if any Op operand has data opType of INT, and throws exception if any.
     *
     * @param op
     */
    protected void interceptIntDataType(Op op) {
        // FIXME: Remove this method, after we'll add support for  dtype operations

        if (op.x() != null && op.x().data().dataType() == DataBuffer.Type.INT)
            throw new ND4JIllegalStateException(
                            "Op.X contains INT data. Operations on INT dataType are not supported yet");

        if (op.z() != null && op.z().data().dataType() == DataBuffer.Type.INT)
            throw new ND4JIllegalStateException(
                            "Op.Z contains INT data. Operations on INT dataType are not supported yet");

        if (op.y() != null && op.y().data().dataType() == DataBuffer.Type.INT)
            throw new ND4JIllegalStateException(
                            "Op.Y contains INT data. Operations on INT dataType are not supported yet.");
    }

    @Override
    public Op exec(Op op) {
        if (op.isPassThrough()) {
            op.exec();
            return op;
        }

        throw new IllegalStateException("Java computation no longer supported");
    }

    @Override
    public INDArray execAndReturn(Op op) {
        if (op instanceof TransformOp) {
            return execAndReturn((TransformOp) op);
        }
        if (op instanceof ScalarOp) {
            return execAndReturn((ScalarOp) op);
        }
        if (op instanceof Accumulation) {
            return Nd4j.scalar(execAndReturn((Accumulation) op).getFinalResult());
        }
        if (op instanceof IndexAccumulation) {
            return Nd4j.scalar(execAndReturn((IndexAccumulation) op).getFinalResult());
        }

        throw new IllegalArgumentException("Illegal opType of op: " + op.getClass());
    }

    @Override
    public void iterateOverAllRows(Op op) {
        //column and row vectors should be treated the same
        if (op.x().isVector()) {
            //reset the op in case
            op.setX(op.x());
            if (op.y() != null)
                op.setY(op.y());
            op.setZ(op.z());
            exec(op);
        }
        //execute row wise
        else if (op.x().isMatrix()) {
            if (op.x() instanceof IComplexNDArray) {
                IComplexNDArray original = (IComplexNDArray) op.x();
                IComplexNDArray originalZ = (IComplexNDArray) op.z();
                IComplexNDArray y = (IComplexNDArray) op.y();

                for (int i = 0; i < original.rows(); i++) {
                    IComplexNDArray row = original.slice(i);
                    IComplexNDArray zRow = originalZ.slice(i);
                    op.setX(row.dup());
                    op.setZ(zRow.dup());
                    if (y != null)
                        op.setY(y.slice(i));
                    exec(op);
                    originalZ.slice(i).assign(op.z());

                }
            } else {
                INDArray original = op.x();
                INDArray originalZ = op.z();
                INDArray y = op.y();

                for (int i = 0; i < original.rows(); i++) {
                    INDArray row = original.getRow(i);
                    INDArray zRow = originalZ.getRow(i);
                    op.setX(row.dup());
                    op.setZ(zRow.dup());
                    if (y != null)
                        op.setY(y.getRow(i).dup());
                    exec(op);
                    zRow.assign(op.z());
                }
            }
        } else {
            INDArray originalX = op.x();
            INDArray originalZ = op.z();
            for (int i = 0; i < originalX.slices(); i++) {
                INDArray slice = originalX.slice(i);
                INDArray zSlice = originalZ.slice(i);
                op.setX(slice);
                op.setZ(zSlice);
                iterateOverAllRows(op);
            }
        }
    }

    @Override
    public void iterateOverAllColumns(Op op) {
        if (op.x().isVector()) {
            exec(op);
        }
        //execute row wise
        else if (op.x().isMatrix() || op.x().isColumnVector()) {
            exec(op, 1);
        } else {
            if (op.x() instanceof IComplexNDArray) {
                IComplexNDArray originalX = (IComplexNDArray) op.x();
                IComplexNDArray originalZ = (IComplexNDArray) op.z();
                IComplexNDArray y = (IComplexNDArray) op.y();
                for (int i = 0; i < op.x().slices(); i++) {
                    op.setX(originalX.getColumn(i));
                    op.setZ(originalZ.getColumn(i));
                    if (y != null)
                        op.setY(y.getColumn(i));
                    iterateOverAllColumns(op);
                }
            } else {
                INDArray originalX = op.x();
                INDArray originalZ = op.z();
                INDArray y = op.y();
                for (int i = 0; i < op.x().slices(); i++) {
                    op.setX(originalX.getColumn(i));
                    op.setZ(originalZ.getColumn(i));
                    if (y != null)
                        op.setY(y.getColumn(i));
                    iterateOverAllColumns(op);
                }
            }
        }
    }


    @Override
    public INDArray execAndReturn(TransformOp op) {
        Op result = exec(op);
        TransformOp t = (TransformOp) result;
        return t.z();
    }


    @Override
    public Accumulation execAndReturn(Accumulation op) {
        return (Accumulation) exec(op);
    }

    @Override
    public Accumulation execAndReturn(Variance op, boolean biasCorrected) {
        return null;
    }

    @Override
    public INDArray execAndReturn(ScalarOp op) {
        return exec(op).z();
    }

    @Override
    public IndexAccumulation execAndReturn(IndexAccumulation op) {
        return (IndexAccumulation) exec(op);
    }

    @Override
    public INDArray execAndReturn(BroadcastOp op) {
        return exec(op).z();
    }

    /**
     * Execute and return the result from a vector op
     *
     * @param op
     */
    @Override
    public INDArray execAndReturn(ShapeOp op) {
        exec(op);
        return op.z();
    }

    @Override
    public Op exec(Op op, int... dimension) {
        //do op along all dimensions
        if (dimension.length == op.x().rank()) {
            dimension = new int[] {Integer.MAX_VALUE};
        }

        if (op.isPassThrough()) {
            op.exec(dimension);
            return op;
        }

        if (op instanceof Accumulation || op instanceof IndexAccumulation) {
            //Overloaded exec(Accumulation,int...) and exec(IndexAccumulation,int...) should always be called instead of this
            throw new IllegalStateException(
                            "exec(Op,int...) should never be invoked for Accumulation/IndexAccumulation");
        }
        if (op instanceof ScalarOp) {
            //Scalar op along dimension should be same as on the entire NDArray
            throw new IllegalStateException("Java computation no longer supported");
        }
        if (op instanceof TransformOp) {
            throw new UnsupportedOperationException(
                            "Executing transform ops along a dimension should be done via exec special");
        }
        throw new UnsupportedOperationException("Unknown op opType");
    }

    @Override
    public INDArray exec(Accumulation op, int... dimension) {

        throw new UnsupportedOperationException("Java computation no longer supported");
    }

    @Override
    public INDArray exec(Variance accumulation, boolean biasCorrected, int... dimension) {
        accumulation.setBiasCorrected(biasCorrected);
        return exec(accumulation, dimension);
    }

    @Override
    public INDArray exec(IndexAccumulation op, int... dimension) {
        throw new UnsupportedOperationException("Operation should use exec special");

    }

    @Override
    public ExecutionMode executionMode() {
        return executionMode;
    }

    @Override
    public void setExecutionMode(ExecutionMode executionMode) {
        this.executionMode = executionMode;
    }



    @Override
    public INDArray exec(BroadcastOp broadcast, int... dimension) {
        if (dimension.length == broadcast.x().rank()) {
            dimension = new int[] {Integer.MAX_VALUE};
        }

        if (broadcast.isPassThrough()) {
            broadcast.exec(dimension);
            return broadcast.z();
        }

        throw new IllegalStateException("Java computation no longer supported");

    }

    @Override
    public void exec(MetaOp op) {
        throw new UnsupportedOperationException("MetaOp execution isn't supported for this OpExecutioner yet");
    }

    @Override
    public void exec(GridOp op) {
        throw new UnsupportedOperationException("GridOp execution isn't supported for this OpExecutioner yet");
    }

    @Override
    public  void exec(Batch batch) {
        throw new UnsupportedOperationException();
    }

    @Override
    public void exec(Aggregate op) {
        throw new UnsupportedOperationException();
    }

    /**
     * @param op
     */
    @Override
    public void exec(ShapeOp op) {
        if(!op.isExecSpecial()) {
            throw new IllegalArgumentException("Only special execution supported right now.");
        }

        op.exec();
    }

    @Override
    public void exec(List batch) {
        throw new UnsupportedOperationException();
    }

    /**
     * This method executes specified RandomOp using default RNG available via Nd4j.getRandom()
     *
     * @param op
     */
    @Override
    public INDArray exec(RandomOp op) {
        return exec(op, Nd4j.getRandom());
    }

    /**
     * This method executes specific RandomOp against specified RNG
     *
     * @param op
     * @param rng
     */
    @Override
    public INDArray exec(RandomOp op, Random rng) {
        throw new UnsupportedOperationException();
    }


    @Override
    public void setProfilingMode(ProfilingMode mode) {
        profilingMode = mode;
    }

    @Override
    public ProfilingMode getProfilingMode() {
        return profilingMode;
    }

    public long profilingHookIn(Op op, DataBuffer... tadBuffers) {
        switch (profilingMode) {
            case ALL:
                OpProfiler.getInstance().processOpCall(op, tadBuffers);
                break;
            case METHODS:
                break;
            case OPERATIONS:
                OpProfiler.getInstance().processOpCall(op, tadBuffers);
                break;
            case DISABLED:
            default:
                return 0L;
        }

        return System.nanoTime();
    }

    protected void checkWorkspace(String opName, INDArray array) {
        if (array.isAttached()) {
            val ws = array.data().getParentWorkspace();

            if (ws.getWorkspaceType() != MemoryWorkspace.Type.CIRCULAR) {

                if (!ws.isScopeActive()) {
                    throw new ND4JIllegalStateException("Op [" + opName + "] X argument uses leaked workspace pointer from workspace ["
                            + ws.getId() + "]\nAll open workspaces: " + allOpenWorkspaces() + "\n" + SCOPE_PANIC_MSG);
                }

                if (ws.getGenerationId() != array.data().getGenerationId())
                    throw new ND4JIllegalStateException("Op [" + opName + "] X argument uses outdated workspace pointer from workspace ["
                            + ws.getId() + "]\nAll open workspaces: " + allOpenWorkspaces() + "\n" + SCOPE_PANIC_MSG);
            }
        }
    }

    protected void checkForWorkspaces(CustomOp op) {
        for (val input: op.inputArguments())
            checkWorkspace(op.opName(), input);

        for (val output: op.outputArguments())
            checkWorkspace(op.opName(), output);
    }

    protected void checkForWorkspaces(Op op) {
        val x = op.x();
        if (x != null)
            checkWorkspace(op.opName(), x);

        val y = op.y();
        if (y != null)
            checkWorkspace(op.opName(), y);

        val z = op.z();
        if (z != null)
            checkWorkspace(op.opName(), z);
    }

    private static List allOpenWorkspaces(){
        List l = Nd4j.getWorkspaceManager().getAllWorkspacesForCurrentThread();
        List workspaces = new ArrayList<>(l.size());
        for( MemoryWorkspace ws : l){
            if(ws.isScopeActive()) {
                workspaces.add(ws.getId());
            }
        }
        return workspaces;
    }

    public long profilingHookIn(Op op) {
        switch (profilingMode) {
            case ALL:
                OpProfiler.getInstance().processOpCall(op);
                break;
            case METHODS:
                break;
            case OPERATIONS:
                OpProfiler.getInstance().processOpCall(op);
                break;
            case SCOPE_PANIC:
                checkForWorkspaces(op);
                return 0L;
            case DISABLED:
            default:
                return 0L;
        }

        return System.nanoTime();
    }

    public long profilingHookIn(CustomOp op) {
        switch (profilingMode) {
            case ALL:
                OpProfiler.getInstance().processOpCall(op);
                break;
            case METHODS:
                break;
            case OPERATIONS:
                OpProfiler.getInstance().processOpCall(op);
                break;
            case SCOPE_PANIC:
                checkForWorkspaces(op);
                return 0L;
            case DISABLED:
            default:
                return 0L;
        }

        return System.nanoTime();
    }

    public void profilingHookOut(Op op, long timeStart) {
        switch (profilingMode) {
            case ALL:
                OpProfiler.getInstance().processStackCall(op, timeStart);
                OpProfiler.getInstance().timeOpCall(op, timeStart);
                break;
            case METHODS:
                OpProfiler.getInstance().processStackCall(op, timeStart);
                break;
            case OPERATIONS:
                OpProfiler.getInstance().timeOpCall(op, timeStart);
                break;
            case NAN_PANIC:
                OpExecutionerUtil.checkForNaN(op);
                break;
            case INF_PANIC:
                OpExecutionerUtil.checkForInf(op);
                break;
            case ANY_PANIC:
                OpExecutionerUtil.checkForNaN(op);
                OpExecutionerUtil.checkForInf(op);
                break;
            case DISABLED:
            default:
                break;
        }
    }


    public void profilingHookOut(CustomOp op, long timeStart) {
        switch (profilingMode) {
            case ALL:
                OpProfiler.getInstance().processStackCall(op, timeStart);
                OpProfiler.getInstance().timeOpCall(op, timeStart);
                break;
            case METHODS:
                OpProfiler.getInstance().processStackCall(op, timeStart);
                break;
            case OPERATIONS:
                OpProfiler.getInstance().timeOpCall(op, timeStart);
                break;
            case NAN_PANIC:
                OpExecutionerUtil.checkForNaN(op);
                break;
            case INF_PANIC:
                OpExecutionerUtil.checkForInf(op);
                break;
            case ANY_PANIC:
                OpExecutionerUtil.checkForNaN(op);
                OpExecutionerUtil.checkForInf(op);
                break;
            case DISABLED:
            default:
                break;
        }
    }


    /**
     * Validate the data types
     * for the given operation
     * @param expectedType
     * @param op
     */
    public static void validateDataType(DataBuffer.Type expectedType, Op op) {
        if (op.x() != null && op.x().data().dataType() == DataBuffer.Type.COMPRESSED) {
            Nd4j.getCompressor().decompressi(op.x());
        }

        if (op.y() != null && op.y().data().dataType() == DataBuffer.Type.COMPRESSED) {
            Nd4j.getCompressor().decompressi(op.y());
        }

        if (op.z() != null && op.z().data().dataType() == DataBuffer.Type.COMPRESSED) {
            Nd4j.getCompressor().decompressi(op.z());
        }


        if (op.x() != null && op.x().data().dataType() != expectedType
                        && op.x().data().dataType() != DataBuffer.Type.COMPRESSED)
            throw new ND4JIllegalStateException("op.X dataType is [" + op.x().data().dataType()
                            + "] instead of expected [" + expectedType + "]");

        if (op.z() != null && op.z().data().dataType() != expectedType
                        && op.z().data().dataType() != DataBuffer.Type.COMPRESSED)
            throw new ND4JIllegalStateException("op.Z dataType is [" + op.z().data().dataType()
                            + "] instead of expected [" + expectedType + "]");

        if (op.y() != null && op.y().data().dataType() != expectedType)
            throw new ND4JIllegalStateException("op.Y dataType is [" + op.y().data().dataType()
                            + "] instead of expected [" + expectedType + "]");


    }

    public static void validateDataType(DataBuffer.Type expectedType, INDArray... operands) {
        if (operands == null || operands.length == 0)
            return;

        int cnt = 0;
        for (INDArray operand : operands) {
            if (operand == null)
                continue;

            if (operand.data().dataType() != expectedType)
                throw new ND4JIllegalStateException("INDArray [" + cnt++ + "] dataType is [" + operand.data().dataType()
                                + "] instead of expected [" + expectedType + "]");
        }
    }

    @Override
    public TADManager getTADManager() {
        throw new UnsupportedOperationException();
    }

    /**
     * This method return set of key/value and key/key/value objects, describing current environment
     *
     * @return
     */
    @Override
    public Properties getEnvironmentInformation() {
        Properties environment = new Properties();
        environment.put(Nd4jEnvironment.CPU_CORES_KEY, Runtime.getRuntime().availableProcessors());
        environment.put(Nd4jEnvironment.HOST_TOTAL_MEMORY_KEY, Runtime.getRuntime().maxMemory());
        environment.put(Nd4jEnvironment.OS_KEY, System.getProperty("os.name"));
        return environment;
    }

    @Override
    public void printEnvironmentInformation() {
        Properties env = getEnvironmentInformation();
        double memory = ((Long) env.get("memory.available")) / (double) 1024 / 1024 / 1024;
        String fm = String.format("%.1f", memory);
        log.info("Backend used: [{}]; OS: [{}]", env.get("backend"), env.get("os"));
        log.info("Cores: [{}]; Memory: [{}GB];", env.get("cores"), fm);
        log.info("Blas vendor: [{}]", env.get("blas.vendor"));
    }

    @Override
    public void push() {
        // no-op
    }

    @Override
    public void commit() {
        // no-op
    }


    @Override
    public INDArray thresholdEncode(INDArray input, double threshold) {
        throw new UnsupportedOperationException("Not yet implemented");
    }

    @Override
    public INDArray thresholdEncode(INDArray input, double threshold, Integer boundary) {
        throw new UnsupportedOperationException("Not yet implemented");
    }

    @Override
    public INDArray thresholdDecode(INDArray encoded, INDArray target) {
        throw new UnsupportedOperationException("Not yet implemented");
    }

    @Override
    public long bitmapEncode(INDArray indArray, INDArray target, double threshold) {
        throw new UnsupportedOperationException("Not yet implemented");
    }

    @Override
    public INDArray bitmapEncode(INDArray indArray, double threshold) {
        DataBuffer buffer = Nd4j.getDataBufferFactory().createInt(indArray.length() / 16 + 5);

        INDArray ret = Nd4j.createArrayFromShapeBuffer(buffer, indArray.shapeInfoDataBuffer());

        bitmapEncode(indArray, ret, threshold);

        return ret;
    }

    @Override
    public INDArray bitmapDecode(INDArray encoded, INDArray target) {
        throw new UnsupportedOperationException("Not yet implemented");
    }


    @Override
    public Map getCustomOperations() {
        throw new UnsupportedOperationException();
    }

    @Override
    public void exec(CustomOp op) {
        throw new UnsupportedOperationException();
    }

    @Override
    public List calculateOutputShape(CustomOp op) {
        throw new UnsupportedOperationException();
    }


    @Override
    public void enableDebugMode(boolean reallyEnable) {
        throw new UnsupportedOperationException();
    }

    @Override
    public void enableVerboseMode(boolean reallyEnable) {
        throw new UnsupportedOperationException();
    }

    @Override
    public void registerGraph(long id, Pointer graph) {
        throw new UnsupportedOperationException("Not yet implemented");
    }

    @Override
    public Map executeGraph(long id, Map map) {
        throw new UnsupportedOperationException("Not yet implemented");
    }

    @Override
    public void forgetGraph(long id) {
        throw new UnsupportedOperationException("Not yet implemented");
    }


    /**
     * This method allows to set desired number of elements per thread, for performance optimization purposes.
     * I.e. if array contains 2048 elements, and threshold is set to 1024, 2 threads will be used for given op execution.
     * 

* Default value: 1024 * * @param threshold */ @Override public void setElementsThreshold(int threshold) { // no-op } /** * This method allows to set desired number of sub-arrays per thread, for performance optimization purposes. * I.e. if matrix has shape of 64 x 128, and threshold is set to 8, each thread will be processing 8 sub-arrays (sure, if you have 8 core cpu). * If your cpu has, say, 4, cores, only 4 threads will be spawned, and each will process 16 sub-arrays *

* Default value: 8 * * @param threshold */ @Override public void setTadThreshold(int threshold) { // no-op } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy