All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.ui.VertxUIServer Maven / Gradle / Ivy

/* ******************************************************************************
 * Copyright (c) 2019 Konduit K.K.
 *
 * This program and the accompanying materials are made available under the
 * terms of the Apache License, Version 2.0 which is available at
 * https://www.apache.org/licenses/LICENSE-2.0.
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 *
 * SPDX-License-Identifier: Apache-2.0
 ******************************************************************************/

package org.deeplearning4j.ui;

import com.beust.jcommander.JCommander;
import com.beust.jcommander.Parameter;
import io.vertx.core.AbstractVerticle;
import io.vertx.core.Vertx;
import io.vertx.core.http.HttpServer;
import io.vertx.core.http.impl.MimeMapping;
import io.vertx.ext.web.Router;
import io.vertx.ext.web.RoutingContext;
import io.vertx.ext.web.handler.BodyHandler;
import lombok.Data;
import lombok.Getter;
import lombok.Setter;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.FilenameUtils;
import org.deeplearning4j.api.storage.StatsStorage;
import org.deeplearning4j.api.storage.StatsStorageEvent;
import org.deeplearning4j.api.storage.StatsStorageListener;
import org.deeplearning4j.api.storage.StatsStorageRouter;
import org.deeplearning4j.config.DL4JSystemProperties;
import org.deeplearning4j.ui.api.Route;
import org.deeplearning4j.ui.api.UIModule;
import org.deeplearning4j.ui.api.UIServer;
import org.deeplearning4j.ui.i18n.I18NProvider;
import org.deeplearning4j.ui.module.SameDiffModule;
import org.deeplearning4j.ui.module.convolutional.ConvolutionalListenerModule;
import org.deeplearning4j.ui.module.defaultModule.DefaultModule;
import org.deeplearning4j.ui.module.remote.RemoteReceiverModule;
import org.deeplearning4j.ui.module.train.TrainModule;
import org.deeplearning4j.ui.module.tsne.TsneModule;
import org.deeplearning4j.ui.storage.FileStatsStorage;
import org.deeplearning4j.ui.storage.InMemoryStatsStorage;
import org.deeplearning4j.ui.storage.impl.QueueStatsStorageListener;
import org.deeplearning4j.util.DL4JFileUtils;
import org.nd4j.linalg.function.Function;
import org.nd4j.linalg.primitives.Pair;

import java.io.File;
import java.util.*;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicBoolean;

@Slf4j
public class VertxUIServer extends AbstractVerticle implements UIServer {
    public static final int DEFAULT_UI_PORT = 9000;
    public static final String ASSETS_ROOT_DIRECTORY = "deeplearning4jUiAssets/";

    @Getter
    private static VertxUIServer instance;

    @Getter
    private static AtomicBoolean multiSession = new AtomicBoolean(false);
    @Getter
    @Setter
    private static Function statsStorageProvider;

    private static Integer instancePort;

    private TrainModule trainModule;

    public static VertxUIServer getInstance() {
        return getInstance(null, multiSession.get(), null);
    }

    public static VertxUIServer getInstance(Integer port, boolean multiSession, Function statsStorageProvider){
        if (instance == null || instance.isStopped()) {
            VertxUIServer.multiSession.set(multiSession);
            VertxUIServer.setStatsStorageProvider(statsStorageProvider);
            instancePort = port;
            Vertx vertx = Vertx.vertx();

            //Launch UI server verticle and wait for it to start
            CountDownLatch l = new CountDownLatch(1);
            vertx.deployVerticle(VertxUIServer.class.getName(), res -> {
                l.countDown();
            });
            try {
                l.await(5000, TimeUnit.MILLISECONDS);
            } catch (InterruptedException e){ } //Ignore
        } else if (!instance.isStopped()) {
            if (instance.multiSession.get() && !instance.isMultiSession()) {
                throw new RuntimeException("Cannot return multi-session instance." +
                        " UIServer has already started in single-session mode at " + instance.getAddress() +
                        " You may stop the UI server instance, and start a new one.");
            } else if (!instance.multiSession.get() && instance.isMultiSession()) {
                throw new RuntimeException("Cannot return single-session instance." +
                        " UIServer has already started in multi-session mode at " + instance.getAddress() +
                        " You may stop the UI server instance, and start a new one.");
            }
        }

        return instance;
    }


    private List uiModules = new CopyOnWriteArrayList<>();
    private RemoteReceiverModule remoteReceiverModule;
    private StatsStorageLoader statsStorageLoader;

    //typeIDModuleMap: Records which modules are registered for which type IDs
    private Map> typeIDModuleMap = new ConcurrentHashMap<>();

    private HttpServer server;
    private AtomicBoolean shutdown = new AtomicBoolean(false);
    private long uiProcessingDelay = 500; //500ms. TODO make configurable


    private final BlockingQueue eventQueue = new LinkedBlockingQueue<>();
    private List> listeners = new CopyOnWriteArrayList<>();
    private List statsStorageInstances = new CopyOnWriteArrayList<>();

    private Thread uiEventRoutingThread;

    public VertxUIServer() {
        instance = this;
    }

    public static void stopInstance(){
        if(instance == null)
            return;
        instance.stop();
    }

    /**
     * Auto-attach StatsStorage if an unknown session ID is passed as URL path parameter in multi-session mode
     * @param statsStorageProvider function that returns a StatsStorage containing the given session ID
     */
    public void autoAttachStatsStorageBySessionId(Function statsStorageProvider) {
        if (statsStorageProvider != null) {
            this.statsStorageLoader = new StatsStorageLoader(statsStorageProvider);
            this.trainModule.setSessionLoader(this.statsStorageLoader);
        }
    }

    @Override
    public void start() throws Exception {
        //Create REST endpoints
        File uploadDir = new File(System.getProperty("java.io.tmpdir"), "DL4JUI_" + System.currentTimeMillis());
        uploadDir.mkdirs();
        Router r = Router.router(vertx);
        r.route().handler(BodyHandler.create()  //NOTE: Setting this is required to receive request body content at all
                .setUploadsDirectory(uploadDir.getAbsolutePath()));
        r.get("/assets/*").handler(rc -> {
            String path = rc.request().path();
            path = path.substring(8);   //Remove "/assets/", which is 8 characters
            String mime;
            String newPath;
            if (path.contains("webjars")) {
                newPath = "META-INF/resources/" + path.substring(path.indexOf("webjars"));
            } else {
                newPath = ASSETS_ROOT_DIRECTORY + (path.startsWith("/") ? path.substring(1) : path);
            }
            mime = MimeMapping.getMimeTypeForFilename(FilenameUtils.getName(newPath));

            //System.out.println("PATH: " + path + " - mime = " + mime);
            rc.response()
                    .putHeader("content-type", mime)
                    .sendFile(newPath);
        });


        if (isMultiSession()) {
            r.get("/setlang/:sessionId/:to").handler(
                    rc -> {
                        String sid = rc.request().getParam("sessionID");
                        String to = rc.request().getParam("to");
                        I18NProvider.getInstance(sid).setDefaultLanguage(to);
                        rc.response().end();
                    });
        } else {
            r.get("/setlang/:to").handler(rc -> {
                String to = rc.request().getParam("to");
                I18NProvider.getInstance().setDefaultLanguage(to);
                rc.response().end();
            });
        }


        uiModules.add(new DefaultModule(isMultiSession())); //For: navigation page "/"
        trainModule = new TrainModule(isMultiSession(), statsStorageLoader, this::getAddress);
        uiModules.add(trainModule);
        uiModules.add(new ConvolutionalListenerModule());
        uiModules.add(new TsneModule());
        uiModules.add(new SameDiffModule());
        remoteReceiverModule = new RemoteReceiverModule();
        uiModules.add(remoteReceiverModule);

        //Check service loader mechanism (Arbiter UI, etc) for modules
        modulesViaServiceLoader(uiModules);

        for (UIModule m : uiModules) {
            List routes = m.getRoutes();
            for (Route route : routes) {
                switch (route.getHttpMethod()) {
                    case GET:
                        r.get(route.getRoute()).handler(rc -> route.getConsumer().accept(extractArgsFromRoute(route.getRoute(), rc), rc));
                        break;
                    case PUT:
                        r.put(route.getRoute()).handler(rc -> route.getConsumer().accept(extractArgsFromRoute(route.getRoute(), rc), rc));
                        break;
                    case POST:
                        r.post(route.getRoute()).handler(rc -> route.getConsumer().accept(extractArgsFromRoute(route.getRoute(), rc), rc));
                        break;
                    default:
                        throw new IllegalStateException("Unknown or not supported HTTP method: " + route.getHttpMethod());
                }
            }

            //Determine which type IDs this module wants to receive:
            List typeIDs = m.getCallbackTypeIDs();
            for (String typeID : typeIDs) {
                List list = typeIDModuleMap.get(typeID);
                if (list == null) {
                    list = Collections.synchronizedList(new ArrayList<>());
                    typeIDModuleMap.put(typeID, list);
                }
                list.add(m);
            }
        }

        //Check port property
        int port = instancePort == null ? DEFAULT_UI_PORT : instancePort;
        String portProp = System.getenv(DL4JSystemProperties.UI_SERVER_PORT_PROPERTY);
        if(portProp != null && !portProp.isEmpty()){
            try{
                port = Integer.parseInt(portProp);
            } catch (NumberFormatException e){
                log.warn("Error parsing port property {}={}", DL4JSystemProperties.UI_SERVER_PORT_PROPERTY, portProp);
            }
        }


        server = vertx.createHttpServer()
                .requestHandler(r)
                .listen(port);

        uiEventRoutingThread = new Thread(new StatsEventRouterRunnable());
        uiEventRoutingThread.setDaemon(true);
        uiEventRoutingThread.start();

        String address = UIServer.getInstance().getAddress();
        log.info("Deeplearning4j UI server started at: {}", address);
    }

    private List extractArgsFromRoute(String path, RoutingContext rc) {
        if (!path.contains(":")) {
            return Collections.emptyList();
        }
        String[] split = path.split("/");
        List out = new ArrayList<>();
        for (String s : split) {
            if (s.startsWith(":")) {
                String s2 = s.substring(1);
                out.add(rc.request().getParam(s2));
            }
        }
        return out;
    }

    private void modulesViaServiceLoader(List uiModules) {

        ServiceLoader sl = ServiceLoader.load(UIModule.class);
        Iterator iter = sl.iterator();

        if (!iter.hasNext()) {
            return;
        }

        while (iter.hasNext()) {
            UIModule m = iter.next();
            Class c = m.getClass();
            boolean foundExisting = false;
            for (UIModule mExisting : uiModules) {
                if (mExisting.getClass() == c) {
                    foundExisting = true;
                    break;
                }
            }

            if (!foundExisting) {
                log.debug("Loaded UI module via service loader: {}", m.getClass());
                uiModules.add(m);
            }
        }
    }

    @Override
    public void stop() {
        server.close();
        shutdown.set(true);
    }


    @Override
    public boolean isStopped() {
        return shutdown.get();
    }

    @Override
    public boolean isMultiSession() {
        return multiSession.get();
    }

    @Override
    public String getAddress() {
        return "http://localhost:" + server.actualPort();
    }

    @Override
    public int getPort() {
        return server.actualPort();
    }

    @Override
    public void attach(StatsStorage statsStorage) {
        if (statsStorage == null)
            throw new IllegalArgumentException("StatsStorage cannot be null");
        if (statsStorageInstances.contains(statsStorage))
            return;
        StatsStorageListener listener = new QueueStatsStorageListener(eventQueue);
        listeners.add(new Pair<>(statsStorage, listener));
        statsStorage.registerStatsStorageListener(listener);
        statsStorageInstances.add(statsStorage);

        for (UIModule uiModule : uiModules) {
            uiModule.onAttach(statsStorage);
        }

        log.info("StatsStorage instance attached to UI: {}", statsStorage);
    }

    @Override
    public void detach(StatsStorage statsStorage) {
        if (statsStorage == null)
            throw new IllegalArgumentException("StatsStorage cannot be null");
        if (!statsStorageInstances.contains(statsStorage))
            return; //No op
        boolean found = false;
        for (Iterator> iterator = listeners.iterator(); iterator.hasNext(); ) {
            Pair p = iterator.next();
            if (p.getFirst() == statsStorage) { //Same object, not equality
                statsStorage.deregisterStatsStorageListener(p.getSecond());
                listeners.remove(p);
                found = true;
            }
        }
        statsStorageInstances.remove(statsStorage);
        for (UIModule uiModule : uiModules) {
            uiModule.onDetach(statsStorage);
        }
        for (String sessionId : statsStorage.listSessionIDs()) {
            I18NProvider.removeInstance(sessionId);
        }
        if (found) {
            log.info("StatsStorage instance detached from UI: {}", statsStorage);
        }
    }

    @Override
    public boolean isAttached(StatsStorage statsStorage) {
        return statsStorageInstances.contains(statsStorage);
    }

    @Override
    public List getStatsStorageInstances() {
        return new ArrayList<>(statsStorageInstances);
    }

    @Override
    public void enableRemoteListener() {
        if (remoteReceiverModule == null)
            remoteReceiverModule = new RemoteReceiverModule();
        if (remoteReceiverModule.isEnabled())
            return;
        enableRemoteListener(new InMemoryStatsStorage(), true);
    }

    @Override
    public void enableRemoteListener(StatsStorageRouter statsStorage, boolean attach) {
        remoteReceiverModule.setEnabled(true);
        remoteReceiverModule.setStatsStorage(statsStorage);
        if (attach && statsStorage instanceof StatsStorage) {
            attach((StatsStorage) statsStorage);
        }
    }

    @Override
    public void disableRemoteListener() {
        remoteReceiverModule.setEnabled(false);
    }

    @Override
    public boolean isRemoteListenerEnabled() {
        return remoteReceiverModule.isEnabled();
    }


    private class StatsEventRouterRunnable implements Runnable {

        @Override
        public void run() {
            try {
                runHelper();
            } catch (Exception e) {
                log.error("Unexpected exception from Event routing runnable", e);
            }
        }

        private void runHelper() throws Exception {
            log.trace("VertxUIServer.StatsEventRouterRunnable started");
            //Idea: collect all event stats, and route them to the appropriate modules
            while (!shutdown.get()) {

                List events = new ArrayList<>();
                StatsStorageEvent sse = eventQueue.take(); //Blocking operation
                events.add(sse);
                eventQueue.drainTo(events); //Non-blocking

                for (UIModule m : uiModules) {

                    List callbackTypes = m.getCallbackTypeIDs();
                    List out = new ArrayList<>();
                    for (StatsStorageEvent e : events) {
                        if (callbackTypes.contains(e.getTypeID())
                                && statsStorageInstances.contains(e.getStatsStorage())) {
                            out.add(e);
                        }
                    }

                    m.reportStorageEvents(out);
                }

                events.clear();

                try {
                    Thread.sleep(uiProcessingDelay);
                } catch (InterruptedException e) {
                    Thread.currentThread().interrupt();
                    if (!shutdown.get()) {
                        throw new RuntimeException("Unexpected interrupted exception", e);
                    }
                }
            }
        }
    }

    /**
     * Loader that attaches {@code StatsStorage} provided by {@code #statsStorageProvider} for the given session ID
     */
    private class StatsStorageLoader implements Function {

        Function statsStorageProvider;

        StatsStorageLoader(Function statsStorageProvider) {
            this.statsStorageProvider = statsStorageProvider;
        }

        @Override
        public Boolean apply(String sessionId) {
            log.info("Loading StatsStorage via StatsStorageProvider for session ID (" + sessionId + ").");
            StatsStorage statsStorage = statsStorageProvider.apply(sessionId);
            if (statsStorage != null) {
                if (statsStorage.sessionExists(sessionId)) {
                    attach(statsStorage);
                    return true;
                }
                log.info("Failed to load StatsStorage via StatsStorageProvider for session ID. " +
                        "Session ID (" + sessionId + ") does not exist in StatsStorage.");
                return false;
            } else {
                log.info("Failed to load StatsStorage via StatsStorageProvider for session ID (" + sessionId + "). " +
                        "StatsStorageProvider returned null.");
                return false;
            }
        }
    }

    //==================================================================================================================
    // CLI Launcher

    @Data
    private static class CLIParams {
        @Parameter(names = {"-r", "--enableRemote"}, description = "Whether to enable remote or not", arity = 1)
        private boolean cliEnableRemote;

        @Parameter(names = {"-p", "--uiPort"}, description = "Custom HTTP port for UI", arity = 1)
        private int cliPort = DEFAULT_UI_PORT;

        @Parameter(names = {"-f", "--customStatsFile"}, description = "Path to create custom stats file (remote only)", arity = 1)
        private String cliCustomStatsFile;

        @Parameter(names = {"-m", "--multiSession"}, description = "Whether to enable multiple separate browser sessions or not", arity = 1)
        private boolean cliMultiSession;
    }

    public void main(String[] args){
        CLIParams d = new CLIParams();
        new JCommander(d).parse(args);
        instancePort = d.getCliPort();
        UIServer.getInstance(d.isCliMultiSession(), null);
        if(d.isCliEnableRemote()){
            try {
                File tempStatsFile = DL4JFileUtils.createTempFile("dl4j", "UIstats");
                tempStatsFile.delete();
                tempStatsFile.deleteOnExit();
                enableRemoteListener(new FileStatsStorage(tempStatsFile), true);
            } catch(Exception e) {
                log.error("Failed to create temporary file for stats storage",e);
                System.exit(1);
            }
        }
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy