All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.linalg.api.ops.impl.shape.CreateView Maven / Gradle / Ivy

The newest version!
/*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */
package org.nd4j.linalg.api.ops.impl.shape;

import org.nd4j.autodiff.samediff.SDIndex;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.common.base.Preconditions;
import org.nd4j.common.util.ArrayUtil;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;

import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;

public class CreateView extends DynamicCustomOp  {

    public static int POINT_TYPE = 0;
    public static int INTERVAL_TYPE = 1;
    public static int ALL_TYPE = 2;
    public static int NEW_AXIS = 3;

    public static int DEFAULT_INCLUSIVE = 1;

    public CreateView() {
    }

    public CreateView(INDArray[] inputs) {
        super(inputs, null);
    }

    public CreateView(SameDiff sameDiff, SDVariable[] args) {
        super(sameDiff, args);
    }

    public CreateView(SameDiff sd, SDVariable input, SDVariable[] indices) {
        this(sd, ArrayUtil.combine(new SDVariable[]{input},indices));
    }

    public CreateView(INDArray input, INDArray[] indices) {
        this(ArrayUtil.combine(new INDArray[]{input},indices));
    }

    public static SDVariable createInterval(SameDiff sameDiff, SDVariable intervalInputBegin, SDVariable intervalInputEnd, SDVariable intervalStrideInput, SDVariable inclusive) {
        return createInterval(sameDiff,null,intervalInputBegin,intervalInputEnd,intervalStrideInput,inclusive);
    }

    @Override
    public List calculateOutputDataTypes(List dataTypes) {
        //Output type is same as input type
        return Collections.singletonList(dataTypes.get(0));
    }

    @Override
    public int getNumOutputs() {
        return 1;
    }

    @Override
    public String opName() {
        return "create_view";
    }

    /**
     *  See
     * {@link #createPoint(SameDiff,String,long)}
     * for mroe information.
     * @param sameDiff
     * @param offset
     * @return
     */
    public static SDVariable createPoint(SameDiff sameDiff,long offset) {
        return createPoint(sameDiff,null,offset);
    }

    /**
     *  See
     * {@link #createPoint(SameDiff,String,long)}
     * for mroe information.
     * @param sameDiff
     * @param offset
     * @return
     */
    public static SDVariable createPoint(SameDiff sameDiff,SDVariable offset) {
        return createPoint(sameDiff,null,offset);
    }


    /**
     * Create a {@link SDVariable}
     * representing a point array with the specified name.
     * This is used for specifying the equivalent of a
     * {@link org.nd4j.autodiff.samediff.SDIndex#point(SDVariable)}
     * @param sameDiff the samediff instance fo use
     * @param name the name of the variable (maybe null)
     * @param offset the offset for the point
     * @return the created variable
     */
    public static SDVariable createPoint(SameDiff sameDiff,String name,long offset) {
        INDArray arr = Nd4j.createFromArray(new long[]{POINT_TYPE,1,1,offset, DEFAULT_INCLUSIVE});
        return sameDiff.var(name,arr);
    }


    /**
     * Create a {@link SDVariable}
     * representing a point array with the specified name.
     * This is used for specifying the equivalent of a
     * {@link org.nd4j.autodiff.samediff.SDIndex#point(SDVariable)}
     * @param sameDiff the samediff instance fo use
     * @param name the name of the variable (maybe null)
     * @param offset the offset for the point
     * @return the created variable
     */
    public static SDVariable createPoint(SameDiff sameDiff,String name,SDVariable offset) {
        return sameDiff.concat(name,0,
                sameDiff.constant(POINT_TYPE).reshape(1).castTo(DataType.INT64),
                sameDiff.constant(1).reshape(1).castTo(DataType.INT64),
                sameDiff.constant(1).reshape(1).castTo(DataType.INT64),
                offset.reshape(1).castTo(DataType.INT64),
                sameDiff.constant(DEFAULT_INCLUSIVE).reshape(1).castTo(DataType.INT64));
    }



    /**
     * See {@link #createAll(SameDiff, String)}
     * for more information
     * @param sameDiff
     * @return
     */
    public static SDVariable createAll(SameDiff sameDiff) {
        return createAll(sameDiff,null);
    }
    /**
     * Create an {@link SDVariable}
     * representing an {@link SDIndex#all()}
     * variable.
     * @param sameDiff the samediff instance to use
     * @param name the name of the variable (maybe null)
     * @return the created variable
     */
    public static SDVariable createAll(SameDiff sameDiff,String name) {
        INDArray arr = Nd4j.createFromArray(new long[]{ALL_TYPE,0,1, DEFAULT_INCLUSIVE});
        return sameDiff.var(name,arr);
    }

    /**
     * Create an {@link SDVariable}
     * representing a new axis which creates a new index
     * of length 1 in the specified input
     * @param sameDiff the samediff instance to use
     * @param name the name of the variable
     * @return the created variable
     */
    public static SDVariable createNewAxis(SameDiff sameDiff,String name) {
        INDArray arr = Nd4j.createFromArray(new long[]{NEW_AXIS,1,10, DEFAULT_INCLUSIVE});
        return sameDiff.var(name,arr);
    }

    /**
     * See {@link #createNewAxis(SameDiff,String)}
     * for more information.
     * @param sameDiff
     * @return
     */
    public static SDVariable createNewAxis(SameDiff sameDiff) {
        return createNewAxis(sameDiff,null);
    }

    /**
     * Create an interval representing {@link SDIndex#interval(Long, Long)}
     *
     * @param sameDiff the samediff instance to use
     * @param name the name of the variable
     * @param start the start of the interval
     * @param end the end of the interval
     * @param stride the stride
     * @param inclusive whether the interval is inclusive or not 0 for false 1 for true
     * @return
     */
    public static SDVariable createInterval(SameDiff sameDiff,String name,long start,long end,long stride,long inclusive) {
        INDArray arr = Nd4j.createFromArray(new long[]{INTERVAL_TYPE,2,1,start,end,stride,inclusive});
        return sameDiff.var(name,arr);
    }


    /**
     * Create an interval representing {@link SDIndex#interval(Long, Long)}
     *
     * @param sameDiff the samediff instance to use
     * @param name the name of the variable
     * @param start the start of the interval
     * @param end the end of the interval
     * @param stride the stride
     * @param inclusive whether the interval is inclusive or not 0 for false 1 for true
     * @return
     */
    public static SDVariable createInterval(SameDiff sameDiff,String name,SDVariable start,SDVariable end,SDVariable stride,SDVariable inclusive) {
       if(stride == null)
           stride = sameDiff.constant(1).castTo(DataType.INT64).reshape(1);
       if(inclusive == null)
           inclusive = sameDiff.constant(0).castTo(DataType.INT64).reshape(1);
        return sameDiff.concat(name,0,
                sameDiff.constant(INTERVAL_TYPE).reshape(1).castTo(DataType.INT64),
                sameDiff.constant(2).reshape(1).castTo(DataType.INT64),
                sameDiff.constant(1).reshape(1).castTo(DataType.INT64),
                start.reshape(1).castTo(DataType.INT64)
                ,end.reshape(1).castTo(DataType.INT64),
                stride.reshape(1).castTo(DataType.INT64),
                inclusive.castTo(DataType.INT64).reshape(1));
    }


    /**
     * See {@link #createInterval(SameDiff, String, long, long, long, long)}
     * for more information.
     * @param sameDiff
     * @param start
     * @param end
     * @param stride
     * @param inclusive
     * @return
     */
    public static SDVariable createInterval(SameDiff sameDiff,long start,long end,long stride,long inclusive) {
        return createInterval(sameDiff,null,start,end,stride,inclusive);
    }


    public static INDArray createFrom(INDArray input,INDArray...indices) {
       return input.get(indices(indices));
    }



   public static INDArrayIndex[] indices(INDArray...indexArrs) {
        return Arrays.stream(indexArrs).map(CreateView::fromIndexArr).collect(Collectors.toList())
                .toArray(new INDArrayIndex[indexArrs.length]);
   }
    public static INDArrayIndex fromIndexArr(INDArray index) {
        int idx = index.getInt(0);
        if(idx == POINT_TYPE) {
            int getPoint = index.getInt(3);
            return NDArrayIndex.point(getPoint);
        } else if(idx == INTERVAL_TYPE) {
            int start = index.getInt(3);
            int end = index.getInt(4);
            int stride = index.getInt(5);
            boolean inclusive =  index.getInt(6) > 0;
            return NDArrayIndex.interval(start,stride,end,inclusive);
        } else if(idx == NEW_AXIS) {
            return NDArrayIndex.newAxis();
        } else if(idx == ALL_TYPE) {
            return NDArrayIndex.all();
        } else {
            throw new IllegalArgumentException("Invalid type. Must be 1 of: " + POINT_TYPE + " (point type) " + INTERVAL_TYPE + " (interval type)" + NEW_AXIS + " (new axis) " + ALL_TYPE + " (all) ");
        }
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy