hex.deepwater.DeepWaterImageIterator Maven / Gradle / Ivy
package hex.deepwater;
import hex.genmodel.GenModel;
import water.*;
import water.util.Log;
import water.util.SBPrintStream;
import javax.imageio.ImageIO;
import java.awt.image.BufferedImage;
import java.io.BufferedOutputStream;
import java.io.File;
import java.io.IOException;
import java.io.OutputStream;
import java.net.URL;
import java.util.ArrayList;
import java.util.Arrays;
class DeepWaterImageIterator extends DeepWaterIterator {
DeepWaterImageIterator(ArrayList images, ArrayList labels, float[] meanData, int batch_size, int width, int height, int channels, boolean cache) throws IOException {
super(batch_size, width*height*channels, cache);
_img_lst = images;
_label_lst = labels;
_meanData = meanData;
_num_obs = images.size();
_width = width;
_height = height;
_channels = channels;
_file = new String[2][];
_file[0] = new String[batch_size];
_file[1] = new String[batch_size];
}
public static class Dimensions extends Iced implements Comparable {
int _width;
int _height;
int _channels;
public int len() { return _width * _height * _channels; }
@Override
public int compareTo(Dimensions o) {
return o._width == _width && o._height == _height && o._channels == _channels ? 0 : (len() < o.len() ? -1 : 1);
}
}
//Helper for image conversion
//TODO: add cropping, distortion, rotation, etc.
private static class Conversion {
Conversion() { _dim = new Dimensions(); }
Dimensions _dim;
public int len() { return _dim.len(); }
}
static class IcedImage extends Keyed {
public IcedImage() {}
IcedImage(Dimensions dim, float[] data) { _dim = dim; _data = data; }
Dimensions _dim;
float[] _data;
}
static class ImageConverter extends H2O.H2OCountedCompleter {
String _file;
float _label;
Conversion _conv;
float[] _destData;
float[] _meanData;
float[] _destLabel;
int _index;
boolean _cache;
ImageConverter(int index, String file, float label, Conversion conv, float[] destData, float[] meanData, float[] destLabel, boolean cache) {
_index=index;
_file=file;
_label=label;
_conv=conv;
_destData=destData;
_meanData=meanData;
_destLabel=destLabel;
_cache = cache;
}
@Override
public void compute2() {
_destLabel[_index] = _label;
File file = new File(_file);
try {
final int start=_index*_conv.len();
Key imgKey = Key.make(_file + DeepWaterModel.CACHE_MARKER);
boolean status = false;
if (_cache) { //try to get the data from cache first
IcedImage icedIm = DKV.getGet(imgKey);
if (icedIm != null && icedIm._dim.compareTo(_conv._dim)==0) {
// place the cached image into the right minibatch slot
System.arraycopy(icedIm._data, 0, _destData, start, icedIm._data.length);
status = true;
}
}
if (!status) {
boolean isURL = _file.matches("^(https?|ftp|file)://[-a-zA-Z0-9+&@#/%?=~_|!:,.;]*[-a-zA-Z0-9+&@#/%=~_|]") && !file.exists();
BufferedImage img;
if (isURL) img = ImageIO.read(new URL(_file.trim()));
else img = ImageIO.read(new File(_file.trim()));
GenModel.img2pixels(img, _conv._dim._width, _conv._dim._height, _conv._dim._channels, _destData, start, _meanData);
if (_cache) {
Value v = new Value(imgKey, new IcedImage(_conv._dim, Arrays.copyOfRange(_destData, start, start + _conv.len())));
DKV.put(imgKey, v);
v.freeMem();
}
}
} catch (NullPointerException e) {
e.printStackTrace();
// ignored: ImageIO's ICC_Profile can fail with NPEs - unclear why
} catch (Throwable e) {
Log.warn(e.getMessage());
}
tryComplete();
}
}
public boolean Next(Futures fs) throws IOException {
if (_start_index < _num_obs) {
if (_start_index + _batch_size > _num_obs)
_start_index = _num_obs - _batch_size;
// Multi-Threaded data preparation
Conversion conv = new Conversion();
conv._dim._height=this._height;
conv._dim._width=this._width;
conv._dim._channels=this._channels;
for (int i = 0; i < _batch_size; i++)
fs.add(H2O.submitTask(new ImageConverter(i, _img_lst.get(_start_index +i), _label_lst ==null?Float.NaN: _label_lst.get(_start_index +i),conv, _data[which()], _meanData, _label[which()], _cache)));
fs.blockForPending();
flip();
_start_index = _start_index + _batch_size;
return true;
} else {
return false;
}
}
public String[] getFiles() { return _file[which() ^1]; }
final private int _num_obs;
private int _start_index;
final private int _width, _height, _channels;
final private float[] _meanData; //mean image
final private String[][] _file;
final private ArrayList _img_lst;
final private ArrayList _label_lst;
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy