All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.ui.VertxUIServer Maven / Gradle / Ivy

The newest version!
    /*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */

package org.deeplearning4j.ui;

import com.beust.jcommander.JCommander;
import com.beust.jcommander.Parameter;
import io.vertx.core.AbstractVerticle;
import io.vertx.core.Future;
import io.vertx.core.Promise;
import io.vertx.core.Vertx;
import io.vertx.core.http.HttpServer;
import io.vertx.core.http.impl.MimeMapping;
import io.vertx.ext.web.Router;
import io.vertx.ext.web.RoutingContext;
import io.vertx.ext.web.handler.BodyHandler;
import lombok.Data;
import lombok.Getter;
import lombok.Setter;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.FilenameUtils;
import org.deeplearning4j.config.DL4JClassLoading;
import org.deeplearning4j.config.DL4JSystemProperties;
import org.deeplearning4j.common.util.ND4JFileUtils;
import org.deeplearning4j.core.storage.StatsStorage;
import org.deeplearning4j.core.storage.StatsStorageEvent;
import org.deeplearning4j.core.storage.StatsStorageListener;
import org.deeplearning4j.core.storage.StatsStorageRouter;
import org.deeplearning4j.exception.DL4JException;
import org.deeplearning4j.ui.api.Route;
import org.deeplearning4j.ui.api.UIModule;
import org.deeplearning4j.ui.api.UIServer;
import org.deeplearning4j.ui.i18n.I18NProvider;
import org.deeplearning4j.ui.model.storage.FileStatsStorage;
import org.deeplearning4j.ui.model.storage.InMemoryStatsStorage;
import org.deeplearning4j.ui.model.storage.impl.QueueStatsStorageListener;
import org.deeplearning4j.ui.module.SameDiffModule;
import org.deeplearning4j.ui.module.convolutional.ConvolutionalListenerModule;
import org.deeplearning4j.ui.module.defaultModule.DefaultModule;
import org.deeplearning4j.ui.module.remote.RemoteReceiverModule;
import org.deeplearning4j.ui.module.train.TrainModule;
import org.deeplearning4j.ui.module.tsne.TsneModule;
import org.nd4j.common.function.Function;
import org.nd4j.common.primitives.Pair;

import java.io.File;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.ServiceLoader;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.atomic.AtomicBoolean;

@Slf4j
public class VertxUIServer extends AbstractVerticle implements UIServer {
    public static final int DEFAULT_UI_PORT = 9000;
    public static final String ASSETS_ROOT_DIRECTORY = "deeplearning4jUiAssets/";

    @Getter
    private static VertxUIServer instance;

    @Getter
    private static AtomicBoolean multiSession = new AtomicBoolean(false);
    @Getter
    @Setter
    private static Function statsStorageProvider;

    private static Integer instancePort;
    @Getter
    private static Thread shutdownHook;

    /**
     * Get (and, initialize if necessary) the UI server. This synchronous function will wait until the server started.
     * @param port TCP socket port for {@link HttpServer} to listen
     * @param multiSession         in multi-session mode, multiple training sessions can be visualized in separate browser tabs.
     *                             
URL path will include session ID as a parameter, i.e.: /train becomes /train/:sessionId * @param statsStorageProvider function that returns a StatsStorage containing the given session ID. *
Use this to auto-attach StatsStorage if an unknown session ID is passed * as URL path parameter in multi-session mode, or leave it {@code null}. * @return UI instance for this JVM * @throws DL4JException if UI server failed to start; * if the instance has already started in a different mode (multi/single-session); * if interrupted while waiting for completion */ public static VertxUIServer getInstance(Integer port, boolean multiSession, Function statsStorageProvider) throws DL4JException { return getInstance(port, multiSession, statsStorageProvider, null); } /** * * Get (and, initialize if necessary) the UI server. This function will wait until the server started * (synchronous way), or pass the given callback to handle success or failure (asynchronous way). * @param port TCP socket port for {@link HttpServer} to listen * @param multiSession in multi-session mode, multiple training sessions can be visualized in separate browser tabs. *
URL path will include session ID as a parameter, i.e.: /train becomes /train/:sessionId * @param statsStorageProvider function that returns a StatsStorage containing the given session ID. *
Use this to auto-attach StatsStorage if an unknown session ID is passed * as URL path parameter in multi-session mode, or leave it {@code null}. * @param startCallback asynchronous deployment handler callback that will be notify of success or failure. * If {@code null} given, then this method will wait until deployment is complete. * If the deployment is successful the result will contain a String representing the * unique deployment ID of the deployment. * @return UI server instance * @throws DL4JException if UI server failed to start; * if the instance has already started in a different mode (multi/single-session); * if interrupted while waiting for completion */ public static VertxUIServer getInstance(Integer port, boolean multiSession, Function statsStorageProvider, Promise startCallback) throws DL4JException { if (instance == null || instance.isStopped()) { VertxUIServer.multiSession.set(multiSession); VertxUIServer.setStatsStorageProvider(statsStorageProvider); instancePort = port; if (startCallback != null) { //Launch UI server verticle and pass asynchronous callback that will be notified of completion deploy(startCallback); } else { //Launch UI server verticle and wait for it to start deploy(); } } else if (!instance.isStopped()) { if (multiSession && !instance.isMultiSession()) { throw new DL4JException("Cannot return multi-session instance." + " UIServer has already started in single-session mode at " + instance.getAddress() + " You may stop the UI server instance, and start a new one."); } else if (!multiSession && instance.isMultiSession()) { throw new DL4JException("Cannot return single-session instance." + " UIServer has already started in multi-session mode at " + instance.getAddress() + " You may stop the UI server instance, and start a new one."); } } return instance; } /** * Deploy (start) {@link VertxUIServer}, waiting until starting is complete. * @throws DL4JException if UI server failed to start; * if interrupted while waiting for completion */ private static void deploy() throws DL4JException { CountDownLatch l = new CountDownLatch(1); Promise promise = Promise.promise(); promise.future().compose( success -> Future.future(prom -> l.countDown()), failure -> Future.future(prom -> l.countDown()) ); deploy(promise); // synchronous function try { l.await(); } catch (InterruptedException e) { throw new DL4JException(e); } Future future = promise.future(); if (future.failed()) { throw new DL4JException("Deeplearning4j UI server failed to start.", future.cause()); } } /** * Deploy (start) {@link VertxUIServer}, * and pass callback to handle successful or failed completion of deployment. * @param startCallback promise that will handle success or failure of deployment. * If the deployment is successful the result will contain a String representing the unique deployment ID of the * deployment. */ private static void deploy(Promise startCallback) { log.debug("Deeplearning4j UI server is starting."); Promise promise = Promise.promise(); promise.future().compose( success -> Future.future(prom -> startCallback.complete(success)), failure -> Future.future(prom -> startCallback.fail(new RuntimeException(failure))) ); Vertx vertx = Vertx.vertx(); vertx.deployVerticle(VertxUIServer.class.getName(), promise); VertxUIServer.shutdownHook = new Thread(() -> { if (VertxUIServer.instance != null && !VertxUIServer.instance.isStopped()) { log.info("Deeplearning4j UI server is auto-stopping in shutdown hook."); try { instance.stop(); } catch (InterruptedException e) { log.error("Interrupted stopping of Deeplearning4j UI server in shutdown hook.", e); } } }); Runtime.getRuntime().addShutdownHook(shutdownHook); } private List uiModules = new CopyOnWriteArrayList<>(); private RemoteReceiverModule remoteReceiverModule; /** * Loader that attaches {@code StatsStorage} provided by {@code #statsStorageProvider} for the given session ID */ @Getter private Function statsStorageLoader; //typeIDModuleMap: Records which modules are registered for which type IDs private Map> typeIDModuleMap = new ConcurrentHashMap<>(); private HttpServer server; private AtomicBoolean shutdown = new AtomicBoolean(false); private long uiProcessingDelay = 500; //500ms. TODO make configurable private final BlockingQueue eventQueue = new LinkedBlockingQueue<>(); private List> listeners = new CopyOnWriteArrayList<>(); private List statsStorageInstances = new CopyOnWriteArrayList<>(); private Thread uiEventRoutingThread; public VertxUIServer() { instance = this; } public static void stopInstance() throws Exception { if(instance == null || instance.isStopped()) return; instance.stop(); VertxUIServer.reset(); } private static void reset() { VertxUIServer.instance = null; VertxUIServer.statsStorageProvider = null; VertxUIServer.instancePort = null; VertxUIServer.multiSession.set(false); } /** * Auto-attach StatsStorage if an unknown session ID is passed as URL path parameter in multi-session mode * @param statsStorageProvider function that returns a StatsStorage containing the given session ID */ public void autoAttachStatsStorageBySessionId(Function statsStorageProvider) { if (statsStorageProvider != null) { this.statsStorageLoader = (sessionId) -> { log.info("Loading StatsStorage via StatsStorageProvider for session ID (" + sessionId + ")."); StatsStorage statsStorage = statsStorageProvider.apply(sessionId); if (statsStorage != null) { if (statsStorage.sessionExists(sessionId)) { attach(statsStorage); return true; } log.info("Failed to load StatsStorage via StatsStorageProvider for session ID. " + "Session ID (" + sessionId + ") does not exist in StatsStorage."); return false; } else { log.info("Failed to load StatsStorage via StatsStorageProvider for session ID (" + sessionId + "). " + "StatsStorageProvider returned null."); return false; } }; } } @Override public void start(Promise startCallback) throws Exception { //Create REST endpoints File uploadDir = new File(System.getProperty("java.io.tmpdir"), "DL4JUI_" + System.currentTimeMillis()); uploadDir.mkdirs(); Router r = Router.router(vertx); r.route().handler(BodyHandler.create() //NOTE: Setting this is required to receive request body content at all .setUploadsDirectory(uploadDir.getAbsolutePath())); r.get("/assets/*").handler(rc -> { String path = rc.request().path(); path = path.substring(8); //Remove "/assets/", which is 8 characters String mime; String newPath; if (path.contains("webjars")) { newPath = "META-INF/resources/" + path.substring(path.indexOf("webjars")); } else { newPath = ASSETS_ROOT_DIRECTORY + (path.startsWith("/") ? path.substring(1) : path); } mime = MimeMapping.getMimeTypeForFilename(FilenameUtils.getName(newPath)); //System.out.println("PATH: " + path + " - mime = " + mime); rc.response() .putHeader("content-type", mime) .sendFile(newPath); }); if (isMultiSession()) { r.get("/setlang/:sessionId/:to").handler( rc -> { String sid = rc.request().getParam("sessionID"); String to = rc.request().getParam("to"); I18NProvider.getInstance(sid).setDefaultLanguage(to); rc.response().end(); }); } else { r.get("/setlang/:to").handler(rc -> { String to = rc.request().getParam("to"); I18NProvider.getInstance().setDefaultLanguage(to); rc.response().end(); }); } if (VertxUIServer.statsStorageProvider != null) { autoAttachStatsStorageBySessionId(VertxUIServer.statsStorageProvider); } uiModules.add(new DefaultModule(isMultiSession())); //For: navigation page "/" uiModules.add(new TrainModule()); uiModules.add(new ConvolutionalListenerModule()); uiModules.add(new TsneModule()); uiModules.add(new SameDiffModule()); remoteReceiverModule = new RemoteReceiverModule(); uiModules.add(remoteReceiverModule); //Check service loader mechanism (Arbiter UI, etc) for modules modulesViaServiceLoader(uiModules); for (UIModule m : uiModules) { List routes = m.getRoutes(); for (Route route : routes) { switch (route.getHttpMethod()) { case GET: r.get(route.getRoute()).handler(rc -> route.getConsumer().accept(extractArgsFromRoute(route.getRoute(), rc), rc)); break; case PUT: r.put(route.getRoute()).handler(rc -> route.getConsumer().accept(extractArgsFromRoute(route.getRoute(), rc), rc)); break; case POST: r.post(route.getRoute()).handler(rc -> route.getConsumer().accept(extractArgsFromRoute(route.getRoute(), rc), rc)); break; default: throw new IllegalStateException("Unknown or not supported HTTP method: " + route.getHttpMethod()); } } //Determine which type IDs this module wants to receive: List typeIDs = m.getCallbackTypeIDs(); for (String typeID : typeIDs) { List list = typeIDModuleMap.get(typeID); if (list == null) { list = Collections.synchronizedList(new ArrayList<>()); typeIDModuleMap.put(typeID, list); } list.add(m); } } //Check port property int port = instancePort == null ? DEFAULT_UI_PORT : instancePort; String portProp = System.getProperty(DL4JSystemProperties.UI_SERVER_PORT_PROPERTY); if(portProp != null && !portProp.isEmpty()){ try{ port = Integer.parseInt(portProp); } catch (NumberFormatException e){ log.warn("Error parsing port property {}={}", DL4JSystemProperties.UI_SERVER_PORT_PROPERTY, portProp); } } if (port < 0 || port > 0xFFFF) { throw new IllegalStateException("Valid port range is 0 <= port <= 65535. The given port was " + port); } uiEventRoutingThread = new Thread(new StatsEventRouterRunnable()); uiEventRoutingThread.setDaemon(true); uiEventRoutingThread.start(); server = vertx.createHttpServer() .requestHandler(r) .listen(port, result -> { if (result.succeeded()) { String address = UIServer.getInstance().getAddress(); log.info("Deeplearning4j UI server started at: {}", address); startCallback.complete(); } else { startCallback.fail(new RuntimeException("Deeplearning4j UI server failed to listen on port " + server.actualPort(), result.cause())); } }); } private List extractArgsFromRoute(String path, RoutingContext rc) { if (!path.contains(":")) { return Collections.emptyList(); } String[] split = path.split("/"); List out = new ArrayList<>(); for (String s : split) { if (s.startsWith(":")) { String s2 = s.substring(1); out.add(rc.request().getParam(s2)); } } return out; } private void modulesViaServiceLoader(List uiModules) { ServiceLoader sl = DL4JClassLoading.loadService(UIModule.class); Iterator iter = sl.iterator(); if (!iter.hasNext()) { return; } while (iter.hasNext()) { UIModule module = iter.next(); Class moduleClass = module.getClass(); boolean foundExisting = false; for (UIModule mExisting : uiModules) { if (mExisting.getClass() == moduleClass) { foundExisting = true; break; } } if (!foundExisting) { log.debug("Loaded UI module via service loader: {}", module.getClass()); uiModules.add(module); } } } @Override public void stop() throws InterruptedException { CountDownLatch l = new CountDownLatch(1); Promise promise = Promise.promise(); promise.future().compose( successEvent -> Future.future(prom -> l.countDown()), failureEvent -> Future.future(prom -> l.countDown()) ); stopAsync(promise); // synchronous function should wait until the server is stopped l.await(); } @Override public void stopAsync(Promise stopCallback) { /** * Stop Vertx instance and release any resources held by it. * Pass promise to {@link #stop(Promise)}. */ vertx.close(ar -> stopCallback.handle(ar)); } @Override public void stop(Promise stopCallback) { shutdown.set(true); stopCallback.complete(); log.info("Deeplearning4j UI server stopped."); } @Override public boolean isStopped() { return shutdown.get(); } @Override public boolean isMultiSession() { return multiSession.get(); } @Override public String getAddress() { return "http://localhost:" + server.actualPort(); } @Override public int getPort() { return server.actualPort(); } @Override public void attach(StatsStorage statsStorage) { if (statsStorage == null) throw new IllegalArgumentException("StatsStorage cannot be null"); if (statsStorageInstances.contains(statsStorage)) return; StatsStorageListener listener = new QueueStatsStorageListener(eventQueue); listeners.add(new Pair<>(statsStorage, listener)); statsStorage.registerStatsStorageListener(listener); statsStorageInstances.add(statsStorage); for (UIModule uiModule : uiModules) { uiModule.onAttach(statsStorage); } log.info("StatsStorage instance attached to UI: {}", statsStorage); } @Override public void detach(StatsStorage statsStorage) { if (statsStorage == null) throw new IllegalArgumentException("StatsStorage cannot be null"); if (!statsStorageInstances.contains(statsStorage)) return; //No op boolean found = false; for (Pair p : listeners) { if (p.getFirst() == statsStorage) { //Same object, not equality statsStorage.deregisterStatsStorageListener(p.getSecond()); listeners.remove(p); found = true; } } statsStorageInstances.remove(statsStorage); for (UIModule uiModule : uiModules) { uiModule.onDetach(statsStorage); } for (String sessionId : statsStorage.listSessionIDs()) { I18NProvider.removeInstance(sessionId); } if (found) { log.info("StatsStorage instance detached from UI: {}", statsStorage); } } @Override public boolean isAttached(StatsStorage statsStorage) { return statsStorageInstances.contains(statsStorage); } @Override public List getStatsStorageInstances() { return new ArrayList<>(statsStorageInstances); } @Override public void enableRemoteListener() { if (remoteReceiverModule == null) remoteReceiverModule = new RemoteReceiverModule(); if (remoteReceiverModule.isEnabled()) return; enableRemoteListener(new InMemoryStatsStorage(), true); } @Override public void enableRemoteListener(StatsStorageRouter statsStorage, boolean attach) { remoteReceiverModule.setEnabled(true); remoteReceiverModule.setStatsStorage(statsStorage); if (attach && statsStorage instanceof StatsStorage) { attach((StatsStorage) statsStorage); } } @Override public void disableRemoteListener() { remoteReceiverModule.setEnabled(false); } @Override public boolean isRemoteListenerEnabled() { return remoteReceiverModule.isEnabled(); } private class StatsEventRouterRunnable implements Runnable { @Override public void run() { try { runHelper(); } catch (Exception e) { log.error("Unexpected exception from Event routing runnable", e); } } private void runHelper() throws Exception { log.trace("VertxUIServer.StatsEventRouterRunnable started"); //Idea: collect all event stats, and route them to the appropriate modules while (!shutdown.get()) { List events = new ArrayList<>(); StatsStorageEvent sse = eventQueue.take(); //Blocking operation events.add(sse); eventQueue.drainTo(events); //Non-blocking for (UIModule m : uiModules) { List callbackTypes = m.getCallbackTypeIDs(); List out = new ArrayList<>(); for (StatsStorageEvent e : events) { if (callbackTypes.contains(e.getTypeID()) && statsStorageInstances.contains(e.getStatsStorage())) { out.add(e); } } m.reportStorageEvents(out); } events.clear(); try { Thread.sleep(uiProcessingDelay); } catch (InterruptedException e) { Thread.currentThread().interrupt(); if (!shutdown.get()) { throw new RuntimeException("Unexpected interrupted exception", e); } } } } } //================================================================================================================== // CLI Launcher @Data private static class CLIParams { @Parameter(names = {"-r", "--enableRemote"}, description = "Whether to enable remote or not", arity = 1) private boolean cliEnableRemote; @Parameter(names = {"-p", "--uiPort"}, description = "Custom HTTP port for UI", arity = 1) private int cliPort = DEFAULT_UI_PORT; @Parameter(names = {"-f", "--customStatsFile"}, description = "Path to create custom stats file (remote only)", arity = 1) private String cliCustomStatsFile; @Parameter(names = {"-m", "--multiSession"}, description = "Whether to enable multiple separate browser sessions or not", arity = 1) private boolean cliMultiSession; } public void main(String[] args){ CLIParams d = new CLIParams(); new JCommander(d).parse(args); instancePort = d.getCliPort(); UIServer.getInstance(d.isCliMultiSession(), null); if(d.isCliEnableRemote()){ try { File tempStatsFile = ND4JFileUtils.createTempFile("dl4j", "UIstats"); tempStatsFile.delete(); tempStatsFile.deleteOnExit(); enableRemoteListener(new FileStatsStorage(tempStatsFile), true); } catch(Exception e) { log.error("Failed to create temporary file for stats storage",e); System.exit(1); } } } }




© 2015 - 2024 Weber Informatics LLC | Privacy Policy