All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.ui.UiServer Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
/*
 *
 *  * Copyright 2015 Skymind,Inc.
 *  *
 *  *    Licensed under the Apache License, Version 2.0 (the "License");
 *  *    you may not use this file except in compliance with the License.
 *  *    You may obtain a copy of the License at
 *  *
 *  *        http://www.apache.org/licenses/LICENSE-2.0
 *  *
 *  *    Unless required by applicable law or agreed to in writing, software
 *  *    distributed under the License is distributed on an "AS IS" BASIS,
 *  *    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  *    See the License for the specific language governing permissions and
 *  *    limitations under the License.
 *
 */

package org.deeplearning4j.ui;

import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.SerializationFeature;
import com.fasterxml.jackson.databind.module.SimpleModule;
import com.google.common.collect.ImmutableMap;
import io.dropwizard.Application;
import io.dropwizard.assets.AssetsBundle;
import io.dropwizard.jersey.jackson.JsonProcessingExceptionMapper;
import io.dropwizard.jetty.HttpConnectorFactory;
import io.dropwizard.lifecycle.ServerLifecycleListener;
import io.dropwizard.server.DefaultServerFactory;
import io.dropwizard.setup.Bootstrap;
import io.dropwizard.setup.Environment;
import io.dropwizard.views.ViewBundle;
import org.apache.commons.io.IOUtils;
import org.datavec.api.util.ClassPathResource;
import org.deeplearning4j.ui.activation.ActivationsDropwiz;
import org.deeplearning4j.ui.api.ApiResource;
import org.deeplearning4j.ui.defaults.DefaultDropwiz;
import org.deeplearning4j.ui.exception.GenericExceptionMapper;
import org.deeplearning4j.ui.flow.FlowDropwiz;
import org.deeplearning4j.ui.renders.RendersDropwiz;
import org.deeplearning4j.ui.rl.RlDropwiz;
import org.deeplearning4j.ui.tsne.TsneDropwiz;
import org.deeplearning4j.ui.weights.WeightDropwiz;
import org.eclipse.jetty.server.Connector;
import org.eclipse.jetty.server.Server;
import org.eclipse.jetty.server.ServerConnector;
import org.eclipse.jetty.servlets.CrossOriginFilter;
import org.glassfish.jersey.media.multipart.MultiPartFeature;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.serde.jackson.VectorDeSerializer;
import org.nd4j.serde.jackson.VectorSerializer;

import javax.servlet.DispatcherType;
import javax.servlet.FilterRegistration;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStream;
import java.util.EnumSet;


/**
 * @author Adam Gibson
 */
public class UiServer extends Application {
    private static UiServer INSTANCE;
    private UIConfiguration conf;
    private Environment env;
    public UiServer() {
        INSTANCE = this;
    }
    private int port;
    public int getPort(){ return port; }

    public static synchronized UiServer getInstance() throws Exception {
        if(INSTANCE == null) createServer();
        return INSTANCE;
    }

    public Environment getEnv() {
        return env;
    }


    public UiConnectionInfo getConnectionInfo() {
        UiConnectionInfo info = new UiConnectionInfo.Builder()
                .setPort(getPort())
                .setAddress("localhost")
                .enableHttps(false)
                .build();

        return info;
    }

    @Override
    public void run(UIConfiguration uiConfiguration, Environment environment) throws Exception {
        this.conf = uiConfiguration;
        this.env = environment;

        //Workaround to dropwizard sometimes ignoring ports specified in dropwizard.yml
        int[] portsFromYml = getApplicationPortFromYml();
        if(portsFromYml[0] != -1 ) {
            ((HttpConnectorFactory) ((DefaultServerFactory) uiConfiguration.getServerFactory())
                    .getApplicationConnectors().get(0)).setPort(portsFromYml[0]);
        }
        if(portsFromYml[1] != -1 ) {
            ((HttpConnectorFactory) ((DefaultServerFactory) uiConfiguration.getServerFactory())
                    .getAdminConnectors().get(0)).setPort(portsFromYml[1]);
        }

        //Read the port that actually got assigned (needed for random ports i.e. port: 0 setting in yml)
        environment.lifecycle().addServerLifecycleListener(new ServerLifecycleListener() {
            @Override
            public void serverStarted(Server server) {
                for (Connector connector : server.getConnectors()) {
                    if (connector instanceof ServerConnector) {
                        ServerConnector serverConnector = (ServerConnector) connector;
                        if(!serverConnector.getName().toLowerCase().contains("application")) continue;
                        int port = serverConnector.getLocalPort();
                        try{
                            UiServer.getInstance().port = port;
                        }catch( Exception e ){
                            e.printStackTrace();
                        }
                    }
                }
            }
        });

        environment.jersey().register(MultiPartFeature.class);
        environment.jersey().register(new GenericExceptionMapper());
        environment.jersey().register(new JsonProcessingExceptionMapper());


        environment.jersey().register(new TsneDropwiz(conf.getUploadPath()));
//        environment.jersey().register(new NearestNeighborsDropwiz());
        environment.jersey().register(new DefaultDropwiz());
        environment.jersey().register(new WeightDropwiz());
        environment.jersey().register(new ActivationsDropwiz());
        environment.jersey().register(new RendersDropwiz());
        environment.jersey().register(new ApiResource());
        environment.jersey().register(new GenericExceptionMapper());
        environment.jersey().register(new FlowDropwiz());
        environment.jersey().register(new RlDropwiz());
        environment.jersey().register(new org.deeplearning4j.ui.nearestneighbors.word2vec.NearestNeighborsDropwiz(conf.getUploadPath()));

        environment.getObjectMapper().configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false);
        environment.getObjectMapper().configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);

        environment.getObjectMapper().registerModule(module());

        configureCors(environment);
    }

    @Override
    public void initialize(Bootstrap bootstrap) {
        //custom serializers for the json serde

        bootstrap.getObjectMapper().registerModule(module());


        bootstrap.addBundle(new ViewBundle() {
            @Override
            public ImmutableMap> getViewConfiguration(
                    UIConfiguration arg0) {
                return ImmutableMap.of();
            }
        });
        bootstrap.addBundle(new AssetsBundle());
    }


    private SimpleModule module() {
        SimpleModule module = new SimpleModule();
        module.addSerializer(INDArray.class, new VectorSerializer());
        module.addDeserializer(INDArray.class, new VectorDeSerializer());
        return module;
    }

    public static void main(String[] args) throws Exception {
        createServer();
    }

    public static void createServer() throws Exception {
        ClassPathResource resource = new ClassPathResource("dropwizard.yml");
        File tmpConfig = resource.getFile();
        INSTANCE = new UiServer();
        INSTANCE.run("server", tmpConfig.getAbsolutePath());
    }



    private void configureCors(Environment environment) {
        FilterRegistration.Dynamic filter = environment.servlets().addFilter("CORS", CrossOriginFilter.class);
        filter.addMappingForUrlPatterns(EnumSet.allOf(DispatcherType.class), true, "/*");
        filter.setInitParameter(CrossOriginFilter.ALLOWED_METHODS_PARAM, "GET,PUT,POST,DELETE,OPTIONS");
        filter.setInitParameter(CrossOriginFilter.ALLOWED_ORIGINS_PARAM, "*");
        filter.setInitParameter(CrossOriginFilter.ACCESS_CONTROL_ALLOW_ORIGIN_HEADER, "*");
        filter.setInitParameter("allowedHeaders", "Content-Type,Authorization,X-Requested-With,Content-Length,Accept,Origin");
        filter.setInitParameter("allowCredentials", "true");
    }


    //Parse dropwizard.yml (if present on classpath) and parse the port specifications
    private int[] getApplicationPortFromYml(){
        int[] toReturn = {-1,-1};
        ClassPathResource resource = new ClassPathResource("dropwizard.yml");
        InputStream in = null;
        try {
             in = resource.getInputStream();
            if (in == null) return toReturn;   //Not found
        } catch (FileNotFoundException e) {
            return toReturn;
        }
        String s;
        try {
            s = IOUtils.toString(in);
        } catch(IOException e ){
            return toReturn;
        }

        String[] split = s.split("\n");
        int count = 0;
        for( String str : split ){
            if(!str.contains("port")) continue;
            System.out.println(str);
            String[] line = str.split("\\s+");
            for( String token : line ){
                try{
                    toReturn[count] = Integer.parseInt(token);
                    count++;
                }catch(NumberFormatException e ){ }
            }
            if(count == 2 ) return toReturn;
        }

        return toReturn;  //No port configuration?
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy