All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.ui.stats.impl.SbeStatsInitializationReport Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
package org.deeplearning4j.ui.stats.impl;

import lombok.Data;
import org.agrona.DirectBuffer;
import org.agrona.MutableDirectBuffer;
import org.agrona.concurrent.UnsafeBuffer;
import org.apache.commons.io.IOUtils;
import org.deeplearning4j.ui.stats.api.StatsInitializationReport;
import org.deeplearning4j.ui.stats.sbe.*;
import org.deeplearning4j.ui.storage.AgronaPersistable;

import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.ByteBuffer;
import java.util.HashMap;
import java.util.Map;

/**
 * An implementation of {@link StatsInitializationReport} using Simple Binary Encoding (SBE)
 *
 * @author Alex Black
 */
@Data
public class SbeStatsInitializationReport implements StatsInitializationReport, AgronaPersistable {

    private String sessionID;
    private String typeID;
    private String workerID;
    private long timeStamp;

    private boolean hasSoftwareInfo;
    private boolean hasHardwareInfo;
    private boolean hasModelInfo;

    private String swArch;
    private String swOsName;
    private String swJvmName;
    private String swJvmVersion;
    private String swJvmSpecVersion;
    private String swNd4jBackendClass;
    private String swNd4jDataTypeName;
    private String swHostName;
    private String swJvmUID;
    private Map swEnvironmentInfo;

    private int hwJvmAvailableProcessors;
    private int hwNumDevices;
    private long hwJvmMaxMemory;
    private long hwOffHeapMaxMemory;
    private long[] hwDeviceTotalMemory;
    private String[] hwDeviceDescription;
    private String hwHardwareUID;

    private String modelClassName;
    private String modelConfigJson;
    private String[] modelParamNames;
    private int modelNumLayers;
    private long modelNumParams;

    @Override
    public void reportIDs(String sessionID, String typeID, String workerID, long timeStamp) {
        this.sessionID = sessionID;
        this.typeID = typeID;
        this.workerID = workerID;
        this.timeStamp = timeStamp;
    }

    @Override
    public void reportSoftwareInfo(String arch, String osName, String jvmName, String jvmVersion, String jvmSpecVersion,
                    String nd4jBackendClass, String nd4jDataTypeName, String hostname, String jvmUid,
                    Map swEnvironmentInfo) {
        this.swArch = arch;
        this.swOsName = osName;
        this.swJvmName = jvmName;
        this.swJvmVersion = jvmVersion;
        this.swJvmSpecVersion = jvmSpecVersion;
        this.swNd4jBackendClass = nd4jBackendClass;
        this.swNd4jDataTypeName = nd4jDataTypeName;
        this.swHostName = hostname;
        this.swJvmUID = jvmUid;
        this.swEnvironmentInfo = swEnvironmentInfo;
        hasSoftwareInfo = true;
    }

    @Override
    public void reportHardwareInfo(int jvmAvailableProcessors, int numDevices, long jvmMaxMemory, long offHeapMaxMemory,
                    long[] deviceTotalMemory, String[] deviceDescription, String hardwareUID) {
        this.hwJvmAvailableProcessors = jvmAvailableProcessors;
        this.hwNumDevices = numDevices;
        this.hwJvmMaxMemory = jvmMaxMemory;
        this.hwOffHeapMaxMemory = offHeapMaxMemory;
        this.hwDeviceTotalMemory = deviceTotalMemory;
        this.hwDeviceDescription = deviceDescription;
        this.hwHardwareUID = hardwareUID;
        hasHardwareInfo = true;
    }

    @Override
    public void reportModelInfo(String modelClassName, String modelConfigJson, String[] modelParamNames, int numLayers,
                    long numParams) {
        this.modelClassName = modelClassName;
        this.modelConfigJson = modelConfigJson;
        this.modelParamNames = modelParamNames;
        this.modelNumLayers = numLayers;
        this.modelNumParams = numParams;
        hasModelInfo = true;
    }

    @Override
    public boolean hasSoftwareInfo() {
        return hasSoftwareInfo;
    }

    @Override
    public boolean hasHardwareInfo() {
        return hasHardwareInfo;
    }

    @Override
    public boolean hasModelInfo() {
        return hasModelInfo;
    }



    private void clearHwFields() {
        hwDeviceTotalMemory = null;
        hwDeviceDescription = null;
        hwHardwareUID = null;
    }

    private void clearSwFields() {
        swArch = null;
        swOsName = null;
        swJvmName = null;
        swJvmVersion = null;
        swJvmSpecVersion = null;
        swNd4jBackendClass = null;
        swNd4jDataTypeName = null;
        swHostName = null;
        swJvmUID = null;
    }

    private void clearModelFields() {
        modelClassName = null;
        modelConfigJson = null;
        modelParamNames = null;
    }

    @Override
    public String getSessionID() {
        return sessionID;
    }

    @Override
    public String getTypeID() {
        return typeID;
    }

    @Override
    public String getWorkerID() {
        return workerID;
    }

    @Override
    public long getTimeStamp() {
        return timeStamp;
    }

    @Override
    public int encodingLengthBytes() {
        //TODO reuse the byte[]s here, to avoid converting them twice...

        //First: need to determine how large a buffer to use.
        //Buffer is composed of:
        //(a) Header: 8 bytes (4x uint16 = 8 bytes)
        //(b) Fixed length entries length (sie.BlockLength())
        //(c) Group 1: Hardware devices (GPUs) max memory: 4 bytes header + nEntries * 8 (int64) + nEntries * variable length Strings (header + content)  = 4 + 8*n + content
        //(d) Group 2: Software device info: 4 bytes header + 2x variable length Strings for each
        //(d) Group 3: Parameter names: 4 bytes header + nEntries * variable length strings (header + content) = 4 + content
        //(e) Variable length fields: 15 String length fields. Size: 4 bytes header, plus content. 60 bytes header
        //Fixed length + repeating groups + variable length...
        StaticInfoEncoder sie = new StaticInfoEncoder();
        int bufferSize = 8 + sie.sbeBlockLength() + 4 + 4 + 60; //header + fixed values + group headers + variable length headers

        //For variable length field lengths: easist way is simply to convert to UTF-8
        //Of course, it is possible to calculate it first - but we might as well convert (1 pass), rather than count then convert (2 passes)
        byte[] bSessionId = SbeUtil.toBytes(true, sessionID);
        byte[] bTypeId = SbeUtil.toBytes(true, typeID);
        byte[] bWorkerId = SbeUtil.toBytes(true, workerID);

        byte[] bswArch = SbeUtil.toBytes(hasSoftwareInfo, swArch);
        byte[] bswOsName = SbeUtil.toBytes(hasSoftwareInfo, swOsName);
        byte[] bswJvmName = SbeUtil.toBytes(hasSoftwareInfo, swJvmName);
        byte[] bswJvmVersion = SbeUtil.toBytes(hasSoftwareInfo, swJvmVersion);
        byte[] bswJvmSpecVersion = SbeUtil.toBytes(hasSoftwareInfo, swJvmSpecVersion);
        byte[] bswNd4jBackendClass = SbeUtil.toBytes(hasSoftwareInfo, swNd4jBackendClass);
        byte[] bswNd4jDataTypeName = SbeUtil.toBytes(hasSoftwareInfo, swNd4jDataTypeName);
        byte[] bswHostname = SbeUtil.toBytes(hasSoftwareInfo, swHostName);
        byte[] bswJvmUID = SbeUtil.toBytes(hasSoftwareInfo, swJvmUID);
        byte[] bHwHardwareUID = SbeUtil.toBytes(hasHardwareInfo, hwHardwareUID);
        byte[] bmodelConfigClass = SbeUtil.toBytes(hasModelInfo, modelClassName);
        byte[] bmodelConfigJson = SbeUtil.toBytes(hasModelInfo, modelConfigJson);

        byte[][] bhwDeviceDescription = SbeUtil.toBytes(hasHardwareInfo, hwDeviceDescription);
        byte[][][] bswEnvInfo = SbeUtil.toBytes(swEnvironmentInfo);
        byte[][] bModelParamNames = SbeUtil.toBytes(hasModelInfo, modelParamNames);



        bufferSize += bSessionId.length + bTypeId.length + bWorkerId.length;

        bufferSize += 4; //swEnvironmentInfo group header (always present)
        if (hasSoftwareInfo) {
            bufferSize += SbeUtil.length(bswArch);
            bufferSize += SbeUtil.length(bswOsName);
            bufferSize += SbeUtil.length(bswJvmName);
            bufferSize += SbeUtil.length(bswJvmVersion);
            bufferSize += SbeUtil.length(bswJvmSpecVersion);
            bufferSize += SbeUtil.length(bswNd4jBackendClass);
            bufferSize += SbeUtil.length(bswNd4jDataTypeName);
            bufferSize += SbeUtil.length(bswHostname);
            bufferSize += SbeUtil.length(bswJvmUID);
            //For each entry: 2 variable-length headers (2x4 bytes each) + content
            int envCount = (bswEnvInfo != null ? bswEnvInfo.length : 0);
            bufferSize += envCount * 8;
            bufferSize += SbeUtil.length(bswEnvInfo);
        }
        int nHWDeviceStats = hwNumDevices;
        if (!hasHardwareInfo)
            nHWDeviceStats = 0;
        if (hasHardwareInfo) {
            //Device info group:
            bufferSize += hwNumDevices * 8; //fixed content in group: int64 -> 8 bytes. Encode an entry, even if hwDeviceTotalMemory is null
            bufferSize += hwNumDevices * 4; //uint32: 4 bytes per entry for var length header...; as above
            bufferSize += SbeUtil.length(bhwDeviceDescription);
            bufferSize += SbeUtil.length(bHwHardwareUID);
        }
        if (hasModelInfo) {
            bufferSize += SbeUtil.length(bmodelConfigClass);
            bufferSize += SbeUtil.length(bmodelConfigJson);
            bufferSize += SbeUtil.length(bModelParamNames);
            bufferSize += (bModelParamNames == null ? 0 : bModelParamNames.length * 4); //uint32: 4 bytes per entry for var length header...
        }

        return bufferSize;
    }

    @Override
    public byte[] encode() {
        byte[] bytes = new byte[encodingLengthBytes()];
        MutableDirectBuffer buffer = new UnsafeBuffer(bytes);
        encode(buffer);
        return bytes;
    }

    @Override
    public void encode(ByteBuffer buffer) {
        encode(new UnsafeBuffer(buffer));
    }

    @Override
    public void encode(MutableDirectBuffer buffer) {

        MessageHeaderEncoder enc = new MessageHeaderEncoder();
        StaticInfoEncoder sie = new StaticInfoEncoder();

        byte[] bSessionId = SbeUtil.toBytes(true, sessionID);
        byte[] bTypeId = SbeUtil.toBytes(true, typeID);
        byte[] bWorkerId = SbeUtil.toBytes(true, workerID);

        byte[] bswArch = SbeUtil.toBytes(hasSoftwareInfo, swArch);
        byte[] bswOsName = SbeUtil.toBytes(hasSoftwareInfo, swOsName);
        byte[] bswJvmName = SbeUtil.toBytes(hasSoftwareInfo, swJvmName);
        byte[] bswJvmVersion = SbeUtil.toBytes(hasSoftwareInfo, swJvmVersion);
        byte[] bswJvmSpecVersion = SbeUtil.toBytes(hasSoftwareInfo, swJvmSpecVersion);
        byte[] bswNd4jBackendClass = SbeUtil.toBytes(hasSoftwareInfo, swNd4jBackendClass);
        byte[] bswNd4jDataTypeName = SbeUtil.toBytes(hasSoftwareInfo, swNd4jDataTypeName);
        byte[] bswHostname = SbeUtil.toBytes(hasSoftwareInfo, swHostName);
        byte[] bswJvmUID = SbeUtil.toBytes(hasSoftwareInfo, swJvmUID);
        byte[] bHwHardwareUID = SbeUtil.toBytes(hasHardwareInfo, hwHardwareUID);
        byte[] bmodelConfigClass = SbeUtil.toBytes(hasModelInfo, modelClassName);
        byte[] bmodelConfigJson = SbeUtil.toBytes(hasModelInfo, modelConfigJson);

        byte[][] bhwDeviceDescription = SbeUtil.toBytes(hasHardwareInfo, hwDeviceDescription);
        byte[][][] bswEnvInfo = SbeUtil.toBytes(swEnvironmentInfo);
        byte[][] bModelParamNames = SbeUtil.toBytes(hasModelInfo, modelParamNames);

        enc.wrap(buffer, 0).blockLength(sie.sbeBlockLength()).templateId(sie.sbeTemplateId())
                        .schemaId(sie.sbeSchemaId()).version(sie.sbeSchemaVersion());

        int offset = enc.encodedLength(); //Expect 8 bytes...

        //Fixed length fields: always encoded, whether present or not.
        sie.wrap(buffer, offset).time(timeStamp).fieldsPresent().softwareInfo(hasSoftwareInfo)
                        .hardwareInfo(hasHardwareInfo).modelInfo(hasModelInfo);
        sie.hwJvmProcessors(hwJvmAvailableProcessors).hwNumDevices((short) hwNumDevices).hwJvmMaxMemory(hwJvmMaxMemory)
                        .hwOffheapMaxMemory(hwOffHeapMaxMemory).modelNumLayers(modelNumLayers)
                        .modelNumParams(modelNumParams);
        //Device info group...
        StaticInfoEncoder.HwDeviceInfoGroupEncoder hwdEnc = sie.hwDeviceInfoGroupCount(hwNumDevices);
        int nHWDeviceStats = (hasHardwareInfo ? hwNumDevices : 0);
        for (int i = 0; i < nHWDeviceStats; i++) {
            long maxMem = hwDeviceTotalMemory == null || hwDeviceTotalMemory.length <= i ? 0 : hwDeviceTotalMemory[i];
            byte[] descr = bhwDeviceDescription == null || bhwDeviceDescription.length <= i ? SbeUtil.EMPTY_BYTES
                            : bhwDeviceDescription[i];
            if (descr == null)
                descr = SbeUtil.EMPTY_BYTES;
            hwdEnc.next().deviceMemoryMax(maxMem).putDeviceDescription(descr, 0, descr.length);
        }

        //Environment info group
        int numEnvValues = (hasSoftwareInfo && swEnvironmentInfo != null ? swEnvironmentInfo.size() : 0);
        StaticInfoEncoder.SwEnvironmentInfoEncoder swEnv = sie.swEnvironmentInfoCount(numEnvValues);
        if (numEnvValues > 0) {
            byte[][][] mapAsBytes = SbeUtil.toBytes(swEnvironmentInfo);
            for (byte[][] entryBytes : mapAsBytes) {
                swEnv.next().putEnvKey(entryBytes[0], 0, entryBytes[0].length).putEnvValue(entryBytes[1], 0,
                                entryBytes[1].length);
            }
        }

        int nParamNames = modelParamNames == null ? 0 : modelParamNames.length;
        StaticInfoEncoder.ModelParamNamesEncoder mpnEnc = sie.modelParamNamesCount(nParamNames);
        for (int i = 0; i < nParamNames; i++) {
            mpnEnc.next().putModelParamNames(bModelParamNames[i], 0, bModelParamNames[i].length);
        }

        //In the case of !hasSoftwareInfo: these will all be empty byte arrays... still need to encode them (for 0 length) however
        sie.putSessionID(bSessionId, 0, bSessionId.length).putTypeID(bTypeId, 0, bTypeId.length)
                        .putWorkerID(bWorkerId, 0, bWorkerId.length).putSwArch(bswArch, 0, bswArch.length)
                        .putSwOsName(bswOsName, 0, bswOsName.length).putSwJvmName(bswJvmName, 0, bswJvmName.length)
                        .putSwJvmVersion(bswJvmVersion, 0, bswJvmVersion.length)
                        .putSwJvmSpecVersion(bswJvmSpecVersion, 0, bswJvmSpecVersion.length)
                        .putSwNd4jBackendClass(bswNd4jBackendClass, 0, bswNd4jBackendClass.length)
                        .putSwNd4jDataTypeName(bswNd4jDataTypeName, 0, bswNd4jDataTypeName.length)
                        .putSwHostName(bswHostname, 0, bswHostname.length).putSwJvmUID(bswJvmUID, 0, bswJvmUID.length)
                        .putHwHardwareUID(bHwHardwareUID, 0, bHwHardwareUID.length);
        //Similar: !hasModelInfo -> empty byte[]
        sie.putModelConfigClassName(bmodelConfigClass, 0, bmodelConfigClass.length).putModelConfigJson(bmodelConfigJson,
                        0, bmodelConfigJson.length);
    }

    @Override
    public void encode(OutputStream outputStream) throws IOException {
        //TODO there may be more efficient way of doing this
        outputStream.write(encode());
    }

    @Override
    public void decode(byte[] decode) {
        MutableDirectBuffer buffer = new UnsafeBuffer(decode);
        decode(buffer);
    }

    @Override
    public void decode(ByteBuffer buffer) {
        decode(new UnsafeBuffer(buffer));
    }

    @Override
    public void decode(DirectBuffer buffer) {
        //TODO we could do this much more efficiently, with buffer re-use, etc.
        MessageHeaderDecoder dec = new MessageHeaderDecoder();
        StaticInfoDecoder sid = new StaticInfoDecoder();
        dec.wrap(buffer, 0);

        final int blockLength = dec.blockLength();
        final int version = dec.version();

        final int headerLength = dec.encodedLength();
        //TODO: in general, we should check the header, version, schema etc. But we don't have any other versions yet.

        sid.wrap(buffer, headerLength, blockLength, version);
        timeStamp = sid.time();
        InitFieldsPresentDecoder fields = sid.fieldsPresent();
        hasSoftwareInfo = fields.softwareInfo();
        hasHardwareInfo = fields.hardwareInfo();
        hasModelInfo = fields.modelInfo();

        //These fields: always present, even if !hasHardwareInfo
        hwJvmAvailableProcessors = sid.hwJvmProcessors();
        hwNumDevices = sid.hwNumDevices();
        hwJvmMaxMemory = sid.hwJvmMaxMemory();
        hwOffHeapMaxMemory = sid.hwOffheapMaxMemory();
        modelNumLayers = sid.modelNumLayers();
        modelNumParams = sid.modelNumParams();

        //Hardware device info group
        StaticInfoDecoder.HwDeviceInfoGroupDecoder hwDeviceInfoGroupDecoder = sid.hwDeviceInfoGroup();
        int count = hwDeviceInfoGroupDecoder.count();
        if (count > 0) {
            hwDeviceTotalMemory = new long[count];
            hwDeviceDescription = new String[count];
        }
        int i = 0;
        for (StaticInfoDecoder.HwDeviceInfoGroupDecoder hw : hwDeviceInfoGroupDecoder) {
            hwDeviceTotalMemory[i] = hw.deviceMemoryMax();
            hwDeviceDescription[i++] = hw.deviceDescription();
        }

        //Environment info group
        i = 0;
        StaticInfoDecoder.SwEnvironmentInfoDecoder swEnvDecoder = sid.swEnvironmentInfo();
        if (swEnvDecoder.count() > 0) {
            swEnvironmentInfo = new HashMap<>();
        }
        for (StaticInfoDecoder.SwEnvironmentInfoDecoder env : swEnvDecoder) {
            String key = env.envKey();
            String value = env.envValue();
            swEnvironmentInfo.put(key, value);
        }

        i = 0;
        StaticInfoDecoder.ModelParamNamesDecoder mpdec = sid.modelParamNames();
        int mpnCount = mpdec.count();
        modelParamNames = new String[mpnCount];
        for (StaticInfoDecoder.ModelParamNamesDecoder mp : mpdec) {
            modelParamNames[i++] = mp.modelParamNames();
        }
        //Variable length data. Even if it is missing: still needs to be read, to advance buffer
        //Again, the exact order of these calls matters here
        sessionID = sid.sessionID();
        typeID = sid.typeID();
        workerID = sid.workerID();
        swArch = sid.swArch();
        swOsName = sid.swOsName();
        swJvmName = sid.swJvmName();
        swJvmVersion = sid.swJvmVersion();
        swJvmSpecVersion = sid.swJvmSpecVersion();
        swNd4jBackendClass = sid.swNd4jBackendClass();
        swNd4jDataTypeName = sid.swNd4jDataTypeName();
        swHostName = sid.swHostName();
        swJvmUID = sid.swJvmUID();
        if (!hasSoftwareInfo)
            clearSwFields();
        hwHardwareUID = sid.hwHardwareUID();
        if (!hasHardwareInfo)
            clearHwFields();
        modelClassName = sid.modelConfigClassName();
        modelConfigJson = sid.modelConfigJson();
        if (!hasModelInfo)
            clearModelFields();
    }

    @Override
    public void decode(InputStream inputStream) throws IOException {
        byte[] bytes = IOUtils.toByteArray(inputStream);
        decode(bytes);
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy