All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.streaming.routes.DL4jServeRouteBuilder Maven / Gradle / Ivy

There is a newer version: 1.0.0-beta_spark_1
Show newest version
package org.deeplearning4j.streaming.routes;

import kafka.serializer.StringEncoder;
import lombok.AllArgsConstructor;
import lombok.Builder;
import org.apache.camel.Exchange;
import org.apache.camel.Processor;
import org.apache.camel.builder.RouteBuilder;
import org.apache.commons.net.util.Base64;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

import java.io.ByteArrayInputStream;
import java.io.DataInputStream;

/**
 * Serve results from a kafka queue.
 * The input to the route can either be a pre serialized ndarray
 * or a normal ndarray itself.
 *
 * @author Adam Gibson
 */
@AllArgsConstructor
@Builder
public class DL4jServeRouteBuilder extends RouteBuilder {
    protected String modelUri;
    protected String kafkaBroker;
    protected String consumingTopic;
    protected boolean computationGraph;
    protected String outputUri;
    protected Processor finalProcessor;
    protected String groupId = "dl4j-serving";
    protected String zooKeeperHost = "localhost";
    protected int zooKeeperPort = 2181;
    //default no-op
    protected Processor beforeProcessor;


    /**
     * Called on initialization to build the routes using the fluent builder syntax.
     * 

* This is a central method for RouteBuilder implementations to implement * the routes using the Java fluent builder syntax. * * @throws Exception can be thrown during configuration */ @Override public void configure() throws Exception { if(groupId == null) groupId = "dl4j-serving"; if(zooKeeperHost == null) zooKeeperHost = "localhost"; String kafkaUri = String.format("kafka:%s?topic=%s&groupId=%s&zookeeperHost=%s&zookeeperPort=%d&serializerClass=%s&keySerializerClass=%s", kafkaBroker, consumingTopic ,groupId ,zooKeeperHost ,zooKeeperPort, StringEncoder.class.getName(), StringEncoder.class.getName()); if(beforeProcessor == null) { beforeProcessor = new Processor() { @Override public void process(Exchange exchange) throws Exception { } }; } from(kafkaUri).process(beforeProcessor) .process(new Processor() { @Override public void process(Exchange exchange) throws Exception { INDArray predict; if(exchange.getIn().getBody() instanceof byte[]) { byte[] o = (byte[]) exchange.getIn().getBody(); byte[] arr = Base64.decodeBase64(new String(o)); ByteArrayInputStream bis = new ByteArrayInputStream(arr); DataInputStream dis = new DataInputStream(bis); predict = Nd4j.read(dis); } else predict = (INDArray) exchange.getIn().getBody(); if(computationGraph) { ComputationGraph graph = ModelSerializer.restoreComputationGraph(modelUri); INDArray[] output = graph.output(predict); exchange.getOut().setBody(output); exchange.getIn().setBody(output); } else { MultiLayerNetwork network = ModelSerializer.restoreMultiLayerNetwork(modelUri); INDArray output = network.output(predict); exchange.getOut().setBody(output); exchange.getIn().setBody(output); } } }) .process(finalProcessor) .to(outputUri); } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy