All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ai.djl.ndarray.index.NDArrayIndexer Maven / Gradle / Ivy

There is a newer version: 0.30.0
Show newest version
/*
 * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
 * with the License. A copy of the License is located at
 *
 * http://aws.amazon.com/apache2.0/
 *
 * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
 * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
 * and limitations under the License.
 */
package ai.djl.ndarray.index;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.index.dim.NDIndexBooleans;
import ai.djl.ndarray.index.dim.NDIndexElement;
import ai.djl.ndarray.index.full.NDIndexFullPick;
import ai.djl.ndarray.index.full.NDIndexFullSlice;
import java.util.List;
import java.util.Optional;

/** A helper class for {@link NDArray} implementations for operations with an {@link NDIndex}. */
public abstract class NDArrayIndexer {

    /**
     * Returns a subarray by picking the elements.
     *
     * @param array the array to get from
     * @param fullPick the elements to pick
     * @return the subArray
     */
    public abstract NDArray get(NDArray array, NDIndexFullPick fullPick);

    /**
     * Returns a subarray at the slice.
     *
     * @param array the array to get from
     * @param fullSlice the fullSlice index of the array
     * @return the subArray
     */
    public abstract NDArray get(NDArray array, NDIndexFullSlice fullSlice);

    /**
     * Returns a subarray at the given index.
     *
     * @param array the array to get from
     * @param index the index to get
     * @return the subarray
     */
    public NDArray get(NDArray array, NDIndex index) {
        if (index.getRank() == 0 && array.getShape().isScalar()) {
            return array.duplicate();
        }

        // use booleanMask for NDIndexBooleans case
        List indices = index.getIndices();
        if (!indices.isEmpty() && indices.get(0) instanceof NDIndexBooleans) {
            if (indices.size() != 1) {
                throw new IllegalArgumentException(
                        "get() currently didn't support more that one boolean NDArray");
            }
            return array.booleanMask(((NDIndexBooleans) indices.get(0)).getIndex());
        }

        Optional fullPick = NDIndexFullPick.fromIndex(index, array.getShape());
        if (fullPick.isPresent()) {
            return get(array, fullPick.get());
        }

        Optional fullSlice = NDIndexFullSlice.fromIndex(index, array.getShape());
        if (fullSlice.isPresent()) {
            return get(array, fullSlice.get());
        }
        throw new UnsupportedOperationException(
                "get() currently supports all, fixed, and slices indices");
    }

    /**
     * Sets the values of the array at the fullSlice with an array.
     *
     * @param array the array to set
     * @param fullSlice the fullSlice of the index to set in the array
     * @param value the value to set with
     */
    public abstract void set(NDArray array, NDIndexFullSlice fullSlice, NDArray value);

    /**
     * Sets the values of the array at the boolean locations with an array.
     *
     * @param array the array to set
     * @param indices a boolean array where true indicates values to update
     * @param value the value to set with when condition is true
     */
    public void set(NDArray array, NDIndexBooleans indices, NDArray value) {
        array.intern(NDArrays.where(indices.getIndex(), value, array));
    }

    /**
     * Sets the values of the array at the index locations with an array.
     *
     * @param array the array to set
     * @param index the index to set at in the array
     * @param value the value to set with
     */
    public void set(NDArray array, NDIndex index, NDArray value) {
        List indices = index.getIndices();
        if (!indices.isEmpty() && indices.get(0) instanceof NDIndexBooleans) {
            if (indices.size() != 1) {
                throw new IllegalArgumentException(
                        "get() currently didn't support more that one boolean NDArray");
            }
            set(array, (NDIndexBooleans) indices.get(0), value);
        }

        NDIndexFullSlice fullSlice =
                NDIndexFullSlice.fromIndex(index, array.getShape()).orElse(null);
        if (fullSlice != null) {
            set(array, fullSlice, value);
            return;
        }
        throw new UnsupportedOperationException(
                "set() currently supports all, fixed, and slices indices");
    }

    /**
     * Sets the values of the array at the fullSlice with a number.
     *
     * @param array the array to set
     * @param fullSlice the fullSlice of the index to set in the array
     * @param value the value to set with
     */
    public abstract void set(NDArray array, NDIndexFullSlice fullSlice, Number value);

    /**
     * Sets the values of the array at the index locations with a number.
     *
     * @param array the array to set
     * @param index the index to set at in the array
     * @param value the value to set with
     */
    public void set(NDArray array, NDIndex index, Number value) {
        NDIndexFullSlice fullSlice =
                NDIndexFullSlice.fromIndex(index, array.getShape()).orElse(null);
        if (fullSlice != null) {
            set(array, fullSlice, value);
            return;
        }
        // use booleanMask for NDIndexBooleans case
        List indices = index.getIndices();
        if (!indices.isEmpty() && indices.get(0) instanceof NDIndexBooleans) {
            if (indices.size() != 1) {
                throw new IllegalArgumentException(
                        "set() currently didn't support more that one boolean NDArray");
            }
            set(array, (NDIndexBooleans) indices.get(0), array.getManager().create(value));
            return;
        }
        throw new UnsupportedOperationException(
                "set() currently supports all, fixed, and slices indices");
    }

    /**
     * Sets a scalar value in the array at the indexed location.
     *
     * @param array the array to set
     * @param index the index to set at in the array
     * @param value the value to set with
     * @throws IllegalArgumentException if the index does not point to a scalar value in the array
     */
    public void setScalar(NDArray array, NDIndex index, Number value) {
        NDIndexFullSlice fullSlice =
                NDIndexFullSlice.fromIndex(index, array.getShape()).orElse(null);
        if (fullSlice != null) {
            if (fullSlice.getShape().size() != 1) {
                throw new IllegalArgumentException("The provided index does not set a scalar");
            }
            set(array, index, value);
            return;
        }
        throw new UnsupportedOperationException(
                "set() currently supports all, fixed, and slices indices");
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy