org.deeplearning4j.ui.VertxUIServer Maven / Gradle / Ivy
The newest version!
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.ui;
import com.beust.jcommander.JCommander;
import com.beust.jcommander.Parameter;
import io.vertx.core.AbstractVerticle;
import io.vertx.core.Future;
import io.vertx.core.Promise;
import io.vertx.core.Vertx;
import io.vertx.core.http.HttpServer;
import io.vertx.core.http.impl.MimeMapping;
import io.vertx.ext.web.Router;
import io.vertx.ext.web.RoutingContext;
import io.vertx.ext.web.handler.BodyHandler;
import lombok.Data;
import lombok.Getter;
import lombok.Setter;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.FilenameUtils;
import org.deeplearning4j.config.DL4JClassLoading;
import org.deeplearning4j.config.DL4JSystemProperties;
import org.deeplearning4j.common.util.ND4JFileUtils;
import org.deeplearning4j.core.storage.StatsStorage;
import org.deeplearning4j.core.storage.StatsStorageEvent;
import org.deeplearning4j.core.storage.StatsStorageListener;
import org.deeplearning4j.core.storage.StatsStorageRouter;
import org.deeplearning4j.exception.DL4JException;
import org.deeplearning4j.ui.api.Route;
import org.deeplearning4j.ui.api.UIModule;
import org.deeplearning4j.ui.api.UIServer;
import org.deeplearning4j.ui.i18n.I18NProvider;
import org.deeplearning4j.ui.model.storage.FileStatsStorage;
import org.deeplearning4j.ui.model.storage.InMemoryStatsStorage;
import org.deeplearning4j.ui.model.storage.impl.QueueStatsStorageListener;
import org.deeplearning4j.ui.module.SameDiffModule;
import org.deeplearning4j.ui.module.convolutional.ConvolutionalListenerModule;
import org.deeplearning4j.ui.module.defaultModule.DefaultModule;
import org.deeplearning4j.ui.module.remote.RemoteReceiverModule;
import org.deeplearning4j.ui.module.train.TrainModule;
import org.deeplearning4j.ui.module.tsne.TsneModule;
import org.nd4j.common.function.Function;
import org.nd4j.common.primitives.Pair;
import java.io.File;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.ServiceLoader;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.atomic.AtomicBoolean;
@Slf4j
public class VertxUIServer extends AbstractVerticle implements UIServer {
public static final int DEFAULT_UI_PORT = 9000;
public static final String ASSETS_ROOT_DIRECTORY = "deeplearning4jUiAssets/";
@Getter
private static VertxUIServer instance;
@Getter
private static AtomicBoolean multiSession = new AtomicBoolean(false);
@Getter
@Setter
private static Function statsStorageProvider;
private static Integer instancePort;
@Getter
private static Thread shutdownHook;
/**
* Get (and, initialize if necessary) the UI server. This synchronous function will wait until the server started.
* @param port TCP socket port for {@link HttpServer} to listen
* @param multiSession in multi-session mode, multiple training sessions can be visualized in separate browser tabs.
*
URL path will include session ID as a parameter, i.e.: /train becomes /train/:sessionId
* @param statsStorageProvider function that returns a StatsStorage containing the given session ID.
*
Use this to auto-attach StatsStorage if an unknown session ID is passed
* as URL path parameter in multi-session mode, or leave it {@code null}.
* @return UI instance for this JVM
* @throws DL4JException if UI server failed to start;
* if the instance has already started in a different mode (multi/single-session);
* if interrupted while waiting for completion
*/
public static VertxUIServer getInstance(Integer port, boolean multiSession,
Function statsStorageProvider) throws DL4JException {
return getInstance(port, multiSession, statsStorageProvider, null);
}
/**
*
* Get (and, initialize if necessary) the UI server. This function will wait until the server started
* (synchronous way), or pass the given callback to handle success or failure (asynchronous way).
* @param port TCP socket port for {@link HttpServer} to listen
* @param multiSession in multi-session mode, multiple training sessions can be visualized in separate browser tabs.
*
URL path will include session ID as a parameter, i.e.: /train becomes /train/:sessionId
* @param statsStorageProvider function that returns a StatsStorage containing the given session ID.
*
Use this to auto-attach StatsStorage if an unknown session ID is passed
* as URL path parameter in multi-session mode, or leave it {@code null}.
* @param startCallback asynchronous deployment handler callback that will be notify of success or failure.
* If {@code null} given, then this method will wait until deployment is complete.
* If the deployment is successful the result will contain a String representing the
* unique deployment ID of the deployment.
* @return UI server instance
* @throws DL4JException if UI server failed to start;
* if the instance has already started in a different mode (multi/single-session);
* if interrupted while waiting for completion
*/
public static VertxUIServer getInstance(Integer port, boolean multiSession,
Function statsStorageProvider, Promise startCallback)
throws DL4JException {
if (instance == null || instance.isStopped()) {
VertxUIServer.multiSession.set(multiSession);
VertxUIServer.setStatsStorageProvider(statsStorageProvider);
instancePort = port;
if (startCallback != null) {
//Launch UI server verticle and pass asynchronous callback that will be notified of completion
deploy(startCallback);
} else {
//Launch UI server verticle and wait for it to start
deploy();
}
} else if (!instance.isStopped()) {
if (multiSession && !instance.isMultiSession()) {
throw new DL4JException("Cannot return multi-session instance." +
" UIServer has already started in single-session mode at " + instance.getAddress() +
" You may stop the UI server instance, and start a new one.");
} else if (!multiSession && instance.isMultiSession()) {
throw new DL4JException("Cannot return single-session instance." +
" UIServer has already started in multi-session mode at " + instance.getAddress() +
" You may stop the UI server instance, and start a new one.");
}
}
return instance;
}
/**
* Deploy (start) {@link VertxUIServer}, waiting until starting is complete.
* @throws DL4JException if UI server failed to start;
* if interrupted while waiting for completion
*/
private static void deploy() throws DL4JException {
CountDownLatch l = new CountDownLatch(1);
Promise promise = Promise.promise();
promise.future().compose(
success -> Future.future(prom -> l.countDown()),
failure -> Future.future(prom -> l.countDown())
);
deploy(promise);
// synchronous function
try {
l.await();
} catch (InterruptedException e) {
throw new DL4JException(e);
}
Future future = promise.future();
if (future.failed()) {
throw new DL4JException("Deeplearning4j UI server failed to start.", future.cause());
}
}
/**
* Deploy (start) {@link VertxUIServer},
* and pass callback to handle successful or failed completion of deployment.
* @param startCallback promise that will handle success or failure of deployment.
* If the deployment is successful the result will contain a String representing the unique deployment ID of the
* deployment.
*/
private static void deploy(Promise startCallback) {
log.debug("Deeplearning4j UI server is starting.");
Promise promise = Promise.promise();
promise.future().compose(
success -> Future.future(prom -> startCallback.complete(success)),
failure -> Future.future(prom -> startCallback.fail(new RuntimeException(failure)))
);
Vertx vertx = Vertx.vertx();
vertx.deployVerticle(VertxUIServer.class.getName(), promise);
VertxUIServer.shutdownHook = new Thread(() -> {
if (VertxUIServer.instance != null && !VertxUIServer.instance.isStopped()) {
log.info("Deeplearning4j UI server is auto-stopping in shutdown hook.");
try {
instance.stop();
} catch (InterruptedException e) {
log.error("Interrupted stopping of Deeplearning4j UI server in shutdown hook.", e);
}
}
});
Runtime.getRuntime().addShutdownHook(shutdownHook);
}
private List uiModules = new CopyOnWriteArrayList<>();
private RemoteReceiverModule remoteReceiverModule;
/**
* Loader that attaches {@code StatsStorage} provided by {@code #statsStorageProvider} for the given session ID
*/
@Getter
private Function statsStorageLoader;
//typeIDModuleMap: Records which modules are registered for which type IDs
private Map> typeIDModuleMap = new ConcurrentHashMap<>();
private HttpServer server;
private AtomicBoolean shutdown = new AtomicBoolean(false);
private long uiProcessingDelay = 500; //500ms. TODO make configurable
private final BlockingQueue eventQueue = new LinkedBlockingQueue<>();
private List> listeners = new CopyOnWriteArrayList<>();
private List statsStorageInstances = new CopyOnWriteArrayList<>();
private Thread uiEventRoutingThread;
public VertxUIServer() {
instance = this;
}
public static void stopInstance() throws Exception {
if(instance == null || instance.isStopped())
return;
instance.stop();
VertxUIServer.reset();
}
private static void reset() {
VertxUIServer.instance = null;
VertxUIServer.statsStorageProvider = null;
VertxUIServer.instancePort = null;
VertxUIServer.multiSession.set(false);
}
/**
* Auto-attach StatsStorage if an unknown session ID is passed as URL path parameter in multi-session mode
* @param statsStorageProvider function that returns a StatsStorage containing the given session ID
*/
public void autoAttachStatsStorageBySessionId(Function statsStorageProvider) {
if (statsStorageProvider != null) {
this.statsStorageLoader = (sessionId) -> {
log.info("Loading StatsStorage via StatsStorageProvider for session ID (" + sessionId + ").");
StatsStorage statsStorage = statsStorageProvider.apply(sessionId);
if (statsStorage != null) {
if (statsStorage.sessionExists(sessionId)) {
attach(statsStorage);
return true;
}
log.info("Failed to load StatsStorage via StatsStorageProvider for session ID. " +
"Session ID (" + sessionId + ") does not exist in StatsStorage.");
return false;
} else {
log.info("Failed to load StatsStorage via StatsStorageProvider for session ID (" + sessionId + "). " +
"StatsStorageProvider returned null.");
return false;
}
};
}
}
@Override
public void start(Promise startCallback) throws Exception {
//Create REST endpoints
File uploadDir = new File(System.getProperty("java.io.tmpdir"), "DL4JUI_" + System.currentTimeMillis());
uploadDir.mkdirs();
Router r = Router.router(vertx);
r.route().handler(BodyHandler.create() //NOTE: Setting this is required to receive request body content at all
.setUploadsDirectory(uploadDir.getAbsolutePath()));
r.get("/assets/*").handler(rc -> {
String path = rc.request().path();
path = path.substring(8); //Remove "/assets/", which is 8 characters
String mime;
String newPath;
if (path.contains("webjars")) {
newPath = "META-INF/resources/" + path.substring(path.indexOf("webjars"));
} else {
newPath = ASSETS_ROOT_DIRECTORY + (path.startsWith("/") ? path.substring(1) : path);
}
mime = MimeMapping.getMimeTypeForFilename(FilenameUtils.getName(newPath));
//System.out.println("PATH: " + path + " - mime = " + mime);
rc.response()
.putHeader("content-type", mime)
.sendFile(newPath);
});
if (isMultiSession()) {
r.get("/setlang/:sessionId/:to").handler(
rc -> {
String sid = rc.request().getParam("sessionID");
String to = rc.request().getParam("to");
I18NProvider.getInstance(sid).setDefaultLanguage(to);
rc.response().end();
});
} else {
r.get("/setlang/:to").handler(rc -> {
String to = rc.request().getParam("to");
I18NProvider.getInstance().setDefaultLanguage(to);
rc.response().end();
});
}
if (VertxUIServer.statsStorageProvider != null) {
autoAttachStatsStorageBySessionId(VertxUIServer.statsStorageProvider);
}
uiModules.add(new DefaultModule(isMultiSession())); //For: navigation page "/"
uiModules.add(new TrainModule());
uiModules.add(new ConvolutionalListenerModule());
uiModules.add(new TsneModule());
uiModules.add(new SameDiffModule());
remoteReceiverModule = new RemoteReceiverModule();
uiModules.add(remoteReceiverModule);
//Check service loader mechanism (Arbiter UI, etc) for modules
modulesViaServiceLoader(uiModules);
for (UIModule m : uiModules) {
List routes = m.getRoutes();
for (Route route : routes) {
switch (route.getHttpMethod()) {
case GET:
r.get(route.getRoute()).handler(rc -> route.getConsumer().accept(extractArgsFromRoute(route.getRoute(), rc), rc));
break;
case PUT:
r.put(route.getRoute()).handler(rc -> route.getConsumer().accept(extractArgsFromRoute(route.getRoute(), rc), rc));
break;
case POST:
r.post(route.getRoute()).handler(rc -> route.getConsumer().accept(extractArgsFromRoute(route.getRoute(), rc), rc));
break;
default:
throw new IllegalStateException("Unknown or not supported HTTP method: " + route.getHttpMethod());
}
}
//Determine which type IDs this module wants to receive:
List typeIDs = m.getCallbackTypeIDs();
for (String typeID : typeIDs) {
List list = typeIDModuleMap.get(typeID);
if (list == null) {
list = Collections.synchronizedList(new ArrayList<>());
typeIDModuleMap.put(typeID, list);
}
list.add(m);
}
}
//Check port property
int port = instancePort == null ? DEFAULT_UI_PORT : instancePort;
String portProp = System.getProperty(DL4JSystemProperties.UI_SERVER_PORT_PROPERTY);
if(portProp != null && !portProp.isEmpty()){
try{
port = Integer.parseInt(portProp);
} catch (NumberFormatException e){
log.warn("Error parsing port property {}={}", DL4JSystemProperties.UI_SERVER_PORT_PROPERTY, portProp);
}
}
if (port < 0 || port > 0xFFFF) {
throw new IllegalStateException("Valid port range is 0 <= port <= 65535. The given port was " + port);
}
uiEventRoutingThread = new Thread(new StatsEventRouterRunnable());
uiEventRoutingThread.setDaemon(true);
uiEventRoutingThread.start();
server = vertx.createHttpServer()
.requestHandler(r)
.listen(port, result -> {
if (result.succeeded()) {
String address = UIServer.getInstance().getAddress();
log.info("Deeplearning4j UI server started at: {}", address);
startCallback.complete();
} else {
startCallback.fail(new RuntimeException("Deeplearning4j UI server failed to listen on port "
+ server.actualPort(), result.cause()));
}
});
}
private List extractArgsFromRoute(String path, RoutingContext rc) {
if (!path.contains(":")) {
return Collections.emptyList();
}
String[] split = path.split("/");
List out = new ArrayList<>();
for (String s : split) {
if (s.startsWith(":")) {
String s2 = s.substring(1);
out.add(rc.request().getParam(s2));
}
}
return out;
}
private void modulesViaServiceLoader(List uiModules) {
ServiceLoader sl = DL4JClassLoading.loadService(UIModule.class);
Iterator iter = sl.iterator();
if (!iter.hasNext()) {
return;
}
while (iter.hasNext()) {
UIModule module = iter.next();
Class> moduleClass = module.getClass();
boolean foundExisting = false;
for (UIModule mExisting : uiModules) {
if (mExisting.getClass() == moduleClass) {
foundExisting = true;
break;
}
}
if (!foundExisting) {
log.debug("Loaded UI module via service loader: {}", module.getClass());
uiModules.add(module);
}
}
}
@Override
public void stop() throws InterruptedException {
CountDownLatch l = new CountDownLatch(1);
Promise promise = Promise.promise();
promise.future().compose(
successEvent -> Future.future(prom -> l.countDown()),
failureEvent -> Future.future(prom -> l.countDown())
);
stopAsync(promise);
// synchronous function should wait until the server is stopped
l.await();
}
@Override
public void stopAsync(Promise stopCallback) {
/**
* Stop Vertx instance and release any resources held by it.
* Pass promise to {@link #stop(Promise)}.
*/
vertx.close(ar -> stopCallback.handle(ar));
}
@Override
public void stop(Promise stopCallback) {
shutdown.set(true);
stopCallback.complete();
log.info("Deeplearning4j UI server stopped.");
}
@Override
public boolean isStopped() {
return shutdown.get();
}
@Override
public boolean isMultiSession() {
return multiSession.get();
}
@Override
public String getAddress() {
return "http://localhost:" + server.actualPort();
}
@Override
public int getPort() {
return server.actualPort();
}
@Override
public void attach(StatsStorage statsStorage) {
if (statsStorage == null)
throw new IllegalArgumentException("StatsStorage cannot be null");
if (statsStorageInstances.contains(statsStorage))
return;
StatsStorageListener listener = new QueueStatsStorageListener(eventQueue);
listeners.add(new Pair<>(statsStorage, listener));
statsStorage.registerStatsStorageListener(listener);
statsStorageInstances.add(statsStorage);
for (UIModule uiModule : uiModules) {
uiModule.onAttach(statsStorage);
}
log.info("StatsStorage instance attached to UI: {}", statsStorage);
}
@Override
public void detach(StatsStorage statsStorage) {
if (statsStorage == null)
throw new IllegalArgumentException("StatsStorage cannot be null");
if (!statsStorageInstances.contains(statsStorage))
return; //No op
boolean found = false;
for (Pair p : listeners) {
if (p.getFirst() == statsStorage) { //Same object, not equality
statsStorage.deregisterStatsStorageListener(p.getSecond());
listeners.remove(p);
found = true;
}
}
statsStorageInstances.remove(statsStorage);
for (UIModule uiModule : uiModules) {
uiModule.onDetach(statsStorage);
}
for (String sessionId : statsStorage.listSessionIDs()) {
I18NProvider.removeInstance(sessionId);
}
if (found) {
log.info("StatsStorage instance detached from UI: {}", statsStorage);
}
}
@Override
public boolean isAttached(StatsStorage statsStorage) {
return statsStorageInstances.contains(statsStorage);
}
@Override
public List getStatsStorageInstances() {
return new ArrayList<>(statsStorageInstances);
}
@Override
public void enableRemoteListener() {
if (remoteReceiverModule == null)
remoteReceiverModule = new RemoteReceiverModule();
if (remoteReceiverModule.isEnabled())
return;
enableRemoteListener(new InMemoryStatsStorage(), true);
}
@Override
public void enableRemoteListener(StatsStorageRouter statsStorage, boolean attach) {
remoteReceiverModule.setEnabled(true);
remoteReceiverModule.setStatsStorage(statsStorage);
if (attach && statsStorage instanceof StatsStorage) {
attach((StatsStorage) statsStorage);
}
}
@Override
public void disableRemoteListener() {
remoteReceiverModule.setEnabled(false);
}
@Override
public boolean isRemoteListenerEnabled() {
return remoteReceiverModule.isEnabled();
}
private class StatsEventRouterRunnable implements Runnable {
@Override
public void run() {
try {
runHelper();
} catch (Exception e) {
log.error("Unexpected exception from Event routing runnable", e);
}
}
private void runHelper() throws Exception {
log.trace("VertxUIServer.StatsEventRouterRunnable started");
//Idea: collect all event stats, and route them to the appropriate modules
while (!shutdown.get()) {
List events = new ArrayList<>();
StatsStorageEvent sse = eventQueue.take(); //Blocking operation
events.add(sse);
eventQueue.drainTo(events); //Non-blocking
for (UIModule m : uiModules) {
List callbackTypes = m.getCallbackTypeIDs();
List out = new ArrayList<>();
for (StatsStorageEvent e : events) {
if (callbackTypes.contains(e.getTypeID())
&& statsStorageInstances.contains(e.getStatsStorage())) {
out.add(e);
}
}
m.reportStorageEvents(out);
}
events.clear();
try {
Thread.sleep(uiProcessingDelay);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
if (!shutdown.get()) {
throw new RuntimeException("Unexpected interrupted exception", e);
}
}
}
}
}
//==================================================================================================================
// CLI Launcher
@Data
private static class CLIParams {
@Parameter(names = {"-r", "--enableRemote"}, description = "Whether to enable remote or not", arity = 1)
private boolean cliEnableRemote;
@Parameter(names = {"-p", "--uiPort"}, description = "Custom HTTP port for UI", arity = 1)
private int cliPort = DEFAULT_UI_PORT;
@Parameter(names = {"-f", "--customStatsFile"}, description = "Path to create custom stats file (remote only)", arity = 1)
private String cliCustomStatsFile;
@Parameter(names = {"-m", "--multiSession"}, description = "Whether to enable multiple separate browser sessions or not", arity = 1)
private boolean cliMultiSession;
}
public void main(String[] args){
CLIParams d = new CLIParams();
new JCommander(d).parse(args);
instancePort = d.getCliPort();
UIServer.getInstance(d.isCliMultiSession(), null);
if(d.isCliEnableRemote()){
try {
File tempStatsFile = ND4JFileUtils.createTempFile("dl4j", "UIstats");
tempStatsFile.delete();
tempStatsFile.deleteOnExit();
enableRemoteListener(new FileStatsStorage(tempStatsFile), true);
} catch(Exception e) {
log.error("Failed to create temporary file for stats storage",e);
System.exit(1);
}
}
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy