All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.ui.storage.BaseCollectionStatsStorage Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
package org.deeplearning4j.ui.storage;

import lombok.AllArgsConstructor;
import lombok.Data;
import org.deeplearning4j.api.storage.*;
import org.jetbrains.annotations.NotNull;

import java.io.*;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;

/**
 * An implementation of the {@link StatsStorage} interface, backed by MapDB
 *
 * @author Alex Black
 */
public abstract class BaseCollectionStatsStorage implements StatsStorage {

    protected Set sessionIDs;
    protected Map storageMetaData;
    protected Map staticInfo;

    protected Map> updates = new ConcurrentHashMap<>();

    protected List listeners = new ArrayList<>();

    protected BaseCollectionStatsStorage() {

    }

    protected abstract Map getUpdateMap(String sessionID, String typeID, String workerID,
                    boolean createIfRequired);

    //Return any relevant storage events
    //We want to return these so they can be logged later. Can't be logged immediately, as this may case a race
    //condition with whatever is receiving the events: i.e., might get the event before the contents are actually
    //available in the DB
    protected List checkStorageEvents(Persistable p) {
        if (listeners.size() == 0)
            return null;

        int count = 0;
        StatsStorageEvent newSID = null;
        StatsStorageEvent newTID = null;
        StatsStorageEvent newWID = null;

        //Is this a new session ID?
        if (!sessionIDs.contains(p.getSessionID())) {
            newSID = new StatsStorageEvent(this, StatsStorageListener.EventType.NewSessionID, p.getSessionID(),
                            p.getTypeID(), p.getWorkerID(), p.getTimeStamp());
            count++;
        }

        //Check for new type and worker IDs
        //TODO probably more efficient way to do this
        boolean foundTypeId = false;
        boolean foundWorkerId = false;
        String typeId = p.getTypeID();
        String wid = p.getWorkerID();
        for (SessionTypeId ts : storageMetaData.keySet()) {
            if (typeId.equals(ts.getTypeID())) {
                foundTypeId = true;
                break;
            }
        }
        for (SessionTypeWorkerId stw : staticInfo.keySet()) {
            if (!foundTypeId && typeId.equals(stw.getTypeID())) {
                foundTypeId = true;
            }
            if (!foundWorkerId && wid.equals(stw.getWorkerID())) {
                foundWorkerId = true;
            }
            if (foundTypeId && foundWorkerId)
                break;
        }
        if (!foundTypeId || !foundWorkerId) {
            for (SessionTypeWorkerId stw : updates.keySet()) {
                if (!foundTypeId && typeId.equals(stw.getTypeID())) {
                    foundTypeId = true;
                }
                if (!foundWorkerId && wid.equals(stw.getWorkerID())) {
                    foundWorkerId = true;
                }
                if (foundTypeId && foundWorkerId)
                    break;
            }
        }
        if (!foundTypeId) {
            //New type ID
            newTID = new StatsStorageEvent(this, StatsStorageListener.EventType.NewTypeID, p.getSessionID(),
                            p.getTypeID(), p.getWorkerID(), p.getTimeStamp());
            count++;
        }
        if (!foundWorkerId) {
            //New worker ID
            newWID = new StatsStorageEvent(this, StatsStorageListener.EventType.NewWorkerID, p.getSessionID(),
                            p.getTypeID(), p.getWorkerID(), p.getTimeStamp());
            count++;
        }
        if (count == 0)
            return null;
        List sses = new ArrayList<>(count);
        if (newSID != null)
            sses.add(newSID);
        if (newTID != null)
            sses.add(newTID);
        if (newWID != null)
            sses.add(newWID);
        return sses;
    }

    protected void notifyListeners(List sses) {
        if (sses == null || sses.size() == 0 || listeners.size() == 0)
            return;
        for (StatsStorageListener l : listeners) {
            for (StatsStorageEvent e : sses) {
                l.notify(e);
            }
        }
    }

    @Override
    public List listSessionIDs() {
        return new ArrayList<>(sessionIDs);
    }

    @Override
    public boolean sessionExists(String sessionID) {
        return sessionIDs.contains(sessionID);
    }

    @Override
    public Persistable getStaticInfo(String sessionID, String typeID, String workerID) {
        SessionTypeWorkerId id = new SessionTypeWorkerId(sessionID, typeID, workerID);
        return staticInfo.get(id);
    }

    @Override
    public List getAllStaticInfos(String sessionID, String typeID) {
        List out = new ArrayList<>();
        for (SessionTypeWorkerId key : staticInfo.keySet()) {
            if (sessionID.equals(key.getSessionID()) && typeID.equals(key.getTypeID())) {
                out.add(staticInfo.get(key));
            }
        }
        return out;
    }

    @Override
    public List listTypeIDsForSession(String sessionID) {
        Set typeIDs = new HashSet<>();
        for (SessionTypeId st : storageMetaData.keySet()) {
            if (!sessionID.equals(st.getSessionID()))
                continue;
            typeIDs.add(st.getTypeID());
        }

        for (SessionTypeWorkerId stw : staticInfo.keySet()) {
            if (!sessionID.equals(stw.getSessionID()))
                continue;
            typeIDs.add(stw.getTypeID());
        }
        for (SessionTypeWorkerId stw : updates.keySet()) {
            if (!sessionID.equals(stw.getSessionID()))
                continue;
            typeIDs.add(stw.getTypeID());
        }

        return new ArrayList<>(typeIDs);
    }

    @Override
    public List listWorkerIDsForSession(String sessionID) {
        List out = new ArrayList<>();
        for (SessionTypeWorkerId ids : staticInfo.keySet()) {
            if (sessionID.equals(ids.getSessionID())) {
                out.add(ids.getWorkerID());
            }
        }
        return out;
    }

    @Override
    public List listWorkerIDsForSessionAndType(String sessionID, String typeID) {
        List out = new ArrayList<>();
        for (SessionTypeWorkerId ids : staticInfo.keySet()) {
            if (sessionID.equals(ids.getSessionID()) && typeID.equals(ids.getTypeID())) {
                out.add(ids.getWorkerID());
            }
        }
        return out;
    }

    @Override
    public int getNumUpdateRecordsFor(String sessionID) {
        int count = 0;
        for (SessionTypeWorkerId id : updates.keySet()) {
            if (sessionID.equals(id.getSessionID())) {
                Map map = updates.get(id);
                if (map != null)
                    count += map.size();
            }
        }
        return count;
    }

    @Override
    public int getNumUpdateRecordsFor(String sessionID, String typeID, String workerID) {
        SessionTypeWorkerId id = new SessionTypeWorkerId(sessionID, typeID, workerID);
        Map map = updates.get(id);
        if (map != null)
            return map.size();
        return 0;
    }

    @Override
    public Persistable getLatestUpdate(String sessionID, String typeID, String workerID) {
        SessionTypeWorkerId id = new SessionTypeWorkerId(sessionID, typeID, workerID);
        Map map = updates.get(id);
        if (map == null || map.isEmpty())
            return null;
        long maxTime = Long.MIN_VALUE;
        for (Long l : map.keySet()) {
            maxTime = Math.max(maxTime, l);
        }
        return map.get(maxTime);
    }

    @Override
    public Persistable getUpdate(String sessionID, String typeID, String workerID, long timestamp) {
        SessionTypeWorkerId id = new SessionTypeWorkerId(sessionID, typeID, workerID);
        Map map = updates.get(id);
        if (map == null)
            return null;

        return map.get(timestamp);
    }

    @Override
    public List getLatestUpdateAllWorkers(String sessionID, String typeID) {
        List list = new ArrayList<>();

        for (SessionTypeWorkerId id : updates.keySet()) {
            if (sessionID.equals(id.getSessionID()) && typeID.equals(id.getTypeID())) {
                Persistable p = getLatestUpdate(sessionID, typeID, id.workerID);
                if (p != null) {
                    list.add(p);
                }
            }
        }

        return list;
    }

    @Override
    public List getAllUpdatesAfter(String sessionID, String typeID, String workerID, long timestamp) {
        List list = new ArrayList<>();

        Map map = getUpdateMap(sessionID, typeID, workerID, false);
        if (map == null)
            return list;

        for (Long time : map.keySet()) {
            if (time > timestamp) {
                list.add(map.get(time));
            }
        }

        Collections.sort(list, new Comparator() {
            @Override
            public int compare(Persistable o1, Persistable o2) {
                return Long.compare(o1.getTimeStamp(), o2.getTimeStamp());
            }
        });

        return list;
    }

    @Override
    public List getAllUpdatesAfter(String sessionID, String typeID, long timestamp) {
        List list = new ArrayList<>();

        for (SessionTypeWorkerId stw : staticInfo.keySet()) {
            if (stw.getSessionID().equals(sessionID) && stw.getTypeID().equals(typeID)) {
                Map u = updates.get(stw);
                if (u == null)
                    continue;
                for (long l : u.keySet()) {
                    if (l > timestamp) {
                        list.add(u.get(l));
                    }
                }
            }
        }

        //Sort by time stamp
        Collections.sort(list, new Comparator() {
            @Override
            public int compare(Persistable o1, Persistable o2) {
                return Long.compare(o1.getTimeStamp(), o2.getTimeStamp());
            }
        });

        return list;
    }

    @Override
    public StorageMetaData getStorageMetaData(String sessionID, String typeID) {
        return this.storageMetaData.get(new SessionTypeId(sessionID, typeID));
    }

    // ----- Store new info -----

    @Override
    public abstract void putStaticInfo(Persistable staticInfo);

    @Override
    public void putStaticInfo(Collection staticInfo) {
        for (Persistable p : staticInfo) {
            putStaticInfo(p);
        }
    }

    @Override
    public abstract void putUpdate(Persistable update);

    @Override
    public void putUpdate(Collection updates) {
        for (Persistable p : updates) {
            putUpdate(p);
        }
    }

    @Override
    public abstract void putStorageMetaData(StorageMetaData storageMetaData);

    @Override
    public void putStorageMetaData(Collection storageMetaData) {
        for (StorageMetaData m : storageMetaData) {
            putStorageMetaData(m);
        }
    }


    // ----- Listeners -----

    @Override
    public void registerStatsStorageListener(StatsStorageListener listener) {
        if (!this.listeners.contains(listener)) {
            this.listeners.add(listener);
        }
    }

    @Override
    public void deregisterStatsStorageListener(StatsStorageListener listener) {
        this.listeners.remove(listener);
    }

    @Override
    public void removeAllListeners() {
        this.listeners.clear();
    }

    @Override
    public List getListeners() {
        return new ArrayList<>(listeners);
    }

    @Data
    public static class SessionTypeWorkerId implements Serializable, Comparable {
        private final String sessionID;
        private final String typeID;
        private final String workerID;

        public SessionTypeWorkerId(String sessionID, String typeID, String workerID) {
            this.sessionID = sessionID;
            this.typeID = typeID;
            this.workerID = workerID;
        }

        @Override
        public int compareTo(SessionTypeWorkerId o) {
            int c = sessionID.compareTo(o.sessionID);
            if (c != 0)
                return c;
            c = typeID.compareTo(o.typeID);
            if (c != 0)
                return c;
            return workerID.compareTo(workerID);
        }

        @Override
        public String toString() {
            return "(" + sessionID + "," + typeID + "," + workerID + ")";
        }
    }

    @AllArgsConstructor
    @Data
    public static class SessionTypeId implements Serializable, Comparable {
        private final String sessionID;
        private final String typeID;

        @Override
        public int compareTo(SessionTypeId o) {
            int c = sessionID.compareTo(o.sessionID);
            if (c != 0)
                return c;
            return typeID.compareTo(o.typeID);
        }

        @Override
        public String toString() {
            return "(" + sessionID + "," + typeID + ")";
        }
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy