org.deeplearning4j.ui.module.train.TrainModule Maven / Gradle / Ivy
The newest version!
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.ui.module.train;
import freemarker.template.Configuration;
import freemarker.template.Template;
import freemarker.template.TemplateExceptionHandler;
import freemarker.template.Version;
import io.netty.handler.codec.http.HttpResponseStatus;
import io.vertx.ext.web.RoutingContext;
import it.unimi.dsi.fastutil.longs.LongArrayList;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.FilenameUtils;
import org.deeplearning4j.core.storage.Persistable;
import org.deeplearning4j.core.storage.StatsStorage;
import org.deeplearning4j.core.storage.StatsStorageEvent;
import org.deeplearning4j.core.storage.StatsStorageListener;
import org.deeplearning4j.config.DL4JSystemProperties;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.graph.GraphVertex;
import org.deeplearning4j.nn.conf.graph.LayerVertex;
import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.conf.serde.JsonMappers;
import org.deeplearning4j.ui.VertxUIServer;
import org.deeplearning4j.ui.api.HttpMethod;
import org.deeplearning4j.ui.api.I18N;
import org.deeplearning4j.ui.api.Route;
import org.deeplearning4j.ui.api.UIModule;
import org.deeplearning4j.ui.i18n.DefaultI18N;
import org.deeplearning4j.ui.i18n.I18NProvider;
import org.deeplearning4j.ui.i18n.I18NResource;
import org.deeplearning4j.ui.model.stats.StatsListener;
import org.deeplearning4j.ui.model.stats.api.Histogram;
import org.deeplearning4j.ui.model.stats.api.StatsInitializationReport;
import org.deeplearning4j.ui.model.stats.api.StatsReport;
import org.deeplearning4j.ui.model.stats.api.StatsType;
import org.nd4j.common.function.Function;
import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.common.primitives.Pair;
import org.nd4j.common.primitives.Triple;
import org.nd4j.common.resources.Resources;
import org.nd4j.shade.jackson.core.JsonProcessingException;
import java.io.File;
import java.io.StringReader;
import java.io.StringWriter;
import java.nio.charset.StandardCharsets;
import java.text.DateFormat;
import java.text.DecimalFormat;
import java.text.SimpleDateFormat;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
@Slf4j
public class TrainModule implements UIModule {
public static final double NAN_REPLACEMENT_VALUE = 0.0; //UI front-end chokes on NaN in JSON
public static final int DEFAULT_MAX_CHART_POINTS = 512;
private static final DecimalFormat df2 = new DecimalFormat("#.00");
private static DateFormat dateFormat = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss");
private enum ModelType {
MLN, CG, Layer
}
private final int maxChartPoints; //Technically, the way it's set up: won't exceed 2*maxChartPoints
private Map knownSessionIDs = Collections.synchronizedMap(new HashMap<>());
private String currentSessionID;
private int currentWorkerIdx;
private Map workerIdxCount = new ConcurrentHashMap<>(); //Key: session ID
private Map> workerIdxToName = new ConcurrentHashMap<>(); //Key: session ID
private Map lastUpdateForSession = new ConcurrentHashMap<>();
private final Configuration configuration;
/**
* TrainModule
*/
public TrainModule() {
String maxChartPointsProp = System.getProperty(DL4JSystemProperties.CHART_MAX_POINTS_PROPERTY);
int value = DEFAULT_MAX_CHART_POINTS;
if (maxChartPointsProp != null) {
try {
value = Integer.parseInt(maxChartPointsProp);
} catch (NumberFormatException e) {
log.warn("Invalid system property: {} = {}", DL4JSystemProperties.CHART_MAX_POINTS_PROPERTY, maxChartPointsProp);
}
}
if (value >= 10) {
maxChartPoints = value;
} else {
maxChartPoints = DEFAULT_MAX_CHART_POINTS;
}
configuration = new Configuration(new Version(2, 3, 23));
configuration.setDefaultEncoding("UTF-8");
configuration.setLocale(Locale.US);
configuration.setTemplateExceptionHandler(TemplateExceptionHandler.RETHROW_HANDLER);
configuration.setClassForTemplateLoading(TrainModule.class, "");
try {
File dir = Resources.asFile("templates/TrainingOverview.html.ftl").getParentFile();
configuration.setDirectoryForTemplateLoading(dir);
} catch (Throwable t) {
throw new RuntimeException(t);
}
}
@Override
public List getCallbackTypeIDs() {
return Collections.singletonList(StatsListener.TYPE_ID);
}
@Override
public List getRoutes() {
List r = new ArrayList<>();
r.add(new Route("/train/multisession", HttpMethod.GET,
(path, rc) -> rc.response().end(VertxUIServer.getInstance().isMultiSession() ? "true" : "false")));
if (VertxUIServer.getInstance().isMultiSession()) {
r.add(new Route("/train", HttpMethod.GET, (path, rc) -> this.listSessions(rc)));
r.add(new Route("/train/:sessionId", HttpMethod.GET, (path, rc) -> {
rc.response()
.putHeader("location", path.get(0) + "/overview")
.setStatusCode(HttpResponseStatus.FOUND.code())
.end();
}));
r.add(new Route("/train/:sessionId/overview", HttpMethod.GET, (path, rc) -> {
if (knownSessionIDs.containsKey(path.get(0))) {
renderFtl("TrainingOverview.html.ftl", rc);
} else {
sessionNotFound(path.get(0), rc.request().path(), rc);
}
}));
r.add(new Route("/train/:sessionId/overview/data", HttpMethod.GET, (path, rc) -> {
if (knownSessionIDs.containsKey(path.get(0))) {
getOverviewDataForSession(path.get(0), rc);
} else {
sessionNotFound(path.get(0), rc.request().path(), rc);
}
}));
r.add(new Route("/train/:sessionId/model", HttpMethod.GET, (path, rc) -> {
if (knownSessionIDs.containsKey(path.get(0))) {
renderFtl("TrainingModel.html.ftl", rc);
} else {
sessionNotFound(path.get(0), rc.request().path(), rc);
}
}));
r.add(new Route("/train/:sessionId/model/graph", HttpMethod.GET, (path, rc) -> this.getModelGraphForSession(path.get(0), rc)));
r.add(new Route("/train/:sessionId/model/data/:layerId", HttpMethod.GET, (path, rc) -> this.getModelDataForSession(path.get(0), path.get(1), rc)));
r.add(new Route("/train/:sessionId/system", HttpMethod.GET, (path, rc) -> {
if (knownSessionIDs.containsKey(path.get(0))) {
this.renderFtl("TrainingSystem.html.ftl", rc);
} else {
sessionNotFound(path.get(0), rc.request().path(), rc);
}
}));
r.add(new Route("/train/:sessionId/info", HttpMethod.GET, (path, rc) -> this.sessionInfoForSession(path.get(0), rc)));
r.add(new Route("/train/:sessionId/system/data", HttpMethod.GET, (path, rc) -> this.getSystemDataForSession(path.get(0), rc)));
} else {
r.add(new Route("/train", HttpMethod.GET, (path, rc) -> rc.reroute("/train/overview")));
r.add(new Route("/train/sessions/current", HttpMethod.GET, (path, rc) -> rc.response().end(currentSessionID == null ? "" : currentSessionID)));
r.add(new Route("/train/sessions/set/:to", HttpMethod.GET, (path, rc) -> this.setSession(path.get(0), rc)));
r.add(new Route("/train/overview", HttpMethod.GET, (path, rc) -> this.renderFtl("TrainingOverview.html.ftl", rc)));
r.add(new Route("/train/overview/data", HttpMethod.GET, (path, rc) -> this.getOverviewData(rc)));
r.add(new Route("/train/model", HttpMethod.GET, (path, rc) -> this.renderFtl("TrainingModel.html.ftl", rc)));
r.add(new Route("/train/model/graph", HttpMethod.GET, (path, rc) -> this.getModelGraph(rc)));
r.add(new Route("/train/model/data/:layerId", HttpMethod.GET, (path, rc) -> this.getModelData(path.get(0), rc)));
r.add(new Route("/train/system", HttpMethod.GET, (path, rc) -> this.renderFtl("TrainingSystem.html.ftl", rc)));
r.add(new Route("/train/sessions/info", HttpMethod.GET, (path, rc) -> this.sessionInfo(rc)));
r.add(new Route("/train/system/data", HttpMethod.GET, (path, rc) -> this.getSystemData(rc)));
}
// common for single- and multi-session mode
r.add(new Route("/train/sessions/lastUpdate/:sessionId", HttpMethod.GET, (path, rc) -> this.getLastUpdateForSession(path.get(0), rc)));
r.add(new Route("/train/workers/setByIdx/:to", HttpMethod.GET, (path, rc) -> this.setWorkerByIdx(path.get(0), rc)));
return r;
}
/**
* Render a single Freemarker .ftl file from the /templates/ directory
* @param file File to render
* @param rc Routing context
*/
private void renderFtl(String file, RoutingContext rc) {
String sessionId = rc.request().getParam("sessionID");
String langCode = DefaultI18N.getInstance(sessionId).getDefaultLanguage();
Map input = DefaultI18N.getInstance().getMessages(langCode);
String html;
try {
String content = FileUtils.readFileToString(Resources.asFile("templates/" + file), StandardCharsets.UTF_8);
Template template = new Template(FilenameUtils.getName(file), new StringReader(content), configuration);
StringWriter stringWriter = new StringWriter();
template.process(input, stringWriter);
html = stringWriter.toString();
} catch (Throwable t) {
log.error("", t);
throw new RuntimeException(t);
}
rc.response().end(html);
}
/**
* List training sessions. Returns a HTML list of training sessions
*/
private synchronized void listSessions(RoutingContext rc) {
StringBuilder sb = new StringBuilder("\n" +
"\n" +
"\n" +
" \n" +
" Training sessions - DL4J Training UI \n" +
" \n" +
"\n" +
" \n" +
" DL4J Training UI
\n" +
" UI server is in multi-session mode." +
" To visualize a training session, please select one from the following list.
\n" +
" List of attached training sessions
\n");
if (!knownSessionIDs.isEmpty()) {
sb.append(" ");
for (String sessionId : knownSessionIDs.keySet()) {
sb.append(" - ")
.append(sessionId).append("
\n");
}
sb.append("
");
} else {
sb.append("No training session attached.");
}
sb.append(" \n" +
"\n");
rc.response()
.putHeader("content-type", "text/html; charset=utf-8")
.end(sb.toString());
}
/**
* Load StatsStorage via provider, or return "not found"
*
* @param sessionId session ID to look fo with provider
* @param targetPath one of overview / model / system, or null
* @param rc routing context
*/
private void sessionNotFound(String sessionId, String targetPath, RoutingContext rc) {
Function loader = VertxUIServer.getInstance().getStatsStorageLoader();
if (loader != null && loader.apply(sessionId)) {
if (targetPath != null) {
rc.reroute(targetPath);
} else {
rc.response().end();
}
} else {
rc.response().setStatusCode(HttpResponseStatus.NOT_FOUND.code())
.end("Unknown session ID: " + sessionId);
}
}
@Override
public synchronized void reportStorageEvents(Collection events) {
for (StatsStorageEvent sse : events) {
if (StatsListener.TYPE_ID.equals(sse.getTypeID())) {
if (sse.getEventType() == StatsStorageListener.EventType.PostStaticInfo
&& StatsListener.TYPE_ID.equals(sse.getTypeID())
&& !knownSessionIDs.containsKey(sse.getSessionID())) {
knownSessionIDs.put(sse.getSessionID(), sse.getStatsStorage());
if (VertxUIServer.getInstance().isMultiSession()) {
log.info("Adding training session {}/train/{} of StatsStorage instance {}",
VertxUIServer.getInstance().getAddress(), sse.getSessionID(), sse.getStatsStorage());
}
}
Long lastUpdate = lastUpdateForSession.get(sse.getSessionID());
if (lastUpdate == null) {
lastUpdateForSession.put(sse.getSessionID(), sse.getTimestamp());
} else if (sse.getTimestamp() > lastUpdate) {
lastUpdateForSession.put(sse.getSessionID(), sse.getTimestamp()); //Should be thread safe - read only elsewhere
}
}
}
if (currentSessionID == null)
getDefaultSession();
}
@Override
public synchronized void onAttach(StatsStorage statsStorage) {
for (String sessionID : statsStorage.listSessionIDs()) {
for (String typeID : statsStorage.listTypeIDsForSession(sessionID)) {
if (!StatsListener.TYPE_ID.equals(typeID))
continue;
knownSessionIDs.put(sessionID, statsStorage);
if (VertxUIServer.getInstance().isMultiSession()) {
log.info("Adding training session {}/train/{} of StatsStorage instance {}",
VertxUIServer.getInstance().getAddress(), sessionID, statsStorage);
}
List latestUpdates = statsStorage.getLatestUpdateAllWorkers(sessionID, typeID);
for (Persistable update : latestUpdates) {
long updateTime = update.getTimeStamp();
if (lastUpdateForSession.containsKey(sessionID) && lastUpdateForSession.get(sessionID) < updateTime) {
lastUpdateForSession.put(sessionID, updateTime);
}
}
}
}
if (currentSessionID == null)
getDefaultSession();
}
@Override
public synchronized void onDetach(StatsStorage statsStorage) {
Set toRemove = new HashSet<>();
for (String s : knownSessionIDs.keySet()) {
if (knownSessionIDs.get(s) == statsStorage) {
toRemove.add(s);
workerIdxCount.remove(s);
workerIdxToName.remove(s);
currentSessionID = null;
}
}
for (String s : toRemove) {
knownSessionIDs.remove(s);
if (VertxUIServer.getInstance().isMultiSession()) {
log.info("Removing training session {}/train/{} of StatsStorage instance {}.",
VertxUIServer.getInstance().getAddress(), s, statsStorage);
}
lastUpdateForSession.remove(s);
}
getDefaultSession();
}
private synchronized void getDefaultSession() {
if (currentSessionID != null)
return;
long mostRecentTime = Long.MIN_VALUE;
String sessionID = null;
for (Map.Entry entry : knownSessionIDs.entrySet()) {
List staticInfos = entry.getValue().getAllStaticInfos(entry.getKey(), StatsListener.TYPE_ID);
if (staticInfos == null || staticInfos.isEmpty())
continue;
Persistable p = staticInfos.get(0);
long thisTime = p.getTimeStamp();
if (thisTime > mostRecentTime) {
mostRecentTime = thisTime;
sessionID = entry.getKey();
}
}
if (sessionID != null) {
currentSessionID = sessionID;
}
}
private synchronized String getWorkerIdForIndex(String sessionId, int workerIdx) {
if (sessionId == null)
return null;
Map idxToId = workerIdxToName.computeIfAbsent(sessionId, k -> Collections.synchronizedMap(new HashMap<>()));
if (idxToId.containsKey(workerIdx)) {
return idxToId.get(workerIdx);
}
//Need to record new worker...
//Get counter
AtomicInteger counter = workerIdxCount.get(sessionId);
if (counter == null) {
counter = new AtomicInteger(0);
workerIdxCount.put(sessionId, counter);
}
//Get all worker IDs
StatsStorage ss = knownSessionIDs.get(sessionId);
if (ss == null) {
return null;
}
List allWorkerIds = new ArrayList<>(ss.listWorkerIDsForSessionAndType(sessionId, StatsListener.TYPE_ID));
Collections.sort(allWorkerIds);
//Ensure all workers have been assigned an index
for (String s : allWorkerIds) {
if (idxToId.containsValue(s))
continue;
//Unknown worker ID:
idxToId.put(counter.getAndIncrement(), s);
}
//May still return null if index is wrong/too high...
return idxToId.get(workerIdx);
}
/**
* Display, for each session: session ID, start time, number of workers, last update
* Returns info for each session as JSON
*/
private synchronized void sessionInfo(RoutingContext rc) {
Map dataEachSession = new HashMap<>();
for (Map.Entry entry : knownSessionIDs.entrySet()) {
String sid = entry.getKey();
StatsStorage ss = entry.getValue();
Map dataThisSession = sessionData(sid, ss);
dataEachSession.put(sid, dataThisSession);
}
rc.response()
.putHeader("content-type", "application/json")
.end(asJson(dataEachSession));
}
/**
* Extract session data from {@link StatsStorage}
*
* @param sid session ID
* @param ss {@code StatsStorage} instance
* @return session data map
*/
private static Map sessionData(String sid, StatsStorage ss) {
Map dataThisSession = new HashMap<>();
List workerIDs = ss.listWorkerIDsForSessionAndType(sid, StatsListener.TYPE_ID);
int workerCount = (workerIDs == null ? 0 : workerIDs.size());
List staticInfo = ss.getAllStaticInfos(sid, StatsListener.TYPE_ID);
long initTime = Long.MAX_VALUE;
if (staticInfo != null) {
for (Persistable p : staticInfo) {
initTime = Math.min(p.getTimeStamp(), initTime);
}
}
long lastUpdateTime = Long.MIN_VALUE;
List lastUpdatesAllWorkers = ss.getLatestUpdateAllWorkers(sid, StatsListener.TYPE_ID);
for (Persistable p : lastUpdatesAllWorkers) {
lastUpdateTime = Math.max(lastUpdateTime, p.getTimeStamp());
}
dataThisSession.put("numWorkers", workerCount);
dataThisSession.put("initTime", initTime == Long.MAX_VALUE ? "" : initTime);
dataThisSession.put("lastUpdate", lastUpdateTime == Long.MIN_VALUE ? "" : lastUpdateTime);
// add hashmap of workers
if (workerCount > 0) {
dataThisSession.put("workers", workerIDs);
}
//Model info: type, # layers, # params...
if (staticInfo != null && !staticInfo.isEmpty()) {
StatsInitializationReport sr = (StatsInitializationReport) staticInfo.get(0);
String modelClassName = sr.getModelClassName();
if (modelClassName.endsWith("MultiLayerNetwork")) {
modelClassName = "MultiLayerNetwork";
} else if (modelClassName.endsWith("ComputationGraph")) {
modelClassName = "ComputationGraph";
}
int numLayers = sr.getModelNumLayers();
long numParams = sr.getModelNumParams();
dataThisSession.put("modelType", modelClassName);
dataThisSession.put("numLayers", numLayers);
dataThisSession.put("numParams", numParams);
} else {
dataThisSession.put("modelType", "");
dataThisSession.put("numLayers", "");
dataThisSession.put("numParams", "");
}
return dataThisSession;
}
/**
* Display, for given session: session ID, start time, number of workers, last update.
* Returns info for session as JSON
*
* @param sessionId session ID
*/
private synchronized void sessionInfoForSession(String sessionId, RoutingContext rc) {
Map dataEachSession = new HashMap<>();
StatsStorage ss = knownSessionIDs.get(sessionId);
if (ss != null) {
Map dataThisSession = sessionData(sessionId, ss);
dataEachSession.put(sessionId, dataThisSession);
}
rc.response()
.putHeader("content-type", "application/json")
.end(asJson(dataEachSession));
}
private synchronized void setSession(String newSessionID, RoutingContext rc) {
if (knownSessionIDs.containsKey(newSessionID)) {
currentSessionID = newSessionID;
currentWorkerIdx = 0;
rc.response().end();
} else {
rc.response().setStatusCode(HttpResponseStatus.BAD_REQUEST.code()).end();
}
}
private void getLastUpdateForSession(String sessionID, RoutingContext rc) {
Long lastUpdate = lastUpdateForSession.get(sessionID);
if (lastUpdate != null) {
rc.response().end(String.valueOf(lastUpdate));
return;
}
rc.response().end("-1");
}
private void setWorkerByIdx(String newWorkerIdx, RoutingContext rc) {
try {
currentWorkerIdx = Integer.parseInt(newWorkerIdx);
} catch (NumberFormatException e) {
log.debug("Invalid call to setWorkerByIdx", e);
}
rc.response().end();
}
private static double fixNaN(double d) {
return Double.isFinite(d) ? d : NAN_REPLACEMENT_VALUE;
}
private static void cleanLegacyIterationCounts(List iterationCounts) {
if (!iterationCounts.isEmpty()) {
boolean allEqual = true;
int maxStepSize = 1;
int first = iterationCounts.get(0);
int length = iterationCounts.size();
int prevIterCount = first;
for (int i = 1; i < length; i++) {
int currIterCount = iterationCounts.get(i);
if (allEqual && first != currIterCount) {
allEqual = false;
}
maxStepSize = Math.max(maxStepSize, prevIterCount - currIterCount);
prevIterCount = currIterCount;
}
if (allEqual) {
maxStepSize = 1;
}
for (int i = 0; i < length; i++) {
iterationCounts.set(i, first + i * maxStepSize);
}
}
}
/**
* Get last update time for given session ID, checking for null values
*
* @param sessionId session ID
* @return last update time for session if found, or {@code null}
*/
private Long getLastUpdateTime(String sessionId) {
if (lastUpdateForSession != null && sessionId != null && lastUpdateForSession.containsKey(sessionId)) {
return lastUpdateForSession.get(sessionId);
} else {
return -1L;
}
}
/**
* Get global {@link I18N} instance if {@link VertxUIServer#isMultiSession()} is {@code true}, or instance for session
*
* @param sessionId session ID
* @return {@link I18N} instance
*/
private I18N getI18N(String sessionId) {
return VertxUIServer.getInstance().isMultiSession() ? I18NProvider.getInstance(sessionId) : I18NProvider.getInstance();
}
private void getOverviewData(RoutingContext rc) {
getOverviewDataForSession(currentSessionID, rc);
}
private synchronized void getOverviewDataForSession(String sessionId, RoutingContext rc) {
Long lastUpdateTime = getLastUpdateTime(sessionId);
I18N i18N = getI18N(sessionId);
//First pass (optimize later): query all data...
StatsStorage ss = (sessionId == null ? null : knownSessionIDs.get(sessionId));
String wid = getWorkerIdForIndex(sessionId, currentWorkerIdx);
boolean noData = (sessionId == null) || (ss == null) || (wid == null);
List scoresIterCount = new ArrayList<>();
List scores = new ArrayList<>();
Map result = new HashMap<>();
result.put("updateTimestamp", lastUpdateTime);
result.put("scores", scores);
result.put("scoresIter", scoresIterCount);
//Get scores info
long[] allTimes = (noData ? null : ss.getAllUpdateTimes(sessionId, StatsListener.TYPE_ID, wid));
List updates = null;
if (allTimes != null && allTimes.length > maxChartPoints) {
int subsamplingFrequency = allTimes.length / maxChartPoints;
LongArrayList timesToQuery = new LongArrayList(maxChartPoints + 2);
int i = 0;
for (; i < allTimes.length; i += subsamplingFrequency) {
timesToQuery.add(allTimes[i]);
}
if ((i - subsamplingFrequency) != allTimes.length - 1) {
//Also add final point
timesToQuery.add(allTimes[allTimes.length - 1]);
}
updates = ss.getUpdates(sessionId, StatsListener.TYPE_ID, wid, timesToQuery.toLongArray());
} else if (allTimes != null) {
//Don't subsample
updates = ss.getAllUpdatesAfter(sessionId, StatsListener.TYPE_ID, wid, 0);
}
if (updates == null || updates.isEmpty()) {
noData = true;
}
//Collect update ratios for weights
//Collect standard deviations: activations, gradients, updates
Map> updateRatios = new HashMap<>(); //Mean magnitude (updates) / mean magnitude (parameters)
result.put("updateRatios", updateRatios);
Map> stdevActivations = new HashMap<>();
Map> stdevGradients = new HashMap<>();
Map> stdevUpdates = new HashMap<>();
result.put("stdevActivations", stdevActivations);
result.put("stdevGradients", stdevGradients);
result.put("stdevUpdates", stdevUpdates);
if (!noData) {
Persistable u = updates.get(0);
if (u instanceof StatsReport) {
StatsReport sp = (StatsReport) u;
Map map = sp.getMeanMagnitudes(StatsType.Parameters);
if (map != null) {
for (String s : map.keySet()) {
if (!s.toLowerCase().endsWith("w"))
continue; //TODO: more robust "weights only" approach...
updateRatios.put(s, new ArrayList<>());
}
}
Map stdGrad = sp.getStdev(StatsType.Gradients);
if (stdGrad != null) {
for (String s : stdGrad.keySet()) {
if (!s.toLowerCase().endsWith("w"))
continue; //TODO: more robust "weights only" approach...
stdevGradients.put(s, new ArrayList<>());
}
}
Map stdUpdate = sp.getStdev(StatsType.Updates);
if (stdUpdate != null) {
for (String s : stdUpdate.keySet()) {
if (!s.toLowerCase().endsWith("w"))
continue; //TODO: more robust "weights only" approach...
stdevUpdates.put(s, new ArrayList<>());
}
}
Map stdAct = sp.getStdev(StatsType.Activations);
if (stdAct != null) {
for (String s : stdAct.keySet()) {
stdevActivations.put(s, new ArrayList<>());
}
}
}
}
StatsReport last = null;
int lastIterCount = -1;
//Legacy issue - Spark training - iteration counts are used to be reset... which means: could go 0,1,2,0,1,2, etc...
//Or, it could equally go 4,8,4,8,... or 5,5,5,5 - depending on the collection and averaging frequencies
//Now, it should use the proper iteration counts
boolean needToHandleLegacyIterCounts = false;
if (!noData) {
double lastScore;
int totalUpdates = updates.size();
int subsamplingFrequency = 1;
if (totalUpdates > maxChartPoints) {
subsamplingFrequency = totalUpdates / maxChartPoints;
}
int pCount = -1;
int lastUpdateIdx = updates.size() - 1;
for (Persistable u : updates) {
pCount++;
if (!(u instanceof StatsReport))
continue;
last = (StatsReport) u;
int iterCount = last.getIterationCount();
if (iterCount <= lastIterCount) {
needToHandleLegacyIterCounts = true;
}
lastIterCount = iterCount;
if (pCount > 0 && subsamplingFrequency > 1 && pCount % subsamplingFrequency != 0) {
//Skip this - subsample the data
if (pCount != lastUpdateIdx)
continue; //Always keep the most recent value
}
scoresIterCount.add(iterCount);
lastScore = last.getScore();
if (Double.isFinite(lastScore)) {
scores.add(lastScore);
} else {
scores.add(NAN_REPLACEMENT_VALUE);
}
//Update ratios: mean magnitudes(updates) / mean magnitudes (parameters)
Map updateMM = last.getMeanMagnitudes(StatsType.Updates);
Map paramMM = last.getMeanMagnitudes(StatsType.Parameters);
if (updateMM != null && paramMM != null && updateMM.size() > 0 && paramMM.size() > 0) {
for (String s : updateRatios.keySet()) {
List ratioHistory = updateRatios.get(s);
double currUpdate = updateMM.getOrDefault(s, 0.0);
double currParam = paramMM.getOrDefault(s, 0.0);
double ratio = currUpdate / currParam;
if (Double.isFinite(ratio)) {
ratioHistory.add(ratio);
} else {
ratioHistory.add(NAN_REPLACEMENT_VALUE);
}
}
}
//Standard deviations: gradients, updates, activations
Map stdGrad = last.getStdev(StatsType.Gradients);
Map stdUpd = last.getStdev(StatsType.Updates);
Map stdAct = last.getStdev(StatsType.Activations);
if (stdGrad != null) {
for (String s : stdevGradients.keySet()) {
double d = stdGrad.getOrDefault(s, 0.0);
stdevGradients.get(s).add(fixNaN(d));
}
}
if (stdUpd != null) {
for (String s : stdevUpdates.keySet()) {
double d = stdUpd.getOrDefault(s, 0.0);
stdevUpdates.get(s).add(fixNaN(d));
}
}
if (stdAct != null) {
for (String s : stdevActivations.keySet()) {
double d = stdAct.getOrDefault(s, 0.0);
stdevActivations.get(s).add(fixNaN(d));
}
}
}
}
if (needToHandleLegacyIterCounts) {
cleanLegacyIterationCounts(scoresIterCount);
}
//----- Performance Info -----
String[][] perfInfo = new String[][]{{i18N.getMessage("train.overview.perftable.startTime"), ""},
{i18N.getMessage("train.overview.perftable.totalRuntime"), ""},
{i18N.getMessage("train.overview.perftable.lastUpdate"), ""},
{i18N.getMessage("train.overview.perftable.totalParamUpdates"), ""},
{i18N.getMessage("train.overview.perftable.updatesPerSec"), ""},
{i18N.getMessage("train.overview.perftable.examplesPerSec"), ""}};
if (last != null) {
perfInfo[2][1] = String.valueOf(dateFormat.format(new Date(last.getTimeStamp())));
perfInfo[3][1] = String.valueOf(last.getTotalMinibatches());
perfInfo[4][1] = String.valueOf(df2.format(last.getMinibatchesPerSecond()));
perfInfo[5][1] = String.valueOf(df2.format(last.getExamplesPerSecond()));
}
result.put("perf", perfInfo);
// ----- Model Info -----
String[][] modelInfo = new String[][]{{i18N.getMessage("train.overview.modeltable.modeltype"), ""},
{i18N.getMessage("train.overview.modeltable.nLayers"), ""},
{i18N.getMessage("train.overview.modeltable.nParams"), ""}};
if (!noData) {
Persistable p = ss.getStaticInfo(sessionId, StatsListener.TYPE_ID, wid);
if (p != null) {
StatsInitializationReport initReport = (StatsInitializationReport) p;
int nLayers = initReport.getModelNumLayers();
long numParams = initReport.getModelNumParams();
String className = initReport.getModelClassName();
String modelType;
if (className.endsWith("MultiLayerNetwork")) {
modelType = "MultiLayerNetwork";
} else if (className.endsWith("ComputationGraph")) {
modelType = "ComputationGraph";
} else {
modelType = className;
if (modelType.lastIndexOf('.') > 0) {
modelType = modelType.substring(modelType.lastIndexOf('.') + 1);
}
}
modelInfo[0][1] = modelType;
modelInfo[1][1] = String.valueOf(nLayers);
modelInfo[2][1] = String.valueOf(numParams);
}
}
result.put("model", modelInfo);
String json = asJson(result);
rc.response()
.putHeader("content-type", "application/json")
.end(json);
}
private void getModelGraph(RoutingContext rc) {
getModelGraphForSession(currentSessionID, rc);
}
private void getModelGraphForSession(String sessionId, RoutingContext rc) {
boolean noData = (sessionId == null || !knownSessionIDs.containsKey(sessionId));
StatsStorage ss = (noData ? null : knownSessionIDs.get(sessionId));
List allStatic = (noData ? Collections.EMPTY_LIST
: ss.getAllStaticInfos(sessionId, StatsListener.TYPE_ID));
if (allStatic.isEmpty()) {
rc.response().end();
return;
}
TrainModuleUtils.GraphInfo gi = getGraphInfo(getConfig(sessionId));
if (gi == null) {
rc.response().end();
return;
}
String json = asJson(gi);
rc.response()
.putHeader("content-type", "application/json")
.end(json);
}
private TrainModuleUtils.GraphInfo getGraphInfo(Triple conf) {
if (conf == null) {
return null;
}
if (conf.getFirst() != null) {
return TrainModuleUtils.buildGraphInfo(conf.getFirst());
} else if (conf.getSecond() != null) {
return TrainModuleUtils.buildGraphInfo(conf.getSecond());
} else if (conf.getThird() != null) {
return TrainModuleUtils.buildGraphInfo(conf.getThird());
} else {
return null;
}
}
private Triple getConfig(String sessionId) {
boolean noData = (sessionId == null || !knownSessionIDs.containsKey(sessionId));
StatsStorage ss = (noData ? null : knownSessionIDs.get(sessionId));
List allStatic = (noData ? Collections.EMPTY_LIST
: ss.getAllStaticInfos(sessionId, StatsListener.TYPE_ID));
if (allStatic.isEmpty())
return null;
StatsInitializationReport p = (StatsInitializationReport) allStatic.get(0);
String modelClass = p.getModelClassName();
String config = p.getModelConfigJson();
if (modelClass.endsWith("MultiLayerNetwork")) {
MultiLayerConfiguration conf = MultiLayerConfiguration.fromJson(config);
return new Triple<>(conf, null, null);
} else if (modelClass.endsWith("ComputationGraph")) {
ComputationGraphConfiguration conf = ComputationGraphConfiguration.fromJson(config);
return new Triple<>(null, conf, null);
} else {
try {
NeuralNetConfiguration layer =
NeuralNetConfiguration.mapper().readValue(config, NeuralNetConfiguration.class);
return new Triple<>(null, null, layer);
} catch (Exception e) {
log.error("",e);
}
}
return null;
}
private void getModelData(String layerId, RoutingContext rc) {
getModelDataForSession(currentSessionID, layerId, rc);
}
private void getModelDataForSession(String sessionId, String layerId, RoutingContext rc) {
Long lastUpdateTime = getLastUpdateTime(sessionId);
int layerIdx = Integer.parseInt(layerId); //TODO validation
I18N i18N = getI18N(sessionId);
//Model info for layer
//First pass (optimize later): query all data...
StatsStorage ss = (sessionId == null ? null : knownSessionIDs.get(sessionId));
String wid = getWorkerIdForIndex(sessionId, currentWorkerIdx);
boolean noData = (sessionId == null) || (ss == null) || (wid == null);
Map result = new HashMap<>();
result.put("updateTimestamp", lastUpdateTime);
Triple conf = getConfig(sessionId);
if (conf == null) {
rc.response()
.putHeader("content-type", "application/json")
.end(asJson(result));
return;
}
TrainModuleUtils.GraphInfo gi = getGraphInfo(conf);
if (gi == null) {
rc.response()
.putHeader("content-type", "application/json")
.end(asJson(result));
return;
}
// Get static layer info
String[][] layerInfoTable = getLayerInfoTable(sessionId, layerIdx, gi, i18N, noData, ss, wid);
result.put("layerInfo", layerInfoTable);
//First: get all data, and subsample it if necessary, to avoid returning too many points...
long[] allTimes = (noData ? null : ss.getAllUpdateTimes(sessionId, StatsListener.TYPE_ID, wid));
List updates = null;
List iterationCounts = null;
boolean needToHandleLegacyIterCounts = false;
if (allTimes != null && allTimes.length > maxChartPoints) {
int subsamplingFrequency = allTimes.length / maxChartPoints;
LongArrayList timesToQuery = new LongArrayList(maxChartPoints + 2);
int i = 0;
for (; i < allTimes.length; i += subsamplingFrequency) {
timesToQuery.add(allTimes[i]);
}
if ((i - subsamplingFrequency) != allTimes.length - 1) {
//Also add final point
timesToQuery.add(allTimes[allTimes.length - 1]);
}
updates = ss.getUpdates(sessionId, StatsListener.TYPE_ID, wid, timesToQuery.toLongArray());
} else if (allTimes != null) {
//Don't subsample
updates = ss.getAllUpdatesAfter(sessionId, StatsListener.TYPE_ID, wid, 0);
}
iterationCounts = new ArrayList<>(updates.size());
int lastIterCount = -1;
for (Persistable p : updates) {
if (!(p instanceof StatsReport))
continue;
StatsReport sr = (StatsReport) p;
int iterCount = sr.getIterationCount();
if (iterCount <= lastIterCount) {
needToHandleLegacyIterCounts = true;
}
iterationCounts.add(iterCount);
}
//Legacy issue - Spark training - iteration counts are used to be reset... which means: could go 0,1,2,0,1,2, etc...
//Or, it could equally go 4,8,4,8,... or 5,5,5,5 - depending on the collection and averaging frequencies
//Now, it should use the proper iteration counts
if (needToHandleLegacyIterCounts) {
cleanLegacyIterationCounts(iterationCounts);
}
//Get mean magnitudes line chart
ModelType mt;
if (conf.getFirst() != null)
mt = ModelType.MLN;
else if (conf.getSecond() != null)
mt = ModelType.CG;
else
mt = ModelType.Layer;
MeanMagnitudes mm = getLayerMeanMagnitudes(layerIdx, gi, updates, iterationCounts, mt);
Map mmRatioMap = new HashMap<>();
mmRatioMap.put("layerParamNames", mm.getRatios().keySet());
mmRatioMap.put("iterCounts", mm.getIterations());
mmRatioMap.put("ratios", mm.getRatios());
mmRatioMap.put("paramMM", mm.getParamMM());
mmRatioMap.put("updateMM", mm.getUpdateMM());
result.put("meanMag", mmRatioMap);
//Get activations line chart for layer
Triple activationsData = getLayerActivations(layerIdx, gi, updates, iterationCounts);
Map activationMap = new HashMap<>();
activationMap.put("iterCount", activationsData.getFirst());
activationMap.put("mean", activationsData.getSecond());
activationMap.put("stdev", activationsData.getThird());
result.put("activations", activationMap);
//Get learning rate vs. time chart for layer
Map lrs = getLayerLearningRates(layerIdx, gi, updates, iterationCounts, mt);
result.put("learningRates", lrs);
//Parameters histogram data
Persistable lastUpdate = (updates != null && !updates.isEmpty() ? updates.get(updates.size() - 1) : null);
Map paramHistograms = getHistograms(layerIdx, gi, StatsType.Parameters, lastUpdate);
result.put("paramHist", paramHistograms);
//Updates histogram data
Map updateHistograms = getHistograms(layerIdx, gi, StatsType.Updates, lastUpdate);
result.put("updateHist", updateHistograms);
rc.response()
.putHeader("content-type", "application/json")
.end(asJson(result));
}
private void getSystemData(RoutingContext rc) {
getSystemDataForSession(currentSessionID, rc);
}
private void getSystemDataForSession(String sessionId, RoutingContext rc) {
Long lastUpdate = getLastUpdateTime(sessionId);
I18N i18n = getI18N(sessionId);
//First: get the MOST RECENT update...
//Then get all updates from most recent - 5 minutes -> TODO make this configurable...
StatsStorage ss = (sessionId == null ? null : knownSessionIDs.get(sessionId));
boolean noData = (ss == null);
List allStatic = (noData ? Collections.EMPTY_LIST
: ss.getAllStaticInfos(sessionId, StatsListener.TYPE_ID));
List latestUpdates = (noData ? Collections.EMPTY_LIST
: ss.getLatestUpdateAllWorkers(sessionId, StatsListener.TYPE_ID));
long lastUpdateTime = -1;
if (latestUpdates == null || latestUpdates.isEmpty()) {
noData = true;
} else {
for (Persistable p : latestUpdates) {
lastUpdateTime = Math.max(lastUpdateTime, p.getTimeStamp());
}
}
long fromTime = lastUpdateTime - 5 * 60 * 1000; //TODO Make configurable
List lastNMinutes =
(noData ? null : ss.getAllUpdatesAfter(sessionId, StatsListener.TYPE_ID, fromTime));
Map mem = getMemory(allStatic, lastNMinutes, i18n);
Pair
© 2015 - 2024 Weber Informatics LLC | Privacy Policy