org.deeplearning4j.streaming.routes.DL4jServeRouteBuilder Maven / Gradle / Ivy
package org.deeplearning4j.streaming.routes;
import kafka.serializer.StringEncoder;
import lombok.AllArgsConstructor;
import lombok.Builder;
import org.apache.camel.Exchange;
import org.apache.camel.Processor;
import org.apache.camel.builder.RouteBuilder;
import org.apache.commons.net.util.Base64;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import java.io.ByteArrayInputStream;
import java.io.DataInputStream;
/**
* Serve results from a kafka queue.
* The input to the route can either be a pre serialized ndarray
* or a normal ndarray itself.
*
* @author Adam Gibson
*/
@AllArgsConstructor
@Builder
public class DL4jServeRouteBuilder extends RouteBuilder {
protected String modelUri;
protected String kafkaBroker;
protected String consumingTopic;
protected boolean computationGraph;
protected String outputUri;
protected Processor finalProcessor;
protected String groupId = "dl4j-serving";
protected String zooKeeperHost = "localhost";
protected int zooKeeperPort = 2181;
//default no-op
protected Processor beforeProcessor;
/**
* Called on initialization to build the routes using the fluent builder syntax.
*
* This is a central method for RouteBuilder implementations to implement
* the routes using the Java fluent builder syntax.
*
* @throws Exception can be thrown during configuration
*/
@Override
public void configure() throws Exception {
if(groupId == null)
groupId = "dl4j-serving";
if(zooKeeperHost == null)
zooKeeperHost = "localhost";
String kafkaUri = String.format("kafka:%s?topic=%s&groupId=%s&zookeeperHost=%s&zookeeperPort=%d&serializerClass=%s&keySerializerClass=%s",
kafkaBroker,
consumingTopic
,groupId
,zooKeeperHost
,zooKeeperPort,
StringEncoder.class.getName(),
StringEncoder.class.getName());
if(beforeProcessor == null) {
beforeProcessor = new Processor() {
@Override
public void process(Exchange exchange) throws Exception {
}
};
}
from(kafkaUri).process(beforeProcessor)
.process(new Processor() {
@Override
public void process(Exchange exchange) throws Exception {
INDArray predict;
if(exchange.getIn().getBody() instanceof byte[]) {
byte[] o = (byte[]) exchange.getIn().getBody();
byte[] arr = Base64.decodeBase64(new String(o));
ByteArrayInputStream bis = new ByteArrayInputStream(arr);
DataInputStream dis = new DataInputStream(bis);
predict = Nd4j.read(dis);
}
else
predict = (INDArray) exchange.getIn().getBody();
if(computationGraph) {
ComputationGraph graph = ModelSerializer.restoreComputationGraph(modelUri);
INDArray[] output = graph.output(predict);
exchange.getOut().setBody(output);
exchange.getIn().setBody(output);
}
else {
MultiLayerNetwork network = ModelSerializer.restoreMultiLayerNetwork(modelUri);
INDArray output = network.output(predict);
exchange.getOut().setBody(output);
exchange.getIn().setBody(output);
}
}
})
.process(finalProcessor)
.to(outputUri);
}
}