All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.bytebuddy.shape.ShapeMapper Maven / Gradle / Ivy

package org.nd4j.bytebuddy.shape;

import net.bytebuddy.ByteBuddy;
import net.bytebuddy.dynamic.DynamicType;
import net.bytebuddy.dynamic.loading.ClassLoadingStrategy;
import net.bytebuddy.implementation.Implementation;
import net.bytebuddy.implementation.bytecode.StackManipulation;
import net.bytebuddy.implementation.bytecode.constant.IntegerConstant;
import net.bytebuddy.implementation.bytecode.member.MethodReturn;
import net.bytebuddy.implementation.bytecode.member.MethodVariableAccess;
import net.bytebuddy.jar.asm.Label;
import net.bytebuddy.matcher.ElementMatchers;
import org.nd4j.bytebuddy.arithmetic.ByteBuddyIntArithmetic;
import org.nd4j.bytebuddy.arithmetic.stackmanipulation.OpStackManipulation;
import org.nd4j.bytebuddy.arrays.create.stackmanipulation.CreateIntArrayStackManipulation;
import org.nd4j.bytebuddy.arrays.stackmanipulation.ArrayStackManipulation;
import org.nd4j.bytebuddy.branching.stackmanipulation.IfeqNotEquals;
import org.nd4j.bytebuddy.frame.VisitFrameFullInt;
import org.nd4j.bytebuddy.frame.VisitFrameSameInt;
import org.nd4j.bytebuddy.gotoop.GoToOp;
import org.nd4j.bytebuddy.labelvisit.LabelVisitorStackManipulation;
import org.nd4j.bytebuddy.stackmanipulation.StackManipulationImplementation;
import org.nd4j.bytebuddy.storeint.stackmanipulation.StoreIntStackManipulation;
import org.nd4j.bytebuddy.storeref.stackmanipulation.StoreRefStackManipulation;

import java.util.ArrayList;
import java.util.List;

/**
 * @author Adam Gibson
 */
public class ShapeMapper {

    private ShapeMapper() {}

    /**
     * Get an ind2sub instance
     * based on the ordering and rank
     * @param ordering the ordering
     * @param rank the rank
     * @return the ind2sub instance
     */
    public static IndexMapper getInd2SubInstance(char ordering, int rank) {
        Implementation impl = ShapeMapper.getInd2Sub(ordering, rank);
        DynamicType.Unloaded c = new ByteBuddy().subclass(IndexMapper.class)
                        .method(ElementMatchers.isDeclaredBy(IndexMapper.class)).intercept(impl).make();

        Class dynamicType = (Class) c
                        .load(IndexMapper.class.getClassLoader(), ClassLoadingStrategy.Default.WRAPPER).getLoaded();
        try {
            return dynamicType.newInstance();
        } catch (Exception e) {
            throw new IllegalStateException("Unable to get index mapper for rank " + rank);
        }

    }


    /**
     * Get an ind2sub instance
     * based on the ordering and rank
     * @param rank the rank
     * @return the ind2sub instance
     */
    public static OffsetMapper getOffsetMapperInstance(int rank) {
        Implementation impl = ShapeMapper.getOffsetMapper(rank);
        DynamicType.Unloaded c = new ByteBuddy().subclass(OffsetMapper.class)
                        .method(ElementMatchers.isDeclaredBy(OffsetMapper.class)).intercept(impl).make();

        Class dynamicType = (Class) c
                        .load(OffsetMapper.class.getClassLoader(), ClassLoadingStrategy.Default.WRAPPER).getLoaded();
        try {
            return dynamicType.newInstance();
        } catch (Exception e) {
            throw new IllegalStateException("Unable to get index mapper for rank " + rank);
        }

    }

    /**
     * Get the offset mapper bytecode
     * for a particular rank
     * @param rank the rank of array
     *             to generate the offset mapper byte code for
     * @return the implementation of the offset mapper bytecode
     *
     */
    public static Implementation getOffsetMapper(int rank) {
        /**
         * Given:
         * int getOffset(int baseOffset,int[] shape,int[] stride,int[] indices);
         */
        //start offset is the base index to start at
        int startOffsetIndex = 1;
        //shape index is the index of the argument for shape
        int shapeIndex = 2;
        //stride index is the index of the argument for shape
        int strideIndex = 3;
        //indicesindex is the index for the indices
        int indicesIndex = 4;
        List impls = new ArrayList<>();

        for (int i = 0; i < rank; i++) {
            Label label = new Label();
            Label goToLabel = new Label();
            impls.add(MethodVariableAccess.INTEGER.loadOffset(startOffsetIndex));
            //load the array
            impls.add(MethodVariableAccess.REFERENCE.loadOffset(shapeIndex));
            //from the array load the current index
            impls.add(IntegerConstant.forValue(i));
            impls.add(ArrayStackManipulation.load());
            impls.add(new IfeqNotEquals(label));
            //load the stride  for the current index
            impls.add(MethodVariableAccess.REFERENCE.loadOffset(strideIndex));
            impls.add(IntegerConstant.forValue(i));
            impls.add(ArrayStackManipulation.load());
            //load the indices array at the current index
            impls.add(MethodVariableAccess.REFERENCE.loadOffset(indicesIndex));
            impls.add(IntegerConstant.forValue(i));
            impls.add(ArrayStackManipulation.load());

            impls.add(ByteBuddyIntArithmetic.IntegerMultiplication.INSTANCE);
            impls.add(new GoToOp(goToLabel));
            impls.add(new LabelVisitorStackManipulation(label));
            impls.add(new VisitFrameSameInt(0, 1));
            impls.add(IntegerConstant.forValue(i));
            impls.add(new LabelVisitorStackManipulation(goToLabel));
            impls.add(new VisitFrameFullInt(5, 2));
            //add to the offset +=
            impls.add(ByteBuddyIntArithmetic.IntegerAddition.INSTANCE);
            impls.add(new StoreIntStackManipulation(startOffsetIndex));

        }

        impls.add(MethodVariableAccess.INTEGER.loadOffset(startOffsetIndex));
        impls.add(MethodReturn.INTEGER);
        return new StackManipulationImplementation(
                        new StackManipulation.Compound(impls.toArray(new StackManipulation[impls.size()])));

    }

    /**
     * Get an implementation of
     * ind2sub
     * @param ordering the order to iterate in
     * @param rank the rank of the array
     * @return the implementation (in byte code) of
     * ind2subseq
     *
     */
    public static Implementation getInd2Sub(char ordering, int rank) {


        /**
         * Given signature:
         *  int[] map(int[] shape,int index,int numIndices,char ordering);
         Load int param grabs numIndices because the instance
         variable stack indexing starts with this at zero
        
         4 here represents creating a variable and storing
         the value of the last argument in the method in the value
        
         */
        int retArrayIndex = 4;

        //the index of the variable we use to start denomination
        //the index of the last method parameter
        int linearIndexArg = 2;
        //the index of the total number of indexes
        int totalindexarg = 3;

        //"this" is 0 and the shape array is the first argument
        int arrayArgIndex = 1;
        List impls = new ArrayList<>();
        //create the return array of the specified length
        impls.add(IntegerConstant.forValue(rank));
        impls.add(new CreateIntArrayStackManipulation());
        impls.add(new StoreRefStackManipulation(retArrayIndex));
        if (ordering == 'f') {
            //linearIndex of the assignment
            for (int i = rank - 1; i >= 0; i--) {
                //index /= shape[i]
                //load the linear index for divide
                impls.add(MethodVariableAccess.INTEGER.loadOffset(totalindexarg));
                //load the array
                impls.add(MethodVariableAccess.REFERENCE.loadOffset(arrayArgIndex));
                //load index of item to divide by
                impls.add(IntegerConstant.forValue(i));
                //load the item from the array based on the index
                impls.add(ArrayStackManipulation.load());
                //divide in place
                impls.add(OpStackManipulation.div());
                //store results
                impls.add(new StoreIntStackManipulation(totalindexarg));


                // ret[i] = index / numIndices;
                impls.add(MethodVariableAccess.REFERENCE.loadOffset(retArrayIndex));
                impls.add(IntegerConstant.forValue(i));
                impls.add(MethodVariableAccess.INTEGER.loadOffset(linearIndexArg));
                impls.add(MethodVariableAccess.INTEGER.loadOffset(totalindexarg));
                impls.add(OpStackManipulation.div());
                impls.add(ArrayStackManipulation.store());

                //   index %= denom;
                impls.add(MethodVariableAccess.INTEGER.loadOffset(linearIndexArg));
                impls.add(MethodVariableAccess.INTEGER.loadOffset(totalindexarg));
                impls.add(OpStackManipulation.mod());
                impls.add(new StoreIntStackManipulation(linearIndexArg));

            }



        } else {
            //index of the assignment
            for (int i = 0; i < rank; i++) {
                //index /= shape[i]
                //load the linear index for divide
                impls.add(MethodVariableAccess.INTEGER.loadOffset(totalindexarg));
                //load the array
                impls.add(MethodVariableAccess.REFERENCE.loadOffset(arrayArgIndex));
                //load index of item to divide by
                impls.add(IntegerConstant.forValue(i));
                //load the item from the array based on the index
                impls.add(ArrayStackManipulation.load());
                //divide in place
                impls.add(OpStackManipulation.div());
                //store results
                impls.add(new StoreIntStackManipulation(totalindexarg));


                // ret[i] = index / numIndices;
                impls.add(MethodVariableAccess.REFERENCE.loadOffset(retArrayIndex));
                impls.add(IntegerConstant.forValue(i));
                impls.add(MethodVariableAccess.INTEGER.loadOffset(linearIndexArg));
                impls.add(MethodVariableAccess.INTEGER.loadOffset(totalindexarg));
                impls.add(OpStackManipulation.div());
                impls.add(ArrayStackManipulation.store());

                //   index %= denom;
                impls.add(MethodVariableAccess.INTEGER.loadOffset(linearIndexArg));
                impls.add(MethodVariableAccess.INTEGER.loadOffset(totalindexarg));
                impls.add(OpStackManipulation.mod());
                impls.add(new StoreIntStackManipulation(linearIndexArg));

            }
        }


        impls.add(MethodVariableAccess.REFERENCE.loadOffset(retArrayIndex));
        impls.add(MethodReturn.REFERENCE);
        return new StackManipulationImplementation(
                        new StackManipulation.Compound(impls.toArray(new StackManipulation[impls.size()])));
    }


}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy