org.datavec.image.loader.ImageLoader Maven / Gradle / Ivy
The newest version!
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.image.loader;
import com.github.jaiimageio.impl.plugins.tiff.TIFFImageReaderSpi;
import com.github.jaiimageio.impl.plugins.tiff.TIFFImageWriterSpi;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.common.util.ArrayUtil;
import org.nd4j.linalg.util.NDArrayUtil;
import javax.imageio.ImageIO;
import javax.imageio.spi.IIORegistry;
import java.awt.*;
import java.awt.image.BufferedImage;
import java.awt.image.DataBufferByte;
import java.awt.image.Raster;
import java.awt.image.WritableRaster;
import java.io.*;
import java.util.Arrays;
public class ImageLoader extends BaseImageLoader {
static {
ImageIO.scanForPlugins();
IIORegistry registry = IIORegistry.getDefaultInstance();
registry.registerServiceProvider(new TIFFImageWriterSpi());
registry.registerServiceProvider(new TIFFImageReaderSpi());
registry.registerServiceProvider(new com.twelvemonkeys.imageio.plugins.jpeg.JPEGImageReaderSpi());
registry.registerServiceProvider(new com.twelvemonkeys.imageio.plugins.jpeg.JPEGImageWriterSpi());
registry.registerServiceProvider(new com.twelvemonkeys.imageio.plugins.psd.PSDImageReaderSpi());
registry.registerServiceProvider(Arrays.asList(new com.twelvemonkeys.imageio.plugins.bmp.BMPImageReaderSpi(),
new com.twelvemonkeys.imageio.plugins.bmp.CURImageReaderSpi(),
new com.twelvemonkeys.imageio.plugins.bmp.ICOImageReaderSpi()));
}
public ImageLoader() {
super();
}
/**
* Instantiate an image with the given
* height and width
*
* @param height the height to load*
* @param width the width to load
*/
public ImageLoader(long height, long width) {
super();
this.height = height;
this.width = width;
}
/**
* Instantiate an image with the given
* height and width
*
* @param height the height to load
* @param width the width to load
* @param channels the number of channels for the image*
*/
public ImageLoader(long height, long width, long channels) {
super();
this.height = height;
this.width = width;
this.channels = channels;
}
/**
* Instantiate an image with the given
* height and width
*
* @param height the height to load
* @param width the width to load
* @param channels the number of channels for the image*
* @param centerCropIfNeeded to crop before rescaling and converting
*/
public ImageLoader(long height, long width, long channels, boolean centerCropIfNeeded) {
this(height, width, channels);
this.centerCropIfNeeded = centerCropIfNeeded;
}
/**
* Convert a file to a row vector
*
* @param f the image to convert
* @return the flattened image
* @throws IOException
*/
public INDArray asRowVector(File f) throws IOException {
return asRowVector(ImageIO.read(f));
// if(channels == 3) {
// return toRaveledTensor(f);
// }
// return NDArrayUtil.toNDArray(flattenedImageFromFile(f));
}
public INDArray asRowVector(InputStream inputStream) throws IOException {
return asRowVector(ImageIO.read(inputStream));
// return asMatrix(inputStream).ravel();
}
/**
* Convert an image in to a row vector
*
* @param image the image to convert
* @return the row vector based on a rastered
* representation of the image
*/
public INDArray asRowVector(BufferedImage image) {
if (centerCropIfNeeded) {
image = centerCropIfNeeded(image);
}
image = scalingIfNeed(image, true);
if (channels == 3) {
return toINDArrayBGR(image).ravel();
}
int[][] ret = toIntArrayArray(image);
return NDArrayUtil.toNDArray(ArrayUtil.flatten(ret));
}
/**
* Changes the input stream in to an
* bgr based raveled(flattened) vector
*
* @param file the input stream to convert
* @return the raveled bgr values for this input stream
*/
public INDArray toRaveledTensor(File file) {
try {
BufferedInputStream bis = new BufferedInputStream(new FileInputStream(file));
INDArray ret = toRaveledTensor(bis);
bis.close();
return ret.ravel();
} catch (IOException e) {
throw new RuntimeException(e);
}
}
/**
* Changes the input stream in to an
* bgr based raveled(flattened) vector
*
* @param is the input stream to convert
* @return the raveled bgr values for this input stream
*/
public INDArray toRaveledTensor(InputStream is) {
return toBgr(is).ravel();
}
/**
* Convert an image in to a raveled tensor of
* the bgr values of the image
*
* @param image the image to parse
* @return the raveled tensor of bgr values
*/
public INDArray toRaveledTensor(BufferedImage image) {
try {
image = scalingIfNeed(image, false);
return toINDArrayBGR(image).ravel();
} catch (Exception e) {
throw new RuntimeException("Unable to load image", e);
}
}
/**
* Convert an input stream to an bgr spectrum image
*
* @param file the file to convert
* @return the input stream to convert
*/
public INDArray toBgr(File file) {
try {
BufferedInputStream bis = new BufferedInputStream(new FileInputStream(file));
INDArray ret = toBgr(bis);
bis.close();
return ret;
} catch (IOException e) {
throw new RuntimeException(e);
}
}
/**
* Convert an input stream to an bgr spectrum image
*
* @param inputStream the input stream to convert
* @return the input stream to convert
*/
public INDArray toBgr(InputStream inputStream) {
try {
BufferedImage image = ImageIO.read(inputStream);
return toBgr(image);
} catch (IOException e) {
throw new RuntimeException("Unable to load image", e);
}
}
private org.datavec.image.data.Image toBgrImage(InputStream inputStream) {
try {
BufferedImage image = ImageIO.read(inputStream);
INDArray img = toBgr(image);
return new org.datavec.image.data.Image(img, image.getData().getNumBands(), image.getHeight(), image.getWidth());
} catch (IOException e) {
throw new RuntimeException("Unable to load image", e);
}
}
/**
* Convert an BufferedImage to an bgr spectrum image
*
* @param image the BufferedImage to convert
* @return the input stream to convert
*/
public INDArray toBgr(BufferedImage image) {
if (image == null)
throw new IllegalStateException("Unable to load image");
image = scalingIfNeed(image, false);
return toINDArrayBGR(image);
}
/**
* Convert an image file
* in to a matrix
*
* @param f the file to convert
* @return a 2d matrix of a rastered version of the image
* @throws IOException
*/
public INDArray asMatrix(File f) throws IOException {
return asMatrix(f, true);
}
@Override
public INDArray asMatrix(File f, boolean nchw) throws IOException {
try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
return asMatrix(is, nchw);
}
}
/**
* Convert an input stream to a matrix
*
* @param inputStream the input stream to convert
* @return the input stream to convert
*/
public INDArray asMatrix(InputStream inputStream) throws IOException {
return asMatrix(inputStream, true);
}
@Override
public INDArray asMatrix(InputStream inputStream, boolean nchw) throws IOException {
INDArray ret;
if (channels == 3) {
ret = toBgr(inputStream);
} else {
try {
BufferedImage image = ImageIO.read(inputStream);
ret = asMatrix(image);
} catch (IOException e) {
throw new IOException("Unable to load image", e);
}
}
if(ret.rank() == 3){
ret = ret.reshape(1, ret.size(0), ret.size(1), ret.size(2));
}
if(!nchw)
ret = ret.permute(0,2,3,1); //NCHW to NHWC
return ret;
}
@Override
public org.datavec.image.data.Image asImageMatrix(File f) throws IOException {
return asImageMatrix(f, true);
}
@Override
public org.datavec.image.data.Image asImageMatrix(File f, boolean nchw) throws IOException {
try (BufferedInputStream bis = new BufferedInputStream(new FileInputStream(f))) {
return asImageMatrix(bis, nchw);
}
}
@Override
public org.datavec.image.data.Image asImageMatrix(InputStream inputStream) throws IOException {
return asImageMatrix(inputStream, true);
}
@Override
public org.datavec.image.data.Image asImageMatrix(InputStream inputStream, boolean nchw) throws IOException {
org.datavec.image.data.Image ret;
if (channels == 3) {
ret = toBgrImage(inputStream);
} else {
try {
BufferedImage image = ImageIO.read(inputStream);
INDArray asMatrix = asMatrix(image);
ret = new org.datavec.image.data.Image(asMatrix, image.getData().getNumBands(), image.getHeight(), image.getWidth());
} catch (IOException e) {
throw new IOException("Unable to load image", e);
}
}
if(ret.getImage().rank() == 3){
INDArray a = ret.getImage();
ret.setImage(a.reshape(1, a.size(0), a.size(1), a.size(2)));
}
if(!nchw)
ret.setImage(ret.getImage().permute(0,2,3,1)); //NCHW to NHWC
return ret;
}
/**
* Convert an BufferedImage to a matrix
*
* @param image the BufferedImage to convert
* @return the input stream to convert
*/
public INDArray asMatrix(BufferedImage image) {
if (channels == 3) {
return toBgr(image);
} else {
image = scalingIfNeed(image, true);
int w = image.getWidth();
int h = image.getHeight();
INDArray ret = Nd4j.create(h, w);
for (int i = 0; i < h; i++) {
for (int j = 0; j < w; j++) {
ret.putScalar(new int[]{i, j}, image.getRGB(j, i));
}
}
return ret;
}
}
/**
* Slices up an image in to a mini batch.
*
* @param f the file to load from
* @param numMiniBatches the number of images in a mini batch
* @param numRowsPerSlice the number of rows for each image
* @return a tensor representing one image as a mini batch
*/
public INDArray asImageMiniBatches(File f, int numMiniBatches, int numRowsPerSlice) {
try {
INDArray d = asMatrix(f);
return Nd4j.create(numMiniBatches, numRowsPerSlice, d.columns());
} catch (Exception e) {
throw new RuntimeException(e);
}
}
public int[] flattenedImageFromFile(File f) throws IOException {
return ArrayUtil.flatten(fromFile(f));
}
/**
* Load a rastered image from file
*
* @param file the file to load
* @return the rastered image
* @throws IOException
*/
public int[][] fromFile(File file) throws IOException {
BufferedImage image = ImageIO.read(file);
image = scalingIfNeed(image, true);
return toIntArrayArray(image);
}
/**
* Load a rastered image from file
*
* @param file the file to load
* @return the rastered image
* @throws IOException
*/
public int[][][] fromFileMultipleChannels(File file) throws IOException {
BufferedImage image = ImageIO.read(file);
image = scalingIfNeed(image, channels > 3);
int w = image.getWidth(), h = image.getHeight();
int bands = image.getSampleModel().getNumBands();
int[][][] ret = new int[(int) Math.min(channels, Integer.MAX_VALUE)]
[(int) Math.min(h, Integer.MAX_VALUE)]
[(int) Math.min(w, Integer.MAX_VALUE)];
byte[] pixels = ((DataBufferByte) image.getRaster().getDataBuffer()).getData();
for (int i = 0; i < h; i++) {
for (int j = 0; j < w; j++) {
for (int k = 0; k < channels; k++) {
if (k >= bands)
break;
ret[k][i][j] = pixels[(int) Math.min(channels * w * i + channels * j + k, Integer.MAX_VALUE)];
}
}
}
return ret;
}
/**
* Convert a matrix in to a buffereed image
*
* @param matrix the
* @return {@link java.awt.image.BufferedImage}
*/
public static BufferedImage toImage(INDArray matrix) {
BufferedImage img = new BufferedImage(matrix.rows(), matrix.columns(), BufferedImage.TYPE_INT_ARGB);
WritableRaster r = img.getRaster();
int[] equiv = new int[(int) matrix.length()];
for (int i = 0; i < equiv.length; i++) {
equiv[i] = (int) matrix.getDouble(i);
}
r.setDataElements(0, 0, matrix.rows(), matrix.columns(), equiv);
return img;
}
private static int[] rasterData(INDArray matrix) {
int[] ret = new int[(int) matrix.length()];
for (int i = 0; i < ret.length; i++)
ret[i] = (int) Math.round((double) matrix.getScalar(i).element());
return ret;
}
/**
* Convert the given image to an rgb image
*
* @param arr the array to use
* @param image the image to set
*/
public void toBufferedImageRGB(INDArray arr, BufferedImage image) {
if (arr.rank() < 3)
throw new IllegalArgumentException("Arr must be 3d");
image = scalingIfNeed(image, arr.size(-2), arr.size(-1), image.getType(), true);
for (int i = 0; i < image.getHeight(); i++) {
for (int j = 0; j < image.getWidth(); j++) {
int r = arr.slice(2).getInt(i, j);
int g = arr.slice(1).getInt(i, j);
int b = arr.slice(0).getInt(i, j);
int a = 1;
int col = (a << 24) | (r << 16) | (g << 8) | b;
image.setRGB(j, i, col);
}
}
}
/**
* Converts a given Image into a BufferedImage
*
* @param img The Image to be converted
* @param type The color model of BufferedImage
* @return The converted BufferedImage
*/
public static BufferedImage toBufferedImage(Image img, int type) {
if (img instanceof BufferedImage && ((BufferedImage) img).getType() == type) {
return (BufferedImage) img;
}
// Create a buffered image with transparency
BufferedImage bimage = new BufferedImage(img.getWidth(null), img.getHeight(null), type);
// Draw the image on to the buffered image
Graphics2D bGr = bimage.createGraphics();
bGr.drawImage(img, 0, 0, null);
bGr.dispose();
// Return the buffered image
return bimage;
}
protected int[][] toIntArrayArray(BufferedImage image) {
int w = image.getWidth(), h = image.getHeight();
int[][] ret = new int[h][w];
if (image.getRaster().getNumDataElements() == 1) {
Raster raster = image.getRaster();
for (int i = 0; i < h; i++) {
for (int j = 0; j < w; j++) {
ret[i][j] = raster.getSample(j, i, 0);
}
}
} else {
for (int i = 0; i < h; i++) {
for (int j = 0; j < w; j++) {
ret[i][j] = image.getRGB(j, i);
}
}
}
return ret;
}
protected INDArray toINDArrayBGR(BufferedImage image) {
int height = image.getHeight();
int width = image.getWidth();
int bands = image.getSampleModel().getNumBands();
byte[] pixels = ((DataBufferByte) image.getRaster().getDataBuffer()).getData();
int[] shape = new int[]{height, width, bands};
INDArray ret2 = Nd4j.create(1, pixels.length);
for (int i = 0; i < ret2.length(); i++) {
ret2.putScalar(i, ((int) pixels[i]) & 0xFF);
}
return ret2.reshape(shape).permute(2, 0, 1);
}
// TODO build flexibility on where to crop the image
public BufferedImage centerCropIfNeeded(BufferedImage img) {
int x = 0;
int y = 0;
int height = img.getHeight();
int width = img.getWidth();
int diff = Math.abs(width - height) / 2;
if (width > height) {
x = diff;
width = width - diff;
} else if (height > width) {
y = diff;
height = height - diff;
}
return img.getSubimage(x, y, width, height);
}
protected BufferedImage scalingIfNeed(BufferedImage image, boolean needAlpha) {
return scalingIfNeed(image, height, width, channels, needAlpha);
}
protected BufferedImage scalingIfNeed(BufferedImage image, long dstHeight, long dstWidth, long dstImageType, boolean needAlpha) {
Image scaled;
// Scale width and height first if necessary
if (dstHeight > 0 && dstWidth > 0 && (image.getHeight() != dstHeight || image.getWidth() != dstWidth)) {
scaled = image.getScaledInstance((int) dstWidth, (int) dstHeight, Image.SCALE_SMOOTH);
} else {
scaled = image;
}
// Transfer imageType if necessary and transfer to BufferedImage.
if (scaled instanceof BufferedImage && ((BufferedImage) scaled).getType() == dstImageType) {
return (BufferedImage) scaled;
}
if (needAlpha && image.getColorModel().hasAlpha() && dstImageType == BufferedImage.TYPE_4BYTE_ABGR) {
return toBufferedImage(scaled, BufferedImage.TYPE_4BYTE_ABGR);
} else {
if (dstImageType == BufferedImage.TYPE_BYTE_GRAY)
return toBufferedImage(scaled, BufferedImage.TYPE_BYTE_GRAY);
else
return toBufferedImage(scaled, BufferedImage.TYPE_3BYTE_BGR);
}
}
}