All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ai.djl.ndarray.index.NDIndex Maven / Gradle / Ivy

There is a newer version: 0.30.0
Show newest version
/*
 * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
 * with the License. A copy of the License is located at
 *
 * http://aws.amazon.com/apache2.0/
 *
 * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
 * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
 * and limitations under the License.
 */
package ai.djl.ndarray.index;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.types.Shape;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Stream;

/**
 * The {@code NDIndex} allows you to specify a subset of an NDArray that can be used for fetching or
 * updating.
 *
 * 

It accepts a different index option for each dimension, given in the order of the dimensions. * Each dimension has options corresponding to: * *

    *
  • Return all dimensions - Pass null to addIndices *
  • A single value in the dimension - Pass the value to addIndices with a negative index -i * corresponding to [dimensionLength - i] *
  • A range of values - Use addSliceDim *
* *

We recommend creating the NDIndex using {@link #NDIndex(String)}. * * @see #NDIndex(String) */ public class NDIndex { private static final Pattern ITEM_PATTERN = Pattern.compile("(\\*)|((-?\\d+)?:(-?\\d+)?(:(-?\\d+))?)|(-?\\d+)"); private int rank; private List indices; /** Creates an empty {@link NDIndex} to append values to. */ public NDIndex() { rank = 0; indices = new ArrayList<>(); } /** * Creates a {@link NDIndex} given the index values. * *

Here are some examples of the indices format. * *

     *     NDArray a = manager.ones(new Shape(5, 4, 3));
     *
     *     // Gets a subsection of the NDArray in the first axis.
     *     assertEquals(a.get(new NDIndex("2")).getShape(), new Shape(4, 3));
     *
     *     // Gets a subsection of the NDArray indexing from the end (-i == length - i).
     *     assertEquals(a.get(new NDIndex("-1")).getShape(), new Shape(4, 3));
     *
     *     // Gets everything in the first axis and a subsection in the second axis.
     *     // You can use either : or * to represent everything
     *     assertEquals(a.get(new NDIndex(":, 2")).getShape(), new Shape(5, 3));
     *     assertEquals(a.get(new NDIndex("*, 2")).getShape(), new Shape(5, 3));
     *
     *     // Gets a range of values along the second axis that is inclusive on the bottom and exclusive on the top.
     *     assertEquals(a.get(new NDIndex(":, 1:3")).getShape(), new Shape(5, 2, 3));
     *
     *     // Excludes either the min or the max of the range to go all the way to the beginning or end.
     *     assertEquals(a.get(new NDIndex(":, :3")).getShape(), new Shape(5, 3, 3));
     *     assertEquals(a.get(new NDIndex(":, 1:")).getShape(), new Shape(5, 4, 3));
     *
     *     // Uses the value after the second colon in a slicing range, the step, to get every other result.
     *     assertEquals(a.get(new NDIndex(":, 1::2")).getShape(), new Shape(5, 2, 3));
     *
     *     // Uses a negative step to reverse along the dimension.
     *     assertEquals(a.get(new NDIndex("-1")).getShape(), new Shape(5, 4, 3));
     * 
* * @param indices a comma separated list of indices corresponding to either subsections, * everything, or slices on a particular dimension * @see Numpy * Indexing */ public NDIndex(String indices) { this(); addIndices(indices); } /** * Creates an NDIndex with the given indices as specified values on the NDArray. * * @param indices the indices with each index corresponding to the dimensions and negative * indices starting from the end */ public NDIndex(long... indices) { this(); addIndices(indices); } /** * Creates an {@link NDIndex} that just has one slice in the given axis. * * @param axis the axis to slice * @param min the min of the slice * @param max the max of the slice * @return a new {@link NDIndex} with the given slice. */ public static NDIndex sliceAxis(int axis, long min, long max) { NDIndex ind = new NDIndex(); for (int i = 0; i < axis; i++) { ind.addAllDim(); } ind.addSliceDim(min, max); return ind; } /** * Returns the number of dimensions specified in the Index. * * @return the number of dimensions specified in the Index */ public int getRank() { return rank; } /** * Returns the index affecting the given dimension. * * @param dimension the affected dimension * @return the index affecting the given dimension */ public NDIndexElement get(int dimension) { return indices.get(dimension); } /** * Returns the indices. * * @return the indices */ public List getIndices() { return indices; } /** * Updates the NDIndex by appending indices to the array. * * @param indices the indices to add similar to {@link #NDIndex(String)} * @return the updated {@link NDIndex} * @see #NDIndex(String) */ public final NDIndex addIndices(String indices) { String[] indexItems = indices.split(","); rank += indexItems.length; for (String indexItem : indexItems) { addIndexItem(indexItem); } return this; } /** * Updates the NDIndex by appending indices as specified values on the NDArray. * * @param indices with each index corresponding to the dimensions and negative indices starting * from the end * @return the updated {@link NDIndex} */ public final NDIndex addIndices(long... indices) { rank += indices.length; for (long i : indices) { this.indices.add(new NDIndexFixed(i)); } return this; } /** * Updates the NDIndex by appending a boolean NDArray. * *

The NDArray should have a matching shape to the dimensions being fetched and will return * where the values in NDIndex do not equal zero. * * @param index a boolean NDArray where all nonzero elements correspond to elements to return * @return the updated {@link NDIndex} */ public NDIndex addBooleanIndex(NDArray index) { rank += index.getShape().dimension(); indices.add(new NDIndexBooleans(index)); return this; } /** * Appends a new index to get all values in the dimension. * * @return the updated {@link NDIndex} */ public NDIndex addAllDim() { rank++; indices.add(new NDIndexAll()); return this; } /** * Appends a new index to slice the dimension and returns a range of values. * * @param min the minimum of the range * @param max the maximum of the range * @return the updated {@link NDIndex} */ public NDIndex addSliceDim(long min, long max) { rank++; indices.add(new NDIndexSlice(min, max, null)); return this; } /** * Appends a new index to slice the dimension and returns a range of values. * * @param min the minimum of the range * @param max the maximum of the range * @param step the step of the slice * @return the updated {@link NDIndex} */ public NDIndex addSliceDim(long min, long max, long step) { rank++; indices.add(new NDIndexSlice(min, max, step)); return this; } /** * Returns a stream of the NDIndexElements. * * @return a stream of the NDIndexElements */ public Stream stream() { return indices.stream(); } private void addIndexItem(String indexItem) { indexItem = indexItem.trim(); Matcher m = ITEM_PATTERN.matcher(indexItem); if (!m.matches()) { throw new IllegalArgumentException("Invalid argument index: " + indexItem); } String star = m.group(1); if (star != null) { indices.add(new NDIndexAll()); return; } String digit = m.group(7); if (digit != null) { indices.add(new NDIndexFixed(Long.parseLong(digit))); return; } // Slice Long min = m.group(3) != null ? Long.parseLong(m.group(3)) : null; Long max = m.group(4) != null ? Long.parseLong(m.group(4)) : null; Long step = m.group(6) != null ? Long.parseLong(m.group(6)) : null; if (min == null && max == null && step == null) { indices.add(new NDIndexAll()); } else { indices.add(new NDIndexSlice(min, max, step)); } } /** * Returns this index as a full slice if it can be represented as one. * * @param target the shape to index * @return the full slice if it can be represented as one */ public Optional getAsFullSlice(Shape target) { if (!stream().allMatch( ie -> ie instanceof NDIndexAll || ie instanceof NDIndexFixed || ie instanceof NDIndexSlice)) { return Optional.empty(); } int indDimensions = getRank(); int targetDimensions = target.dimension(); if (indDimensions > target.dimension()) { throw new IllegalArgumentException( "The index has too many dimensions - " + indDimensions + " dimensions for array with " + targetDimensions + " dimensions"); } long[] min = new long[targetDimensions]; long[] max = new long[targetDimensions]; long[] step = new long[targetDimensions]; List toSqueeze = new ArrayList<>(targetDimensions); long[] shape = new long[targetDimensions]; List squeezedShape = new ArrayList<>(targetDimensions); for (int i = 0; i < indDimensions; i++) { NDIndexElement ie = get(i); if (ie instanceof NDIndexFixed) { min[i] = ((NDIndexFixed) ie).getIndex(); max[i] = ((NDIndexFixed) ie).getIndex() + 1; step[i] = 1; toSqueeze.add(i); shape[i] = 1; } else if (ie instanceof NDIndexSlice) { NDIndexSlice slice = (NDIndexSlice) ie; long rawMin = Optional.ofNullable(slice.getMin()).orElse(0L); min[i] = rawMin < 0 ? Math.floorMod(rawMin, target.get(i)) : rawMin; long rawMax = Optional.ofNullable(slice.getMax()).orElse(target.size(i)); max[i] = rawMax < 0 ? Math.floorMod(rawMax, target.get(i)) : rawMax; step[i] = Optional.ofNullable(slice.getStep()).orElse(1L); if (step[i] > 0) { shape[i] = (max[i] - min[i] - 1) / (step[i] + 1); } else { shape[i] = (min[i] - max[i]) / (-step[i] + 1); } squeezedShape.add(shape[i]); } else if (ie instanceof NDIndexAll) { min[i] = 0; max[i] = target.size(i); step[i] = 1; shape[i] = target.size(i); squeezedShape.add(target.size(i)); } } for (int i = indDimensions; i < target.dimension(); i++) { min[i] = 0; max[i] = target.size(i); step[i] = 1; shape[i] = target.size(i); squeezedShape.add(target.size(i)); } NDIndexFullSlice fullSlice = new NDIndexFullSlice( min, max, step, toSqueeze, new Shape(shape), new Shape(squeezedShape)); return Optional.of(fullSlice); } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy