org.deeplearning4j.ui.storage.BaseCollectionStatsStorage Maven / Gradle / Ivy
package org.deeplearning4j.ui.storage;
import lombok.AllArgsConstructor;
import lombok.Data;
import org.deeplearning4j.api.storage.*;
import org.jetbrains.annotations.NotNull;
import java.io.*;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
/**
* An implementation of the {@link StatsStorage} interface, backed by MapDB
*
* @author Alex Black
*/
public abstract class BaseCollectionStatsStorage implements StatsStorage {
protected Set sessionIDs;
protected Map storageMetaData;
protected Map staticInfo;
protected Map> updates = new ConcurrentHashMap<>();
protected List listeners = new ArrayList<>();
protected BaseCollectionStatsStorage() {
}
protected abstract Map getUpdateMap(String sessionID, String typeID, String workerID,
boolean createIfRequired);
//Return any relevant storage events
//We want to return these so they can be logged later. Can't be logged immediately, as this may case a race
//condition with whatever is receiving the events: i.e., might get the event before the contents are actually
//available in the DB
protected List checkStorageEvents(Persistable p) {
if (listeners.size() == 0)
return null;
int count = 0;
StatsStorageEvent newSID = null;
StatsStorageEvent newTID = null;
StatsStorageEvent newWID = null;
//Is this a new session ID?
if (!sessionIDs.contains(p.getSessionID())) {
newSID = new StatsStorageEvent(this, StatsStorageListener.EventType.NewSessionID, p.getSessionID(),
p.getTypeID(), p.getWorkerID(), p.getTimeStamp());
count++;
}
//Check for new type and worker IDs
//TODO probably more efficient way to do this
boolean foundTypeId = false;
boolean foundWorkerId = false;
String typeId = p.getTypeID();
String wid = p.getWorkerID();
for (SessionTypeId ts : storageMetaData.keySet()) {
if (typeId.equals(ts.getTypeID())) {
foundTypeId = true;
break;
}
}
for (SessionTypeWorkerId stw : staticInfo.keySet()) {
if (!foundTypeId && typeId.equals(stw.getTypeID())) {
foundTypeId = true;
}
if (!foundWorkerId && wid.equals(stw.getWorkerID())) {
foundWorkerId = true;
}
if (foundTypeId && foundWorkerId)
break;
}
if (!foundTypeId || !foundWorkerId) {
for (SessionTypeWorkerId stw : updates.keySet()) {
if (!foundTypeId && typeId.equals(stw.getTypeID())) {
foundTypeId = true;
}
if (!foundWorkerId && wid.equals(stw.getWorkerID())) {
foundWorkerId = true;
}
if (foundTypeId && foundWorkerId)
break;
}
}
if (!foundTypeId) {
//New type ID
newTID = new StatsStorageEvent(this, StatsStorageListener.EventType.NewTypeID, p.getSessionID(),
p.getTypeID(), p.getWorkerID(), p.getTimeStamp());
count++;
}
if (!foundWorkerId) {
//New worker ID
newWID = new StatsStorageEvent(this, StatsStorageListener.EventType.NewWorkerID, p.getSessionID(),
p.getTypeID(), p.getWorkerID(), p.getTimeStamp());
count++;
}
if (count == 0)
return null;
List sses = new ArrayList<>(count);
if (newSID != null)
sses.add(newSID);
if (newTID != null)
sses.add(newTID);
if (newWID != null)
sses.add(newWID);
return sses;
}
protected void notifyListeners(List sses) {
if (sses == null || sses.size() == 0 || listeners.size() == 0)
return;
for (StatsStorageListener l : listeners) {
for (StatsStorageEvent e : sses) {
l.notify(e);
}
}
}
@Override
public List listSessionIDs() {
return new ArrayList<>(sessionIDs);
}
@Override
public boolean sessionExists(String sessionID) {
return sessionIDs.contains(sessionID);
}
@Override
public Persistable getStaticInfo(String sessionID, String typeID, String workerID) {
SessionTypeWorkerId id = new SessionTypeWorkerId(sessionID, typeID, workerID);
return staticInfo.get(id);
}
@Override
public List getAllStaticInfos(String sessionID, String typeID) {
List out = new ArrayList<>();
for (SessionTypeWorkerId key : staticInfo.keySet()) {
if (sessionID.equals(key.getSessionID()) && typeID.equals(key.getTypeID())) {
out.add(staticInfo.get(key));
}
}
return out;
}
@Override
public List listTypeIDsForSession(String sessionID) {
Set typeIDs = new HashSet<>();
for (SessionTypeId st : storageMetaData.keySet()) {
if (!sessionID.equals(st.getSessionID()))
continue;
typeIDs.add(st.getTypeID());
}
for (SessionTypeWorkerId stw : staticInfo.keySet()) {
if (!sessionID.equals(stw.getSessionID()))
continue;
typeIDs.add(stw.getTypeID());
}
for (SessionTypeWorkerId stw : updates.keySet()) {
if (!sessionID.equals(stw.getSessionID()))
continue;
typeIDs.add(stw.getTypeID());
}
return new ArrayList<>(typeIDs);
}
@Override
public List listWorkerIDsForSession(String sessionID) {
List out = new ArrayList<>();
for (SessionTypeWorkerId ids : staticInfo.keySet()) {
if (sessionID.equals(ids.getSessionID())) {
out.add(ids.getWorkerID());
}
}
return out;
}
@Override
public List listWorkerIDsForSessionAndType(String sessionID, String typeID) {
List out = new ArrayList<>();
for (SessionTypeWorkerId ids : staticInfo.keySet()) {
if (sessionID.equals(ids.getSessionID()) && typeID.equals(ids.getTypeID())) {
out.add(ids.getWorkerID());
}
}
return out;
}
@Override
public int getNumUpdateRecordsFor(String sessionID) {
int count = 0;
for (SessionTypeWorkerId id : updates.keySet()) {
if (sessionID.equals(id.getSessionID())) {
Map map = updates.get(id);
if (map != null)
count += map.size();
}
}
return count;
}
@Override
public int getNumUpdateRecordsFor(String sessionID, String typeID, String workerID) {
SessionTypeWorkerId id = new SessionTypeWorkerId(sessionID, typeID, workerID);
Map map = updates.get(id);
if (map != null)
return map.size();
return 0;
}
@Override
public Persistable getLatestUpdate(String sessionID, String typeID, String workerID) {
SessionTypeWorkerId id = new SessionTypeWorkerId(sessionID, typeID, workerID);
Map map = updates.get(id);
if (map == null || map.isEmpty())
return null;
long maxTime = Long.MIN_VALUE;
for (Long l : map.keySet()) {
maxTime = Math.max(maxTime, l);
}
return map.get(maxTime);
}
@Override
public Persistable getUpdate(String sessionID, String typeID, String workerID, long timestamp) {
SessionTypeWorkerId id = new SessionTypeWorkerId(sessionID, typeID, workerID);
Map map = updates.get(id);
if (map == null)
return null;
return map.get(timestamp);
}
@Override
public List getLatestUpdateAllWorkers(String sessionID, String typeID) {
List list = new ArrayList<>();
for (SessionTypeWorkerId id : updates.keySet()) {
if (sessionID.equals(id.getSessionID()) && typeID.equals(id.getTypeID())) {
Persistable p = getLatestUpdate(sessionID, typeID, id.workerID);
if (p != null) {
list.add(p);
}
}
}
return list;
}
@Override
public List getAllUpdatesAfter(String sessionID, String typeID, String workerID, long timestamp) {
List list = new ArrayList<>();
Map map = getUpdateMap(sessionID, typeID, workerID, false);
if (map == null)
return list;
for (Long time : map.keySet()) {
if (time > timestamp) {
list.add(map.get(time));
}
}
Collections.sort(list, new Comparator() {
@Override
public int compare(Persistable o1, Persistable o2) {
return Long.compare(o1.getTimeStamp(), o2.getTimeStamp());
}
});
return list;
}
@Override
public List getAllUpdatesAfter(String sessionID, String typeID, long timestamp) {
List list = new ArrayList<>();
for (SessionTypeWorkerId stw : staticInfo.keySet()) {
if (stw.getSessionID().equals(sessionID) && stw.getTypeID().equals(typeID)) {
Map u = updates.get(stw);
if (u == null)
continue;
for (long l : u.keySet()) {
if (l > timestamp) {
list.add(u.get(l));
}
}
}
}
//Sort by time stamp
Collections.sort(list, new Comparator() {
@Override
public int compare(Persistable o1, Persistable o2) {
return Long.compare(o1.getTimeStamp(), o2.getTimeStamp());
}
});
return list;
}
@Override
public StorageMetaData getStorageMetaData(String sessionID, String typeID) {
return this.storageMetaData.get(new SessionTypeId(sessionID, typeID));
}
// ----- Store new info -----
@Override
public abstract void putStaticInfo(Persistable staticInfo);
@Override
public void putStaticInfo(Collection extends Persistable> staticInfo) {
for (Persistable p : staticInfo) {
putStaticInfo(p);
}
}
@Override
public abstract void putUpdate(Persistable update);
@Override
public void putUpdate(Collection extends Persistable> updates) {
for (Persistable p : updates) {
putUpdate(p);
}
}
@Override
public abstract void putStorageMetaData(StorageMetaData storageMetaData);
@Override
public void putStorageMetaData(Collection extends StorageMetaData> storageMetaData) {
for (StorageMetaData m : storageMetaData) {
putStorageMetaData(m);
}
}
// ----- Listeners -----
@Override
public void registerStatsStorageListener(StatsStorageListener listener) {
if (!this.listeners.contains(listener)) {
this.listeners.add(listener);
}
}
@Override
public void deregisterStatsStorageListener(StatsStorageListener listener) {
this.listeners.remove(listener);
}
@Override
public void removeAllListeners() {
this.listeners.clear();
}
@Override
public List getListeners() {
return new ArrayList<>(listeners);
}
@Data
public static class SessionTypeWorkerId implements Serializable, Comparable {
private final String sessionID;
private final String typeID;
private final String workerID;
public SessionTypeWorkerId(String sessionID, String typeID, String workerID) {
this.sessionID = sessionID;
this.typeID = typeID;
this.workerID = workerID;
}
@Override
public int compareTo(SessionTypeWorkerId o) {
int c = sessionID.compareTo(o.sessionID);
if (c != 0)
return c;
c = typeID.compareTo(o.typeID);
if (c != 0)
return c;
return workerID.compareTo(workerID);
}
@Override
public String toString() {
return "(" + sessionID + "," + typeID + "," + workerID + ")";
}
}
@AllArgsConstructor
@Data
public static class SessionTypeId implements Serializable, Comparable {
private final String sessionID;
private final String typeID;
@Override
public int compareTo(SessionTypeId o) {
int c = sessionID.compareTo(o.sessionID);
if (c != 0)
return c;
return typeID.compareTo(o.typeID);
}
@Override
public String toString() {
return "(" + sessionID + "," + typeID + ")";
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy