ai.djl.modality.cv.util.NDImageUtils Maven / Gradle / Ivy
/*
* Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.modality.cv.util;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.types.Shape;
/**
* {@code NDImageUtils} is an image processing utility to load, reshape, and convert images using
* {@link NDArray} images.
*/
public final class NDImageUtils {
private NDImageUtils() {}
/**
* Resizes an image to the given size.
*
* @param image the image to resize
* @param size the new size to use for both height and width
* @return the resized NDList
*/
public static NDArray resize(NDArray image, int size) {
return image.getNDArrayInternal().resize(size, size);
}
/**
* Resizes an image to the given width and height.
*
* @param image the image to resize
* @param width the desired width
* @param height the desired height
* @return the resized NDList
*/
public static NDArray resize(NDArray image, int width, int height) {
return image.getNDArrayInternal().resize(width, height);
}
/**
* Normalizes an image NDArray of shape CHW or NCHW with a single mean and standard deviation to
* apply to all channels.
*
* @param input the image to normalize
* @param mean the mean to normalize with (for all channels)
* @param std the standard deviation to normalize with (for all channels)
* @return the normalized NDArray
* @see NDImageUtils#normalize(NDArray, float[], float[])
*/
public static NDArray normalize(NDArray input, float mean, float std) {
return normalize(input, new float[] {mean, mean, mean}, new float[] {std, std, std});
}
/**
* Normalizes an image NDArray of shape CHW or NCHW with mean and standard deviation.
*
* Given mean {@code (m1, ..., mn)} and standard deviation {@code (s1, ..., sn} for {@code n}
* channels, this transform normalizes each channel of the input tensor with: {@code output[i] =
* (input[i] - m1) / (s1)}.
*
* @param input the image to normalize
* @param mean the mean to normalize with for each channel
* @param std the standard deviation to normalize with for each channel
* @return the normalized NDArray
*/
public static NDArray normalize(NDArray input, float[] mean, float[] std) {
return input.getNDArrayInternal().normalize(mean, std);
}
/**
* Converts an image NDArray from preprocessing format to Neural Network format.
*
*
Converts an image NDArray of shape HWC in the range {@code [0, 255]} to a {@link
* ai.djl.ndarray.types.DataType#FLOAT32} tensor NDArray of shape CHW in the range {@code [0,
* 1]}.
*
* @param image the image to convert
* @return the converted image
*/
public static NDArray toTensor(NDArray image) {
return image.getNDArrayInternal().toTensor();
}
/**
* Crops an image to a square of size {@code min(width, height)}.
*
* @param image the image to crop
* @return the cropped image
* @see NDImageUtils#centerCrop(NDArray, int, int)
*/
public static NDArray centerCrop(NDArray image) {
Shape shape = image.getShape();
int w = (int) shape.get(1);
int h = (int) shape.get(0);
if (w == h) {
return image;
}
if (w > h) {
return centerCrop(image, h, h);
}
return centerCrop(image, w, w);
}
/**
* Crops an image to a given width and height from the center of the image.
*
* @param image the image to crop
* @param width the desired width of the cropped image
* @param height the desired height of the cropped image
* @return the cropped image
*/
public static NDArray centerCrop(NDArray image, int width, int height) {
Shape shape = image.getShape();
int w = (int) shape.get(1);
int h = (int) shape.get(0);
int x;
int y;
int dw = (w - width) / 2;
int dh = (h - height) / 2;
if (dw > 0) {
x = dw;
w = width;
} else {
x = 0;
}
if (dh > 0) {
y = dh;
h = height;
} else {
y = 0;
}
return crop(image, x, y, w, h);
}
/**
* Crops an image with a given location and size.
*
* @param image the image to crop
* @param x the x coordinate of the top-left corner of the crop
* @param y the y coordinate of the top-left corner of the crop
* @param width the width of the cropped image
* @param height the height of the cropped image
* @return the cropped image
*/
public static NDArray crop(NDArray image, int x, int y, int width, int height) {
return image.getNDArrayInternal().crop(x, y, width, height);
}
/** Flag indicates the color channel options for images. */
public enum Flag {
GRAYSCALE,
COLOR
}
}