hex.genmodel.algos.deepwater.caffe.DeepwaterCaffeModel Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of h2o-genmodel-ext-deepwater Show documentation
Show all versions of h2o-genmodel-ext-deepwater Show documentation
H2O GenModel Deepwater Extension
package hex.genmodel.algos.deepwater.caffe;
import com.google.protobuf.nano.CodedInputByteBufferNano;
import com.google.protobuf.nano.CodedOutputByteBufferNano;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import deepwater.backends.BackendModel;
import hex.genmodel.algos.deepwater.caffe.nano.Deepwater;
import hex.genmodel.algos.deepwater.caffe.nano.Deepwater.Cmd;
public class DeepwaterCaffeModel implements BackendModel {
private int[] _input_shape = new int[0];
private int[] _sizes = new int[0]; // neurons per layer
private String[] _types = new String[0]; // layer types
private double[] _dropout_ratios = new double[0];
private long _seed;
private boolean _useGPU;
private String _graph = "";
private Process _process;
private static final ThreadLocal _buffer = new ThreadLocal<>();
public DeepwaterCaffeModel(int batch_size, int[] sizes,
String[] types, double[] dropout_ratios,
long seed, boolean useGPU) {
_input_shape = new int[] {batch_size, 1, 1, sizes[0]};
_sizes = sizes;
_types = types;
_dropout_ratios = dropout_ratios;
_seed = seed;
_useGPU = useGPU;
start();
}
public DeepwaterCaffeModel(String graph, int[] input_shape, long seed, boolean useGPU) {
_graph = graph;
_input_shape = input_shape;
_seed = seed;
_useGPU = useGPU;
start();
}
private void start() {
if (_process == null) {
try {
startRegular();
} catch (IOException e) {
throw new RuntimeException(e);
}
Cmd cmd = new Cmd();
cmd.type = Deepwater.Create;
cmd.graph = _graph;
cmd.inputShape = _input_shape;
cmd.solverType = "Adam";
cmd.sizes = _sizes;
cmd.types = _types;
cmd.dropoutRatios = _dropout_ratios;
cmd.learningRate = .01f;
cmd.momentum = .99f;
cmd.randomSeed = _seed;
cmd.useGpu = _useGPU;
call(cmd);
}
}
public void saveModel(String model_path) {
Cmd cmd = new Cmd();
cmd.type = Deepwater.SaveGraph;
cmd.path = model_path;
call(cmd);
}
public void saveParam(String param_path) {
Cmd cmd = new Cmd();
cmd.type = Deepwater.Save;
cmd.path = param_path;
call(cmd);
}
public void loadParam(String param_path) {
Cmd cmd = new Cmd();
cmd.type = Deepwater.Load;
cmd.path = param_path;
call(cmd);
}
private static void copy(float[] data, byte[] buff) {
if (data.length * 4 != buff.length)
throw new RuntimeException();
ByteBuffer buffer = _buffer.get();
if (buffer == null || buffer.capacity() < buff.length) {
_buffer.set(buffer = ByteBuffer.allocateDirect(buff.length));
buffer.order(ByteOrder.LITTLE_ENDIAN);
}
buffer.clear();
buffer.asFloatBuffer().put(data);
buffer.get(buff);
}
private static void copy(float[][] buffs, Cmd cmd) {
cmd.data = new byte[buffs.length][];
for (int i = 0; i < buffs.length; i++) {
cmd.data[i] = new byte[buffs[i].length * 4];
copy(buffs[i], cmd.data[i]);
}
}
public void train(float[] data, float[] label) {
Cmd cmd = new Cmd();
cmd.type = Deepwater.Train;
cmd.inputShape = _input_shape;
int len = _input_shape[0] * _input_shape[1] * _input_shape[2] * _input_shape[3];
if (data.length != len)
throw new RuntimeException();
if (label.length != _input_shape[0])
throw new RuntimeException();
float[][] buffs = new float[][] {data, label};
copy(buffs, cmd);
call(cmd);
}
public float[] predict(float[] data) {
Cmd cmd = new Cmd();
cmd.type = Deepwater.Predict;
cmd.inputShape = _input_shape;
// int len = _input_shape[0] * _input_shape[1] * _input_shape[2] * _input_shape[3];
// if (data.length != len)
// throw new RuntimeException(data.length + " vs " + len);
float[][] buffs = new float[][] {data};
copy(buffs, cmd);
cmd = call(cmd);
ByteBuffer buffer = _buffer.get();
if (buffer == null || buffer.capacity() < cmd.data[0].length) {
_buffer.set(buffer = ByteBuffer.allocateDirect(cmd.data[0].length));
buffer.order(ByteOrder.LITTLE_ENDIAN);
}
buffer.clear();
buffer.put(cmd.data[0]);
float[] res = new float[cmd.data[0].length / 4];
buffer.flip();
buffer.asFloatBuffer().get(res);
return res;
}
// Debug, or if wee find a way to package Caffe without Docker
private void startRegular() throws IOException {
String pwd = DeepwaterCaffeBackend.CAFFE_H2O_DIR;
ProcessBuilder pb = new ProcessBuilder("python3 backend.py".split(" "));
pb.environment().put("PYTHONPATH", DeepwaterCaffeBackend.CAFFE_DIR + "python");
pb.redirectError(ProcessBuilder.Redirect.INHERIT);
pb.directory(new File(pwd));
_process = pb.start();
}
void close() {
_process.destroy();
try {
_process.waitFor();
} catch (InterruptedException ex) {
// Ignore
}
}
private Cmd call(Cmd cmd) {
try {
OutputStream stdin = _process.getOutputStream();
int len = cmd.getSerializedSize();
ByteBuffer buffer = ByteBuffer.allocate(4 + len);
buffer.putInt(len);
CodedOutputByteBufferNano ou = CodedOutputByteBufferNano.newInstance(
buffer.array(), buffer.position(), buffer.remaining());
cmd.writeTo(ou);
buffer.position(buffer.position() + len);
stdin.write(buffer.array(), 0, buffer.position());
stdin.flush();
InputStream stdout = _process.getInputStream();
int read = stdout.read(buffer.array(), 0, 4);
if (read != 4)
throw new RuntimeException();
buffer.position(0);
buffer.limit(read);
len = buffer.getInt();
if (buffer.capacity() < len)
buffer = ByteBuffer.allocate(len);
buffer.position(0);
buffer.limit(len);
while (buffer.position() < buffer.limit()) {
read = stdout.read(buffer.array(), buffer.position(), buffer.limit());
buffer.position(buffer.position() + read);
}
Cmd res = new Cmd();
CodedInputByteBufferNano in = CodedInputByteBufferNano.newInstance(
buffer.array(), 0, buffer.position());
res.mergeFrom(in);
return res;
} catch (IOException e) {
throw new RuntimeException(e);
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy