org.datavec.image.loader.NativeImageLoader Maven / Gradle / Ivy
The newest version!
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.image.loader;
import org.apache.commons.io.IOUtils;
import org.bytedeco.javacpp.*;
import org.bytedeco.javacpp.indexer.*;
import org.bytedeco.javacv.Frame;
import org.bytedeco.javacv.OpenCVFrameConverter;
import org.datavec.image.data.Image;
import org.datavec.image.data.ImageWritable;
import org.datavec.image.transform.ImageTransform;
import org.nd4j.linalg.api.concurrency.AffinityManager;
import org.nd4j.linalg.api.memory.pointers.PagedPointer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.common.util.ArrayUtil;
import java.io.*;
import java.nio.ByteOrder;
import org.bytedeco.leptonica.*;
import org.bytedeco.opencv.opencv_core.*;
import static org.bytedeco.leptonica.global.lept.*;
import static org.bytedeco.opencv.global.opencv_core.*;
import static org.bytedeco.opencv.global.opencv_imgcodecs.*;
import static org.bytedeco.opencv.global.opencv_imgproc.*;
public class NativeImageLoader extends BaseImageLoader {
private static final int MIN_BUFFER_STEP_SIZE = 64 * 1024;
public static final String[] ALLOWED_FORMATS = {"bmp", "gif", "jpg", "jpeg", "jp2", "pbm", "pgm", "ppm", "pnm",
"png", "tif", "tiff", "exr", "webp", "BMP", "GIF", "JPG", "JPEG", "JP2", "PBM", "PGM", "PPM", "PNM",
"PNG", "TIF", "TIFF", "EXR", "WEBP"};
protected OpenCVFrameConverter.ToMat converter = new OpenCVFrameConverter.ToMat();
boolean direct = !Loader.getPlatform().startsWith("android");
/**
* Loads images with no scaling or conversion.
*/
public NativeImageLoader() {}
/**
* Instantiate an image with the given
* height and width
* @param height the height to load
* @param width the width to load
*/
public NativeImageLoader(long height, long width) {
this.height = height;
this.width = width;
}
/**
* Instantiate an image with the given
* height and width
* @param height the height to load
* @param width the width to load
* @param channels the number of channels for the image*
*/
public NativeImageLoader(long height, long width, long channels) {
this.height = height;
this.width = width;
this.channels = channels;
}
/**
* Instantiate an image with the given
* height and width
* @param height the height to load
* @param width the width to load
* @param channels the number of channels for the image*
* @param centerCropIfNeeded to crop before rescaling and converting
*/
public NativeImageLoader(long height, long width, long channels, boolean centerCropIfNeeded) {
this(height, width, channels);
this.centerCropIfNeeded = centerCropIfNeeded;
}
/**
* Instantiate an image with the given
* height and width
* @param height the height to load
* @param width the width to load
* @param channels the number of channels for the image*
* @param imageTransform to use before rescaling and converting
*/
public NativeImageLoader(long height, long width, long channels, ImageTransform imageTransform) {
this(height, width, channels);
this.imageTransform = imageTransform;
}
/**
* Instantiate an image with the given
* height and width
* @param height the height to load
* @param width the width to load
* @param channels the number of channels for the image*
* @param mode how to load multipage image
*/
public NativeImageLoader(long height, long width, long channels, MultiPageMode mode) {
this(height, width, channels);
this.multiPageMode = mode;
}
protected NativeImageLoader(NativeImageLoader other) {
this.height = other.height;
this.width = other.width;
this.channels = other.channels;
this.centerCropIfNeeded = other.centerCropIfNeeded;
this.imageTransform = other.imageTransform;
this.multiPageMode = other.multiPageMode;
}
@Override
public String[] getAllowedFormats() {
return ALLOWED_FORMATS;
}
public INDArray asRowVector(String filename) throws IOException {
return asRowVector(new File(filename));
}
/**
* Convert a file to a row vector
*
* @param f the image to convert
* @return the flattened image
* @throws IOException
*/
@Override
public INDArray asRowVector(File f) throws IOException {
return asMatrix(f).ravel();
}
@Override
public INDArray asRowVector(InputStream is) throws IOException {
return asMatrix(is).ravel();
}
/**
* Returns {@code asMatrix(image).ravel()}.
* @see #asMatrix(Object)
*/
public INDArray asRowVector(Object image) throws IOException {
return asMatrix(image).ravel();
}
public INDArray asRowVector(Frame image) throws IOException {
return asMatrix(image).ravel();
}
public INDArray asRowVector(Mat image) throws IOException {
INDArray arr = asMatrix(image);
return arr.reshape('c', 1, arr.length());
}
public INDArray asRowVector(org.opencv.core.Mat image) throws IOException {
INDArray arr = asMatrix(image);
return arr.reshape('c', 1, arr.length());
}
static Mat convert(PIX pix) {
PIX tempPix = null;
int dtype = -1;
int height = pix.h();
int width = pix.w();
Mat mat2;
if (pix.colormap() != null) {
PIX pix2 = pixRemoveColormap(pix, REMOVE_CMAP_TO_FULL_COLOR);
tempPix = pix = pix2;
dtype = CV_8UC4;
} else if (pix.d() <= 8 || pix.d() == 24) {
PIX pix2 = null;
switch (pix.d()) {
case 1:
pix2 = pixConvert1To8(null, pix, (byte) 0, (byte) 255);
break;
case 2:
pix2 = pixConvert2To8(pix, (byte) 0, (byte) 85, (byte) 170, (byte) 255, 0);
break;
case 4:
pix2 = pixConvert4To8(pix, 0);
break;
case 8:
pix2 = pix;
break;
case 24:
pix2 = pix;
break;
default:
assert false;
}
tempPix = pix = pix2;
int channels = pix.d() / 8;
dtype = CV_8UC(channels);
Mat mat = new Mat(height, width, dtype, pix.data(), 4 * pix.wpl());
mat2 = new Mat(height, width, CV_8UC(channels));
// swap bytes if needed
int[] swap = {0, channels - 1, 1, channels - 2, 2, channels - 3, 3, channels - 4},
copy = {0, 0, 1, 1, 2, 2, 3, 3},
fromTo = channels > 1 && ByteOrder.nativeOrder().equals(ByteOrder.LITTLE_ENDIAN) ? swap : copy;
mixChannels(mat, 1, mat2, 1, fromTo, Math.min(channels, fromTo.length / 2));
} else if (pix.d() == 16){
dtype = CV_16UC(pix.d() / 16);
} else if (pix.d() == 32) {
dtype = CV_32FC(pix.d() / 32);
}
mat2 = new Mat(height, width, dtype, pix.data());
if (tempPix != null) {
pixDestroy(tempPix);
}
return mat2;
}
public INDArray asMatrix(String filename) throws IOException {
return asMatrix(new File(filename));
}
@Override
public INDArray asMatrix(File f) throws IOException {
return asMatrix(f, true);
}
@Override
public INDArray asMatrix(File f, boolean nchw) throws IOException {
try (BufferedInputStream bis = new BufferedInputStream(new FileInputStream(f))) {
return asMatrix(bis, nchw);
}
}
@Override
public INDArray asMatrix(InputStream is) throws IOException {
return asMatrix(is, true);
}
@Override
public INDArray asMatrix(InputStream inputStream, boolean nchw) throws IOException {
Mat mat = streamToMat(inputStream);
INDArray a;
if (this.multiPageMode != null) {
a = asMatrix(mat.data(), mat.cols());
}else{
Mat image = imdecode(mat, IMREAD_ANYDEPTH | IMREAD_ANYCOLOR);
if (image == null || image.empty()) {
PIX pix = pixReadMem(mat.data(), mat.cols());
if (pix == null) {
throw new IOException("Could not decode image from input stream");
}
image = convert(pix);
pixDestroy(pix);
}
a = asMatrix(image);
image.deallocate();
}
if(nchw) {
return a;
} else {
return a.permute(0, 2, 3, 1); //NCHW to NHWC
}
}
/**
* Read the stream to the buffer, and return the number of bytes read
* @param is Input stream to read
* @return Mat with the buffer data as a row vector
* @throws IOException
*/
private Mat streamToMat(InputStream is) throws IOException {
byte[] buffer = IOUtils.toByteArray(is);
Mat bufferMat = null;
if (buffer.length <= 0) {
throw new IOException("Could not decode image from input stream: input stream was empty (no data)");
}
bufferMat = new Mat(buffer);
return bufferMat;
}
public Image asImageMatrix(String filename) throws IOException {
return asImageMatrix(new File(filename));
}
@Override
public Image asImageMatrix(File f) throws IOException {
return asImageMatrix(f, true);
}
@Override
public Image asImageMatrix(File f, boolean nchw) throws IOException {
try (BufferedInputStream bis = new BufferedInputStream(new FileInputStream(f))) {
return asImageMatrix(bis, nchw);
}
}
@Override
public Image asImageMatrix(InputStream is) throws IOException {
return asImageMatrix(is, true);
}
@Override
public Image asImageMatrix(InputStream inputStream, boolean nchw) throws IOException {
Mat mat = streamToMat(inputStream);
Mat image = imdecode(mat, IMREAD_ANYDEPTH | IMREAD_ANYCOLOR);
if (image == null || image.empty()) {
PIX pix = pixReadMem(mat.data(), mat.cols());
if (pix == null) {
throw new IOException("Could not decode image from input stream");
}
image = convert(pix);
pixDestroy(pix);
}
INDArray a = asMatrix(image);
if(!nchw)
a = a.permute(0,2,3,1); //NCHW to NHWC
Image i = new Image(a, image.channels(), image.rows(), image.cols());
image.deallocate();
return i;
}
/**
* Calls {@link AndroidNativeImageLoader#asMatrix(android.graphics.Bitmap)} or
* {@link Java2DNativeImageLoader#asMatrix(java.awt.image.BufferedImage)}.
* @param image as a {@link android.graphics.Bitmap} or {@link java.awt.image.BufferedImage}
* @return the matrix or null for unsupported object classes
* @throws IOException
*/
public INDArray asMatrix(Object image) throws IOException {
INDArray array = null;
if (array == null) {
try {
array = new AndroidNativeImageLoader(this).asMatrix(image);
} catch (NoClassDefFoundError e) {
// ignore
}
}
if (array == null) {
try {
array = new Java2DNativeImageLoader(this).asMatrix(image);
} catch (NoClassDefFoundError e) {
// ignore
}
}
return array;
}
protected void fillNDArray(Mat image, INDArray ret) {
long rows = image.rows();
long cols = image.cols();
long channels = image.channels();
if (ret.length() != rows * cols * channels) {
throw new ND4JIllegalStateException("INDArray provided to store image not equal to image: {channels: "
+ channels + ", rows: " + rows + ", columns: " + cols + "}");
}
Indexer idx = image.createIndexer(direct);
Pointer pointer = ret.data().pointer();
long[] stride = ret.stride();
boolean done = false;
PagedPointer pagedPointer = new PagedPointer(pointer, rows * cols * channels,
ret.data().offset() * Nd4j.sizeOfDataType(ret.data().dataType()));
if (pointer instanceof FloatPointer) {
FloatIndexer retidx = FloatIndexer.create((FloatPointer) pagedPointer.asFloatPointer(),
new long[] {channels, rows, cols}, new long[] {stride[0], stride[1], stride[2]}, direct);
if (idx instanceof UByteIndexer) {
UByteIndexer ubyteidx = (UByteIndexer) idx;
for (long k = 0; k < channels; k++) {
for (long i = 0; i < rows; i++) {
for (long j = 0; j < cols; j++) {
retidx.put(k, i, j, ubyteidx.get(i, j, k));
}
}
}
done = true;
} else if (idx instanceof UShortIndexer) {
UShortIndexer ushortidx = (UShortIndexer) idx;
for (long k = 0; k < channels; k++) {
for (long i = 0; i < rows; i++) {
for (long j = 0; j < cols; j++) {
retidx.put(k, i, j, ushortidx.get(i, j, k));
}
}
}
done = true;
} else if (idx instanceof IntIndexer) {
IntIndexer intidx = (IntIndexer) idx;
for (long k = 0; k < channels; k++) {
for (long i = 0; i < rows; i++) {
for (long j = 0; j < cols; j++) {
retidx.put(k, i, j, intidx.get(i, j, k));
}
}
}
done = true;
} else if (idx instanceof FloatIndexer) {
FloatIndexer floatidx = (FloatIndexer) idx;
for (long k = 0; k < channels; k++) {
for (long i = 0; i < rows; i++) {
for (long j = 0; j < cols; j++) {
retidx.put(k, i, j, floatidx.get(i, j, k));
}
}
}
done = true;
}
retidx.release();
} else if (pointer instanceof DoublePointer) {
DoubleIndexer retidx = DoubleIndexer.create((DoublePointer) pagedPointer.asDoublePointer(),
new long[] {channels, rows, cols}, new long[] {stride[0], stride[1], stride[2]}, direct);
if (idx instanceof UByteIndexer) {
UByteIndexer ubyteidx = (UByteIndexer) idx;
for (long k = 0; k < channels; k++) {
for (long i = 0; i < rows; i++) {
for (long j = 0; j < cols; j++) {
retidx.put(k, i, j, ubyteidx.get(i, j, k));
}
}
}
done = true;
} else if (idx instanceof UShortIndexer) {
UShortIndexer ushortidx = (UShortIndexer) idx;
for (long k = 0; k < channels; k++) {
for (long i = 0; i < rows; i++) {
for (long j = 0; j < cols; j++) {
retidx.put(k, i, j, ushortidx.get(i, j, k));
}
}
}
done = true;
} else if (idx instanceof IntIndexer) {
IntIndexer intidx = (IntIndexer) idx;
for (long k = 0; k < channels; k++) {
for (long i = 0; i < rows; i++) {
for (long j = 0; j < cols; j++) {
retidx.put(k, i, j, intidx.get(i, j, k));
}
}
}
done = true;
} else if (idx instanceof FloatIndexer) {
FloatIndexer floatidx = (FloatIndexer) idx;
for (long k = 0; k < channels; k++) {
for (long i = 0; i < rows; i++) {
for (long j = 0; j < cols; j++) {
retidx.put(k, i, j, floatidx.get(i, j, k));
}
}
}
done = true;
}
retidx.release();
}
if (!done) {
for (long k = 0; k < channels; k++) {
for (long i = 0; i < rows; i++) {
for (long j = 0; j < cols; j++) {
if (ret.rank() == 3) {
ret.putScalar(k, i, j, idx.getDouble(i, j, k));
} else if (ret.rank() == 4) {
ret.putScalar(1, k, i, j, idx.getDouble(i, j, k));
} else if (ret.rank() == 2) {
ret.putScalar(i, j, idx.getDouble(i, j));
} else
throw new ND4JIllegalStateException("NativeImageLoader expects 2D, 3D or 4D output array, but " + ret.rank() + "D array was given");
}
}
}
}
idx.release();
image.data();
Nd4j.getAffinityManager().tagLocation(ret, AffinityManager.Location.HOST);
}
public void asMatrixView(InputStream is, INDArray view) throws IOException {
Mat mat = streamToMat(is);
Mat image = imdecode(mat, IMREAD_ANYDEPTH | IMREAD_ANYCOLOR);
if (image == null || image.empty()) {
PIX pix = pixReadMem(mat.data(), mat.cols());
if (pix == null) {
throw new IOException("Could not decode image from input stream");
}
image = convert(pix);
pixDestroy(pix);
}
if (image == null)
throw new RuntimeException();
asMatrixView(image, view);
image.deallocate();
}
public void asMatrixView(String filename, INDArray view) throws IOException {
asMatrixView(new File(filename), view);
}
public void asMatrixView(File f, INDArray view) throws IOException {
try (BufferedInputStream bis = new BufferedInputStream(new FileInputStream(f))) {
asMatrixView(bis, view);
}
}
public void asMatrixView(Mat image, INDArray view) throws IOException {
transformImage(image, view);
}
public void asMatrixView(org.opencv.core.Mat image, INDArray view) throws IOException {
transformImage(image, view);
}
public INDArray asMatrix(Frame image) throws IOException {
return asMatrix(converter.convert(image));
}
public INDArray asMatrix(org.opencv.core.Mat image) throws IOException {
INDArray ret = transformImage(image, null);
return ret.reshape(ArrayUtil.combine(new long[] {1}, ret.shape()));
}
public INDArray asMatrix(Mat image) throws IOException {
INDArray ret = transformImage(image, null);
return ret.reshape(ArrayUtil.combine(new long[] {1}, ret.shape()));
}
protected INDArray transformImage(org.opencv.core.Mat image, INDArray ret) throws IOException {
Frame f = converter.convert(image);
return transformImage(converter.convert(f), ret);
}
protected INDArray transformImage(Mat image, INDArray ret) throws IOException {
if (imageTransform != null && converter != null) {
ImageWritable writable = new ImageWritable(converter.convert(image));
writable = imageTransform.transform(writable);
image = converter.convert(writable.getFrame());
}
Mat image2 = null, image3 = null, image4 = null;
if (channels > 0 && image.channels() != channels) {
int code = -1;
switch (image.channels()) {
case 1:
switch ((int)channels) {
case 3:
code = CV_GRAY2BGR;
break;
case 4:
code = CV_GRAY2RGBA;
break;
}
break;
case 3:
switch ((int)channels) {
case 1:
code = CV_BGR2GRAY;
break;
case 4:
code = CV_BGR2RGBA;
break;
}
break;
case 4:
switch ((int)channels) {
case 1:
code = CV_RGBA2GRAY;
break;
case 3:
code = CV_RGBA2BGR;
break;
}
break;
}
if (code < 0) {
throw new IOException("Cannot convert from " + image.channels() + " to " + channels + " channels.");
}
image2 = new Mat();
cvtColor(image, image2, code);
image = image2;
}
if (centerCropIfNeeded) {
image3 = centerCropIfNeeded(image);
if (image3 != image) {
image = image3;
} else {
image3 = null;
}
}
image4 = scalingIfNeed(image);
if (image4 != image) {
image = image4;
} else {
image4 = null;
}
if (ret == null) {
int rows = image.rows();
int cols = image.cols();
int channels = image.channels();
ret = Nd4j.create(channels, rows, cols);
}
fillNDArray(image, ret);
image.data(); // dummy call to make sure it does not get deallocated prematurely
if (image2 != null) {
image2.deallocate();
}
if (image3 != null) {
image3.deallocate();
}
if (image4 != null) {
image4.deallocate();
}
return ret;
}
// TODO build flexibility on where to crop the image
protected Mat centerCropIfNeeded(Mat img) {
int x = 0;
int y = 0;
int height = img.rows();
int width = img.cols();
int diff = Math.abs(width - height) / 2;
if (width > height) {
x = diff;
width = width - diff;
} else if (height > width) {
y = diff;
height = height - diff;
}
return img.apply(new Rect(x, y, width, height));
}
protected Mat scalingIfNeed(Mat image) {
return scalingIfNeed(image, height, width);
}
protected Mat scalingIfNeed(Mat image, long dstHeight, long dstWidth) {
Mat scaled = image;
if (dstHeight > 0 && dstWidth > 0 && (image.rows() != dstHeight || image.cols() != dstWidth)) {
resize(image, scaled = new Mat(), new Size(
(int)Math.min(dstWidth, Integer.MAX_VALUE),
(int)Math.min(dstHeight, Integer.MAX_VALUE)));
}
return scaled;
}
public ImageWritable asWritable(String filename) throws IOException {
return asWritable(new File(filename));
}
/**
* Convert a file to a INDArray
*
* @param f the image to convert
* @return INDArray
* @throws IOException
*/
public ImageWritable asWritable(File f) throws IOException {
try (BufferedInputStream bis = new BufferedInputStream(new FileInputStream(f))) {
Mat mat = streamToMat(bis);
Mat image = imdecode(mat, IMREAD_ANYDEPTH | IMREAD_ANYCOLOR);
if (image == null || image.empty()) {
PIX pix = pixReadMem(mat.data(), mat.cols());
if (pix == null) {
throw new IOException("Could not decode image from input stream");
}
image = convert(pix);
pixDestroy(pix);
}
ImageWritable writable = new ImageWritable(converter.convert(image));
return writable;
}
}
/**
* Convert ImageWritable to INDArray
*
* @param writable ImageWritable to convert
* @return INDArray
* @throws IOException
*/
public INDArray asMatrix(ImageWritable writable) throws IOException {
Mat image = converter.convert(writable.getFrame());
return asMatrix(image);
}
/** Returns {@code asFrame(array, -1)}. */
public Frame asFrame(INDArray array) {
return converter.convert(asMat(array));
}
/**
* Converts an INDArray to a JavaCV Frame. Only intended for images with rank 3.
*
* @param array to convert
* @param dataType from JavaCV (DEPTH_FLOAT, DEPTH_UBYTE, etc), or -1 to use same type as the INDArray
* @return data copied to a Frame
*/
public Frame asFrame(INDArray array, int dataType) {
return converter.convert(asMat(array, OpenCVFrameConverter.getMatDepth(dataType)));
}
/** Returns {@code asMat(array, -1)}. */
public Mat asMat(INDArray array) {
return asMat(array, -1);
}
/**
* Converts an INDArray to an OpenCV Mat. Only intended for images with rank 3.
*
* @param array to convert
* @param dataType from OpenCV (CV_32F, CV_8U, etc), or -1 to use same type as the INDArray
* @return data copied to a Mat
*/
public Mat asMat(INDArray array, int dataType) {
if (array.rank() > 4 || (array.rank() > 3 && array.size(0) != 1)) {
throw new UnsupportedOperationException("Only rank 3 (or rank 4 with size(0) == 1) arrays supported");
}
int rank = array.rank();
long[] stride = array.stride();
long offset = array.data().offset();
Pointer pointer = array.data().pointer().position(offset);
long rows = array.size(rank == 3 ? 1 : 2);
long cols = array.size(rank == 3 ? 2 : 3);
long channels = array.size(rank == 3 ? 0 : 1);
boolean done = false;
if (dataType < 0) {
dataType = pointer instanceof DoublePointer ? CV_64F : CV_32F;
}
Mat mat = new Mat((int)Math.min(rows, Integer.MAX_VALUE), (int)Math.min(cols, Integer.MAX_VALUE),
CV_MAKETYPE(dataType, (int)Math.min(channels, Integer.MAX_VALUE)));
Indexer matidx = mat.createIndexer(direct);
Nd4j.getAffinityManager().ensureLocation(array, AffinityManager.Location.HOST);
if (pointer instanceof FloatPointer && dataType == CV_32F) {
FloatIndexer ptridx = FloatIndexer.create((FloatPointer)pointer, new long[] {channels, rows, cols},
new long[] {stride[rank == 3 ? 0 : 1], stride[rank == 3 ? 1 : 2], stride[rank == 3 ? 2 : 3]}, direct);
FloatIndexer idx = (FloatIndexer)matidx;
for (long k = 0; k < channels; k++) {
for (long i = 0; i < rows; i++) {
for (long j = 0; j < cols; j++) {
idx.put(i, j, k, ptridx.get(k, i, j));
}
}
}
done = true;
ptridx.release();
} else if (pointer instanceof DoublePointer && dataType == CV_64F) {
DoubleIndexer ptridx = DoubleIndexer.create((DoublePointer)pointer, new long[] {channels, rows, cols},
new long[] {stride[rank == 3 ? 0 : 1], stride[rank == 3 ? 1 : 2], stride[rank == 3 ? 2 : 3]}, direct);
DoubleIndexer idx = (DoubleIndexer)matidx;
for (long k = 0; k < channels; k++) {
for (long i = 0; i < rows; i++) {
for (long j = 0; j < cols; j++) {
idx.put(i, j, k, ptridx.get(k, i, j));
}
}
}
done = true;
ptridx.release();
}
if (!done) {
for (long k = 0; k < channels; k++) {
for (long i = 0; i < rows; i++) {
for (long j = 0; j < cols; j++) {
if (rank == 3) {
matidx.putDouble(new long[] {i, j, k}, array.getDouble(k, i, j));
} else {
matidx.putDouble(new long[] {i, j, k}, array.getDouble(0, k, i, j));
}
}
}
}
}
matidx.release();
return mat;
}
/**
* Read multipage tiff and load into INDArray
*
* @param bytes
* @return INDArray
* @throws IOException
*/
private INDArray asMatrix(BytePointer bytes, long length) throws IOException {
PIXA pixa;
pixa = pixaReadMemMultipageTiff(bytes, length);
INDArray data;
INDArray currentD;
INDArrayIndex[] index = null;
switch (this.multiPageMode) {
case MINIBATCH:
data = Nd4j.create(pixa.n(), 1, 1, pixa.pix(0).h(), pixa.pix(0).w());
break;
// case CHANNELS:
// data = Nd4j.create(1, pixa.n(), 1, pixa.pix(0).h(), pixa.pix(0).w());
// break;
case FIRST:
data = Nd4j.create(1, 1, 1, pixa.pix(0).h(), pixa.pix(0).w());
PIX pix = pixa.pix(0);
currentD = asMatrix(convert(pix));
pixDestroy(pix);
index = new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.point(0), NDArrayIndex.point(0),
NDArrayIndex.all(), NDArrayIndex.all()};
data.put(index , currentD.get(NDArrayIndex.all(), NDArrayIndex.all(),
NDArrayIndex.all(), NDArrayIndex.all()));
return data;
default: throw new UnsupportedOperationException("Unsupported MultiPageMode: " + multiPageMode);
}
for (int i = 0; i < pixa.n(); i++) {
PIX pix = pixa.pix(i);
currentD = asMatrix(convert(pix));
pixDestroy(pix);
switch (this.multiPageMode) {
case MINIBATCH:
index = new INDArrayIndex[]{NDArrayIndex.point(i),NDArrayIndex.all(), NDArrayIndex.all(),NDArrayIndex.all(),NDArrayIndex.all()};
break;
// case CHANNELS:
// index = new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.point(i), NDArrayIndex.all(), NDArrayIndex.all(),NDArrayIndex.all()};
// break;
default: throw new UnsupportedOperationException("Unsupported MultiPageMode: " + multiPageMode);
}
data.put(index , currentD.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all(),NDArrayIndex.all()));
}
return data;
}
}