All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ai.djl.ndarray.index.full.NDIndexFullSlice Maven / Gradle / Ivy

There is a newer version: 0.30.0
Show newest version
/*
 * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
 * with the License. A copy of the License is located at
 *
 * http://aws.amazon.com/apache2.0/
 *
 * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
 * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
 * and limitations under the License.
 */
package ai.djl.ndarray.index.full;

import ai.djl.ndarray.index.NDIndex;
import ai.djl.ndarray.index.dim.NDIndexAll;
import ai.djl.ndarray.index.dim.NDIndexElement;
import ai.djl.ndarray.index.dim.NDIndexFixed;
import ai.djl.ndarray.index.dim.NDIndexSlice;
import ai.djl.ndarray.types.Shape;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;

/** An index as a slice on all dimensions where some dimensions can be squeezed. */
public final class NDIndexFullSlice {
    private long[] min;
    private long[] max;
    private long[] step;
    private int[] toSqueeze;
    private Shape shape;
    private Shape squeezedShape;

    /**
     * Constructs a {@link NDIndexFullSlice}.
     *
     * @param min the min for each axis
     * @param max the max for each axis
     * @param step the step for each axis
     * @param toSqueeze the axes to squeeze after slicing
     * @param shape the result shape (without squeezing)
     * @param squeezedShape the result shape (with squeezing)
     */
    private NDIndexFullSlice(
            long[] min,
            long[] max,
            long[] step,
            int[] toSqueeze,
            Shape shape,
            Shape squeezedShape) {
        this.min = min;
        this.max = max;
        this.step = step;
        this.toSqueeze = toSqueeze;
        this.shape = shape;
        this.squeezedShape = squeezedShape;
    }

    /**
     * Returns (if possible) the {@link NDIndexFullSlice} representation of an {@link NDIndex}.
     *
     * @param index the index to represent
     * @param target the shape of the array to index
     * @return the full slice representation or nothing if it can't represent the index
     */
    public static Optional fromIndex(NDIndex index, Shape target) {
        if (!index.stream()
                .allMatch(
                        ie ->
                                ie instanceof NDIndexAll
                                        || ie instanceof NDIndexFixed
                                        || ie instanceof NDIndexSlice)) {
            return Optional.empty();
        }
        int ellipsisIndex = index.getEllipsisIndex();
        int indDimensions = index.getRank();
        int targetDimensions = target.dimension();
        if (indDimensions > target.dimension()) {
            throw new IllegalArgumentException(
                    "The index has too many dimensions - "
                            + indDimensions
                            + " dimensions for array with "
                            + targetDimensions
                            + " dimensions");
        }
        long[] min = new long[targetDimensions];
        long[] max = new long[targetDimensions];
        long[] step = new long[targetDimensions];
        List toSqueeze = new ArrayList<>(targetDimensions);
        long[] shape = new long[targetDimensions];
        List squeezedShape = new ArrayList<>(targetDimensions);
        if (ellipsisIndex == -1 || ellipsisIndex == indDimensions) {
            // ellipsis in the end and non ellipsis case
            for (int i = 0; i < indDimensions; i++) {
                NDIndexElement ie = index.get(i);
                addSliceInfo(ie, i, target, min, max, step, toSqueeze, shape, squeezedShape);
            }
            for (int i = indDimensions; i < target.dimension(); i++) {
                padIndexAll(i, target, min, max, step, shape, squeezedShape);
            }
        } else if (ellipsisIndex == 0) {
            // ellipsis in the beginning
            int paddingDim = targetDimensions - indDimensions;
            int i;
            for (i = 0; i < paddingDim; ++i) {
                padIndexAll(i, target, min, max, step, shape, squeezedShape);
            }
            for (; i < targetDimensions; ++i) {
                NDIndexElement ie = index.get(i - paddingDim);
                addSliceInfo(ie, i, target, min, max, step, toSqueeze, shape, squeezedShape);
            }
        } else {
            // ellipsis in the middle
            int paddingDim = targetDimensions - indDimensions;
            int i;
            for (i = 0; i < ellipsisIndex; ++i) {
                NDIndexElement ie = index.get(i);
                addSliceInfo(ie, i, target, min, max, step, toSqueeze, shape, squeezedShape);
            }
            for (; i < paddingDim + ellipsisIndex; ++i) {
                padIndexAll(i, target, min, max, step, shape, squeezedShape);
            }
            for (; i < targetDimensions; ++i) {
                NDIndexElement ie = index.get(i - paddingDim);
                addSliceInfo(ie, i, target, min, max, step, toSqueeze, shape, squeezedShape);
            }
        }
        int[] squeeze = toSqueeze.stream().mapToInt(i -> i).toArray();
        NDIndexFullSlice fullSlice =
                new NDIndexFullSlice(
                        min, max, step, squeeze, new Shape(shape), new Shape(squeezedShape));
        return Optional.of(fullSlice);
    }

    private static void addSliceInfo(
            NDIndexElement ie,
            int i,
            Shape target,
            long[] min,
            long[] max,
            long[] step,
            List toSqueeze,
            long[] shape,
            List squeezedShape) {
        if (ie instanceof NDIndexFixed) {
            NDIndexFixed fixed = ((NDIndexFixed) ie);
            long rawIndex = fixed.getIndex();
            min[i] = rawIndex < 0 ? Math.floorMod(rawIndex, target.get(i)) : rawIndex;
            max[i] = min[i] + 1;
            step[i] = 1;
            toSqueeze.add(i);
            shape[i] = 1;
        } else if (ie instanceof NDIndexSlice) {
            NDIndexSlice slice = (NDIndexSlice) ie;
            long rawMin = Optional.ofNullable(slice.getMin()).orElse(0L);
            min[i] = rawMin < 0 ? Math.floorMod(rawMin, target.get(i)) : rawMin;
            long rawMax = Optional.ofNullable(slice.getMax()).orElse(target.size(i));
            max[i] = rawMax < 0 ? Math.floorMod(rawMax, target.get(i)) : rawMax;
            step[i] = Optional.ofNullable(slice.getStep()).orElse(1L);
            shape[i] = (long) Math.ceil(((double) (max[i] - min[i])) / step[i]);
            squeezedShape.add(shape[i]);
        } else if (ie instanceof NDIndexAll) {
            padIndexAll(i, target, min, max, step, shape, squeezedShape);
        }
    }

    private static void padIndexAll(
            int i,
            Shape target,
            long[] min,
            long[] max,
            long[] step,
            long[] shape,
            List squeezedShape) {
        min[i] = 0;
        max[i] = target.size(i);
        step[i] = 1;
        shape[i] = target.size(i);
        squeezedShape.add(target.size(i));
    }

    /**
     * Returns the slice min for each axis.
     *
     * @return the slice min for each axis
     */
    public long[] getMin() {
        return min;
    }

    /**
     * Returns the slice max for each axis.
     *
     * @return the slice max for each axis
     */
    public long[] getMax() {
        return max;
    }

    /**
     * Returns the slice step for each axis.
     *
     * @return the slice step for each axis
     */
    public long[] getStep() {
        return step;
    }

    /**
     * Returns the squeeze array of axis.
     *
     * @return the squeeze array of axis
     */
    public int[] getToSqueeze() {
        return toSqueeze;
    }

    /**
     * Returns the slice shape without squeezing.
     *
     * @return the slice shape without squeezing
     */
    public Shape getShape() {
        return shape;
    }

    /**
     * Returns the slice shape with squeezing.
     *
     * @return the slice shape with squeezing
     */
    public Shape getSqueezedShape() {
        return squeezedShape;
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy