All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.autodiff.samediff.serde.FlatBuffersMapper Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
package org.nd4j.autodiff.samediff.serde;

import com.google.flatbuffers.FlatBufferBuilder;
import lombok.NonNull;
import lombok.val;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.VariableType;
import org.nd4j.base.Preconditions;
import org.nd4j.graph.*;
import org.nd4j.imports.converters.DifferentialFunctionClassHolder;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.*;

import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.util.ArrayUtil;

import java.nio.ByteOrder;
import java.util.*;

public class FlatBuffersMapper {

    private FlatBuffersMapper(){ }

    /**
     * This method converts enums for DataType
     *
     * @param type
     * @return
     */
    public static byte getDataTypeAsByte(@NonNull org.nd4j.linalg.api.buffer.DataType type) {
        switch (type) {
            case FLOAT:
                return DataType.FLOAT;
            case DOUBLE:
                return DataType.DOUBLE;
            case HALF:
                return DataType.HALF;
            case INT:
                return DataType.INT32;
            case LONG:
                return DataType.INT64;
            case BOOL:
                return DataType.BOOL;
            case SHORT:
                return DataType.INT16;
            case BYTE:
                return DataType.INT8;
            case UBYTE:
                return DataType.UINT8;
            case UTF8:
                return DataType.UTF8;
            default:
                throw new ND4JIllegalStateException("Unknown or unsupported DataType used: [" + type + "]");
        }
    }

    /**
     * This method converts enums for DataType
     *
     * @param val
     * @return
     */
    public static org.nd4j.linalg.api.buffer.DataType getDataTypeFromByte(byte val) {
        if (val == DataType.FLOAT)
            return org.nd4j.linalg.api.buffer.DataType.FLOAT;
        else if (val == DataType.DOUBLE)
            return org.nd4j.linalg.api.buffer.DataType.DOUBLE;
        else if (val == DataType.HALF)
            return  org.nd4j.linalg.api.buffer.DataType.HALF;
        else if (val == DataType.INT32)
            return org.nd4j.linalg.api.buffer.DataType.INT;
        else if (val == DataType.INT64)
            return org.nd4j.linalg.api.buffer.DataType.LONG;
        else if (val == DataType.INT8)
            return org.nd4j.linalg.api.buffer.DataType.BYTE;
        else if (val == DataType.BOOL)
            return org.nd4j.linalg.api.buffer.DataType.BOOL;
        else if (val == DataType.UINT8)
            return org.nd4j.linalg.api.buffer.DataType.UBYTE;
        else if (val == DataType.INT16)
            return org.nd4j.linalg.api.buffer.DataType.SHORT;
        else if (val == DataType.UTF8)
            return org.nd4j.linalg.api.buffer.DataType.UTF8;
        else
            throw new RuntimeException("Unknown datatype: " + val);
    }




    /**
     * This method return operation ID for given op name/type pair.
     *
     * @param name
     * @param type
     * @return
     */
    public static long getOpNum(String name, Op.Type type) {
        if (type == Op.Type.LOOP) {
            return 0;
        } else if (type == Op.Type.RETURN) {
            return 40;
        } else if (type == Op.Type.IF) {
            return 30;
        } else if (type == Op.Type.CONDITIONAL) {
            return 10;
        } else if (type == Op.Type.MERGE) {
            return 60L;
        } else if (type == Op.Type.LOOP_COND) {
            return 70L;
        } else if (type == Op.Type.NEXT_ITERATION) {
            return 80L;
        } else if (type == Op.Type.EXIT) {
            return 90L;
        } else if (type == Op.Type.ENTER) {
            return 100L;
        } else if (type == Op.Type.CUSTOM) {
            val name2 = Nd4j.getExecutioner().getCustomOperations().get(name.toLowerCase());
            if (name2 == null) {
                val name3 = Nd4j.getExecutioner().getCustomOperations().get(name);
                if (name3 == null)
                    return 0;
                else
                    return name3.getHash();
            } else
                return name2.getHash();
            //return Nd4j.getExecutioner().getCustomOperations().get(name.toLowerCase()).getHash();

        } else {
            try {
                DifferentialFunction op =  DifferentialFunctionClassHolder.getInstance().getInstance(name);
                return  op.opNum();
            } catch (Exception e) {
                throw new RuntimeException("Could not find op number for operation: [" + name + "]",e);
            }
        }
    }


    /**
     * This method converts enums for Op.Type
     *
     * @param type Byte representing the op type
     * @return Op type
     */
    public static Op.Type getTypeFromByte(byte type) {
        switch (type) {
            case OpType.SCALAR:
                return Op.Type.SCALAR;
            case OpType.SCALAR_BOOL:
                return Op.Type.SCALAR_BOOL;
            case OpType.BROADCAST:
                return Op.Type.BROADCAST;
            case OpType.BROADCAST_BOOL:
                return Op.Type.BROADCAST_BOOL;
            case OpType.TRANSFORM_BOOL:
                return Op.Type.TRANSFORM_BOOL;
            case OpType.TRANSFORM_FLOAT:
                return Op.Type.TRANSFORM_FLOAT;
            case OpType.TRANSFORM_SAME:
                return Op.Type.TRANSFORM_SAME;
            case OpType.TRANSFORM_ANY:
                return Op.Type.TRANSFORM_ANY;
            case OpType.TRANSFORM_STRICT:
                return Op.Type.TRANSFORM_STRICT;
            case OpType.REDUCE_BOOL:
                return Op.Type.REDUCE_BOOL;
            case OpType.REDUCE_LONG:
                return Op.Type.REDUCE_LONG;
            case OpType.REDUCE_FLOAT:
                return Op.Type.REDUCE_FLOAT;
            case OpType.REDUCE_SAME:
                return Op.Type.REDUCE_SAME;
            case OpType.REDUCE_3:
                return Op.Type.REDUCE3;
            case OpType.INDEX_REDUCE:
                return Op.Type.INDEXREDUCE;
            case OpType.RANDOM:
                return Op.Type.RANDOM;
            case OpType.LOGIC:
                return Op.Type.META;
            case OpType.CUSTOM:
                return Op.Type.CUSTOM;
            case OpType.PAIRWISE:
                return Op.Type.PAIRWISE;
            case OpType.PAIRWISE_BOOL:
                return Op.Type.PAIRWISE_BOOL;
            case OpType.SUMMARYSTATS:
                return Op.Type.SUMMARYSTATS;
            default:
                throw new UnsupportedOperationException("Unknown op type passed in: " + type);
        }
    }

    /**
     * This method converts an Op.Type to it's corresponding byte value
     *
     * @param type type to convert
     * @return Byte representing the op type
     */
    public static byte getFlatOpType(Op.Type type) {
        switch (type) {
            case SCALAR:
                return OpType.SCALAR;
            case SCALAR_BOOL:
                return OpType.SCALAR_BOOL;
            case BROADCAST:
                return OpType.BROADCAST;
            case BROADCAST_BOOL:
                return OpType.BROADCAST_BOOL;
            case TRANSFORM_BOOL:
                return OpType.TRANSFORM_BOOL;
            case TRANSFORM_FLOAT:
                return OpType.TRANSFORM_FLOAT;
            case TRANSFORM_SAME:
                return OpType.TRANSFORM_SAME;
            case TRANSFORM_ANY:
                return OpType.TRANSFORM_ANY;
            case TRANSFORM_STRICT:
                return OpType.TRANSFORM_STRICT;
            case SPECIAL:
                return OpType.TRANSFORM_STRICT;
            case VARIANCE:
            case REDUCE_FLOAT:
                return OpType.REDUCE_FLOAT;
            case REDUCE_BOOL:
                return OpType.REDUCE_BOOL;
            case REDUCE_SAME:
                return OpType.REDUCE_SAME;
            case REDUCE_LONG:
                return OpType.REDUCE_LONG;
            case REDUCE3:
                return OpType.REDUCE_3;
            case INDEXREDUCE:
                return OpType.INDEX_REDUCE;
            case RANDOM:
                return OpType.RANDOM;
            case MERGE:
            case CONDITIONAL:
            case LOOP:
            case RETURN:
            case ENTER:
            case EXIT:
            case NEXT_ITERATION:
            case LOOP_COND:
            case IF:
                return OpType.LOGIC;
            case CUSTOM:
                return OpType.CUSTOM;
            case PAIRWISE:
                return OpType.PAIRWISE;
            case PAIRWISE_BOOL:
                return OpType.PAIRWISE_BOOL;
            case SUMMARYSTATS:
                return OpType.SUMMARYSTATS;
            default:
                throw new UnsupportedOperationException("Unknown op type passed in: " + type);
        }
    }


    /**
     * This method just converts enums
     *
     * @param val
     * @return
     */
    public static ByteOrder getOrderFromByte(byte val) {
        if (val == org.nd4j.graph.ByteOrder.LE)
            return ByteOrder.LITTLE_ENDIAN;
        else
            return ByteOrder.BIG_ENDIAN;
    }

    /**
     * This method returns current byte order for this JVM as libnd4j enum
     *
     * @return
     */
    public static byte getOrderAsByte() {
        if (ByteOrder.nativeOrder().equals(ByteOrder.BIG_ENDIAN))
            return org.nd4j.graph.ByteOrder.BE;
        else
            return org.nd4j.graph.ByteOrder.LE;
    }

    public static DifferentialFunction fromFlatNode(FlatNode fn){

        int id = fn.id();               //ID of the node
        String name = fn.name();        //Name of the node, NOT the name of the op
        Op.Type opType = FlatBuffersMapper.getTypeFromByte(fn.opType());
        long opNum = fn.opNum();        //Op num: hash for custom, number for legacy
        int[] input = new int[fn.inputLength()];
        for( int i=0; i props = FlatBuffersMapper.mapFlatPropertiesToFunctionProperties(Arrays.asList(flatProperties));


        if(opType == Op.Type.CUSTOM) {
            String opName = fn.opName();
            Class c = DifferentialFunctionClassHolder.getInstance().customOpClassForHashAndName(opNum, opName);

            Preconditions.checkNotNull(c, "Could not find class for hash %s", opNum);

            DifferentialFunction op;
            try {
                op = (DifferentialFunction) c.newInstance();
            } catch (IllegalAccessException | InstantiationException e) {
                throw new RuntimeException("Error creating differential function instance of type " + c);
            }
            op.setOwnName(name);

            //Set input SDVariables:

            //Set args:
            //op.addTArgument();
            ((CustomOp) op).addIArgument(extraInteger);
            ((CustomOp) op).addTArgument(extraParams);
            ((CustomOp) op).addBArgument(extraBools);

            op.setPropertiesForFunction(props);
            return op;
        } else {
            Class c = LegacyOpMapper.getLegacyOpClassForId(opType, (int)opNum);
            Op op;
            try {
                op = (Op) c.newInstance();
            } catch (IllegalAccessException | InstantiationException e) {
                throw new RuntimeException("Error creating differential function (Op) instance of type " + c);
            }

            if(extraParams.length > 0) {
                //Assume that extraParams length 0 means extraArgs was originally null, NOT originally length 0
                Object[] extraParamsObj = new Object[extraParams.length];
                for (int i = 0; i < extraParams.length; i++) {
                    extraParamsObj[i] = extraParams[i];
                }
                op.setExtraArgs(extraParamsObj);
            }
            if(opType == Op.Type.SCALAR || opType == Op.Type.SCALAR_BOOL){
                ScalarOp sOp = (ScalarOp)op;
                sOp.setScalar(scalar);
            } else if(opType == Op.Type.REDUCE_FLOAT || opType == Op.Type.REDUCE3 || opType == Op.Type.SUMMARYSTATS || opType == Op.Type.VARIANCE
                    || opType == Op.Type.REDUCE_BOOL || opType == Op.Type.REDUCE_LONG || opType == Op.Type.REDUCE_SAME) {
                val ba = (BaseReduceOp) op; //Reduce3 ops are also all BaseAccumulations
                ba.setDimensions(dimensions);
                ba.setDimensionz(Shape.ndArrayDimFromInt(dimensions));
            } else if(opType == Op.Type.INDEXREDUCE){
                BaseIndexAccumulation bia = (BaseIndexAccumulation)op;
                bia.setDimensions(dimensions);
                bia.setDimensionz(Shape.ndArrayDimFromInt(dimensions));
            }
            /*
            Op types that don't need any extra/special mapping:
            TRANSFORM_BOOL - BooleanNot, IsFinite, IsInf, IsNaN, MatchConditionTransorm
            TRANSFORM_ANY - IsMax, Assign
            TRANSFORM_FLOAT - Histogram, Sqrt
            TRANSFORM_STRICT - Cos, Log, Sigmoid, etc
            TRANSFORM_SAME - Abs, Ceil, etc
             */

            ((DifferentialFunction)op).setPropertiesForFunction(props);
            return (DifferentialFunction)op;
        }
    }

    private static final boolean[] EMPTY_BOOLEAN = new boolean[0];
    private static final int[] EMPTY_INT = new int[0];
    private static final long[] EMPTY_LONG = new long[0];
    private static final double[] EMPTY_DOUBLE = new double[0];

    public static int[] mapFunctionPropertiesToFlatProperties(FlatBufferBuilder fbb, Map fnProps){

        int[] outIdxs = new int[fnProps.size()];
        int count = 0;
        for(Map.Entry e : fnProps.entrySet()){
            //Possible types here: primitives (as Number objects), primitive arrays, Strings, String arrays, multi-dimensional string/primitives
            Object v = e.getValue();
            int iname = fbb.createString(e.getKey());

            int[] i = null;
            long[] l = null;
            double[] d = null;
            int[] aIdx = null;
            boolean[] b = null;
            int[] sIdx = null;
            int[] shape = null;



            if(v == null) {
                //No op
            } else if(v instanceof Boolean){
                b = new boolean[]{(Boolean)v};
            } else if(v instanceof Number) {
                if (v instanceof Double) {
                    d = new double[]{(Double) v};
                } else if (v instanceof Integer) {
                    i = new int[]{(Integer) v};
                } else if (v instanceof Long) {
                    l = new long[]{(Long) v};
                } else {
                    throw new UnsupportedOperationException("Unable to map property \"" + e.getKey() + "\" of type " + v.getClass());
                }
            } else if(v instanceof String) {
                String str = (String) v;
                int strOffset = fbb.createString(str);
                sIdx = new int[]{strOffset};
            } else if(v instanceof org.nd4j.linalg.api.buffer.DataType ){
                String str = v.toString();
                int strOffset = fbb.createString(str);
                sIdx = new int[]{strOffset};
            } else if(v instanceof INDArray){
                INDArray arr = (INDArray)v;
                aIdx = new int[]{arr.toFlatArray(fbb)};
            } else if(v.getClass().isArray()){
                if(v.getClass().getComponentType().isPrimitive()){
                    if(v instanceof boolean[]) {
                        b = (boolean[])v;
                        shape = new int[]{b.length};
                    } else if(v instanceof double[]){
                        d = (double[])v;
                        shape = new int[]{d.length};
                    } else if(v instanceof int[]){
                        i = (int[])v;
                        shape = new int[]{i.length};
                    } else if(v instanceof long[]){
                        l = (long[])v;
                        shape = new int[]{l.length};
                    } else {
                        throw new UnsupportedOperationException("Unable to map property \"" + e.getKey() + "\" of type " + v.getClass());
                    }
                } else if (v instanceof String[]) {
                    //String[]
                    String[] strArr = (String[]) v;
                    sIdx = new int[strArr.length];
                    for (int j = 0; j < strArr.length; j++) {
                        sIdx[j] = fbb.createString(strArr[j]);
                    }
                    shape = new int[]{strArr.length};
                } else if (v instanceof INDArray[]){
                    INDArray[] arrArr = (INDArray[])v;
                    aIdx = new int[arrArr.length];
                    for( int j=0; j mapFlatPropertiesToFunctionProperties(Iterable list){
        Map out = new HashMap<>();
        for(FlatProperties p : list){

            String name = p.name();
            //Work out type:
            if(p.shapeLength() > 0){
                //Array type
                int[] shape = new int[p.shapeLength()];
                for( int i=0; i 0){
                    int[] iArr = new int[p.iLength()];
                    for( int i=0; i 0){
                    double[] dArr = new double[p.dLength()];
                    for( int i=0; i 0) {
                    long[] lArr = new long[p.lLength()];
                    for (int i = 0; i < lArr.length; i++) {
                        lArr[i] = p.l(i);
                    }
                    if(shape.length == 0 || shape.length == 1) {
                        out.put(name, lArr);
                    } else if(shape.length == 2){
                        out.put(name, ArrayUtil.reshapeLong(lArr, shape[0], shape[1]));
                    } else if(shape.length == 3){
                        out.put(name, ArrayUtil.reshapeLong(lArr, shape[0], shape[1], shape[2]));
                    }
                } else if(p.bLength() > 0){
                    boolean[] bArr = new boolean[p.bLength()];
                    for( int i=0; i 0){
                    String[] sArr = new String[p.sLength()];
                    for( int i=0; i 0){
                    INDArray[] iArr = new INDArray[p.aLength()];
                    for( int i=0; i 0) {
                    out.put(name, p.b(0));
                } else if(p.iLength() > 0){
                    out.put(name, p.i(0));
                } else if(p.lLength() > 0){
                    out.put(name, p.l(0));
                } else if(p.dLength() > 0){
                    out.put(name, p.d(0));
                } else if(p.sLength() > 0){
                    out.put(name, p.s(0));
                } else if(p.aLength() > 0){
                    FlatArray fa = p.a(0);
                    out.put(name, Nd4j.createFromFlatArray(fa));
                } else {
                    //null property case
                    out.put(name, null);
                }
            }
        }
        return out;
    }

    public static byte toVarType(VariableType variableType){
        switch (variableType){
            case VARIABLE:
                return VarType.VARIABLE;
            case CONSTANT:
                return VarType.CONSTANT;
            case ARRAY:
                return VarType.ARRAY;
            case PLACEHOLDER:
                return VarType.PLACEHOLDER;
            default:
                throw new RuntimeException("Unknown variable type: " + variableType);
        }
    }

    public static VariableType fromVarType(byte varType){
        switch (varType){
            case VarType.VARIABLE:
                return VariableType.VARIABLE;
            case VarType.CONSTANT:
                return VariableType.CONSTANT;
            case VarType.ARRAY:
                return VariableType.ARRAY;
            case VarType.PLACEHOLDER:
                return VariableType.PLACEHOLDER;
            default:
                throw new IllegalStateException("Unknown VarType byte value:" + varType);
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy