All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.remote.DL4jServlet Maven / Gradle / Ivy

/*******************************************************************************
 * Copyright (c) 2015-2018 Skymind, Inc.
 *
 * This program and the accompanying materials are made available under the
 * terms of the Apache License, Version 2.0 which is available at
 * https://www.apache.org/licenses/LICENSE-2.0.
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 *
 * SPDX-License-Identifier: Apache-2.0
 ******************************************************************************/

package org.deeplearning4j.remote;

import lombok.*;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.parallelism.ParallelInference;
import org.nd4j.adapters.InferenceAdapter;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.dataset.MultiDataSet;
import org.nd4j.remote.clients.serde.BinaryDeserializer;
import org.nd4j.remote.clients.serde.BinarySerializer;
import org.nd4j.remote.clients.serde.JsonDeserializer;
import org.nd4j.remote.clients.serde.JsonSerializer;
import org.nd4j.remote.serving.SameDiffServlet;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;



/**
 *
 * @author astoyakin
 */
@Slf4j
@NoArgsConstructor
public class DL4jServlet extends SameDiffServlet {

    protected ParallelInference parallelInference;
    protected Model model;
    protected boolean parallelEnabled = true;

    public DL4jServlet(@NonNull ParallelInference parallelInference, @NonNull InferenceAdapter inferenceAdapter,
                       JsonSerializer serializer, JsonDeserializer deserializer) {
        super(inferenceAdapter, serializer, deserializer);
        this.parallelInference = parallelInference;
        this.model = null;
        this.parallelEnabled = true;
    }

    public DL4jServlet(@NonNull Model model, @NonNull InferenceAdapter inferenceAdapter,
                       JsonSerializer serializer, JsonDeserializer deserializer) {
        super(inferenceAdapter, serializer, deserializer);
        this.model = model;
        this.parallelInference = null;
        this.parallelEnabled = false;
    }

    public DL4jServlet(@NonNull ParallelInference parallelInference, @NonNull InferenceAdapter inferenceAdapter,
                       BinarySerializer serializer, BinaryDeserializer deserializer) {
        super(inferenceAdapter, serializer, deserializer);
        this.parallelInference = parallelInference;
        this.model = null;
        this.parallelEnabled = true;
    }

    public DL4jServlet(@NonNull Model model, @NonNull InferenceAdapter inferenceAdapter,
                       JsonSerializer jsonSerializer, JsonDeserializer jsonDeserializer,
                       BinarySerializer binarySerializer, BinaryDeserializer binaryDeserializer) {
        super(inferenceAdapter, jsonSerializer, jsonDeserializer, binarySerializer, binaryDeserializer);
        this.model = model;
        this.parallelInference = null;
        this.parallelEnabled = false;
    }

    public DL4jServlet(@NonNull ParallelInference parallelInference, @NonNull InferenceAdapter inferenceAdapter,
                       JsonSerializer jsonSerializer, JsonDeserializer jsonDeserializer,
                       BinarySerializer binarySerializer, BinaryDeserializer binaryDeserializer) {
        super(inferenceAdapter, jsonSerializer, jsonDeserializer, binarySerializer, binaryDeserializer);
        this.parallelInference = parallelInference;
        this.model = null;
        this.parallelEnabled = true;
    }

    private O process(MultiDataSet mds) {
        O result = null;
        if (parallelEnabled) {
            // process result
            result = inferenceAdapter.apply(parallelInference.output(mds.getFeatures(), mds.getFeaturesMaskArrays()));
        } else {
            synchronized (this) {
                if (model instanceof ComputationGraph)
                    result = inferenceAdapter.apply(((ComputationGraph) model).output(false, mds.getFeatures(), mds.getFeaturesMaskArrays()));
                else if (model instanceof MultiLayerNetwork) {
                    Preconditions.checkArgument(mds.getFeatures().length > 0 || (mds.getFeaturesMaskArrays() != null && mds.getFeaturesMaskArrays().length > 0),
                            "Input data for MultilayerNetwork is invalid!");
                    result = inferenceAdapter.apply(((MultiLayerNetwork) model).output(mds.getFeatures()[0], false,
                            mds.getFeaturesMaskArrays() != null ? mds.getFeaturesMaskArrays()[0] : null, null));
                }
            }
        }
        return result;
    }

    @Override
    protected void doPost(HttpServletRequest request, HttpServletResponse response) throws IOException {
        String processorReturned = "";
        MultiDataSet mds = null;
        String path = request.getPathInfo();
        if (path.equals(SERVING_ENDPOINT)) {
            val contentType = request.getContentType();
            if (contentType.equals(typeJson)) {
                if (validateRequest(request, response)) {
                    val stream = request.getInputStream();
                    val bufferedReader = new BufferedReader(new InputStreamReader(stream));
                    char[] charBuffer = new char[128];
                    int bytesRead = -1;
                    val buffer = new StringBuilder();
                    while ((bytesRead = bufferedReader.read(charBuffer)) > 0) {
                        buffer.append(charBuffer, 0, bytesRead);
                    }
                    val requestString = buffer.toString();

                    mds = inferenceAdapter.apply(deserializer.deserialize(requestString));
                }
            }
            else if (contentType.equals(typeBinary)) {
                val stream = request.getInputStream();
                int available = request.getContentLength();
                if (available <= 0) {
                    response.sendError(411, "Content length is unavailable");
                }
                else {
                    byte[] data = new byte[available];
                    stream.read(data, 0, available);

                    mds = inferenceAdapter.apply(binaryDeserializer.deserialize(data));
                }
            }
            if (mds == null)
                log.error("InferenceAdapter failed");
            else {
                val result = process(mds);
                if (binarySerializer != null) {
                    byte[] serialized = binarySerializer.serialize(result);
                    response.setContentType(typeBinary);
                    response.setContentLength(serialized.length);
                    val out = response.getOutputStream();
                    out.write(serialized);
                }
                else {
                    processorReturned = serializer.serialize(result);
                    try {
                        val out = response.getWriter();
                        out.write(processorReturned);
                    } catch (IOException e) {
                        log.error(e.getMessage());
                    }
                }
            }
        } else {
            // we return error otherwise
            sendError(request.getRequestURI(), response);
        }
    }

    /**
     * Creates servlet to serve models
     *
     * @param  type of Input class
     * @param  type of Output class
     *
     * @author [email protected]
     * @author astoyakin
     */
    public static class Builder {

        private ParallelInference pi;
        private Model model;

        private InferenceAdapter inferenceAdapter;
        private JsonSerializer serializer;
        private JsonDeserializer deserializer;
        private BinarySerializer binarySerializer;
        private BinaryDeserializer binaryDeserializer;
        private int port;
        private boolean parallelEnabled = true;

        public Builder(@NonNull ParallelInference pi) {
            this.pi = pi;
        }

        public Builder(@NonNull Model model) {
            this.model = model;
        }

        public Builder inferenceAdapter(@NonNull InferenceAdapter inferenceAdapter) {
            this.inferenceAdapter = inferenceAdapter;
            return this;
        }

        /**
         * This method is required to specify serializer
         *
         * @param serializer
         * @return
         */
        public Builder serializer(JsonSerializer serializer) {
            this.serializer = serializer;
            return this;
        }

        /**
         * This method allows to specify deserializer
         *
         * @param deserializer
         * @return
         */
        public Builder deserializer(JsonDeserializer deserializer) {
            this.deserializer = deserializer;
            return this;
        }

        /**
         * This method is required to specify serializer
         *
         * @param serializer
         * @return
         */
        public Builder binarySerializer(BinarySerializer serializer) {
            this.binarySerializer = serializer;
            return this;
        }

        /**
         * This method allows to specify deserializer
         *
         * @param deserializer
         * @return
         */
        public Builder binaryDeserializer(BinaryDeserializer deserializer) {
            this.binaryDeserializer = deserializer;
            return this;
        }

        /**
         * This method allows to specify port
         *
         * @param port
         * @return
         */
        public Builder port(int port) {
            this.port = port;
            return this;
        }

        /**
         * This method activates parallel inference
         *
         * @param parallelEnabled
         * @return
         */
        public Builder parallelEnabled(boolean parallelEnabled) {
            this.parallelEnabled = parallelEnabled;
            return this;
        }

        public DL4jServlet build() {
            return parallelEnabled ? new DL4jServlet(pi, inferenceAdapter, serializer, deserializer, binarySerializer, binaryDeserializer) :
                    new DL4jServlet(model, inferenceAdapter,  serializer, deserializer, binarySerializer, binaryDeserializer);
        }
    }
}