com.simiacryptus.mindseye.models.Hdf5Archive Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of mindseye-art Show documentation
Show all versions of mindseye-art Show documentation
Visual Neural Network Applications
/*
* Copyright (c) 2019 by Andrew Charneski.
*
* The author licenses this file to you under the
* Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance
* with the License. You may obtain a copy
* of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package com.simiacryptus.mindseye.models;
import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.simiacryptus.mindseye.lang.Tensor;
import org.apache.commons.lang3.ArrayUtils;
import org.bytedeco.javacpp.BytePointer;
import org.bytedeco.javacpp.FloatPointer;
import org.bytedeco.javacpp.Loader;
import org.bytedeco.javacpp.hdf5;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import java.io.File;
import java.io.IOException;
import java.lang.Exception;
import java.nio.ByteBuffer;
import java.util.*;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import static org.bytedeco.javacpp.hdf5.*;
/**
* Class for reading arrays and JSON strings from HDF5 achive files. Originally part of deeplearning4j.
*
* @author dave @skymind.io
*/
public class Hdf5Archive {
private static final Logger log = LoggerFactory.getLogger(Hdf5Archive.class);
static {
try {
/* This is necessary for the apply to the BytePointer constructor below. */
Loader.load(hdf5.class);
} catch (Exception e) {
e.printStackTrace();
}
}
@Nonnull
private final H5File file;
@Nonnull
private final File filename;
/**
* Instantiates a new Hdf 5 archive.
*
* @param filename the archive filename
*/
public Hdf5Archive(@Nonnull String filename) {
this(new File(filename));
}
/**
* Instantiates a new Hdf 5 archive.
*
* @param filename the filename
*/
public Hdf5Archive(@Nonnull File filename) {
this.filename = filename;
try {
this.file = new H5File(filename.getCanonicalPath(), H5F_ACC_RDONLY());
} catch (@Nonnull final RuntimeException e) {
throw e;
} catch (IOException e) {
throw new RuntimeException(e);
}
}
private static void print(@Nonnull Hdf5Archive archive, @Nonnull Logger log) {
printTree(archive, "", false, log);
}
private static void printTree(@Nonnull Hdf5Archive hdf5, CharSequence prefix, boolean printData, @Nonnull Logger log, @Nonnull String... path) {
for (CharSequence datasetName : hdf5.getDataSets(path)) {
@Nullable Tensor tensor = hdf5.readDataSet(datasetName.toString(), path);
log.info(String.format("%sDataset %s: %s", prefix, datasetName, Arrays.toString(tensor.getDimensions())));
if (printData) log.info(String.format("%s%s", prefix, tensor.prettyPrint().replaceAll("\n", "\n" + prefix)));
tensor.freeRef();
}
hdf5.getAttributes(path).forEach((k, v) -> {
log.info((String.format("%sAttribute: %s => %s", prefix, k, v)));
});
for (String t : hdf5.getGroups(path).stream().map(CharSequence::toString).sorted(new Comparator() {
@Override
public int compare(@Nonnull String o1, @Nonnull String o2) {
@Nonnull String prefix = "layer_";
@Nonnull Pattern digit = Pattern.compile("^\\d+$");
if (digit.matcher(o1).matches() && digit.matcher(o2).matches())
return Integer.compare(Integer.parseInt(o1), Integer.parseInt(o2));
if (o1.startsWith(prefix) && o2.startsWith(prefix))
return compare(o1.substring(prefix.length()), o2.substring(prefix.length()));
else return o1.compareTo(o2);
}
}).collect(Collectors.toList())) {
log.info(prefix + t);
printTree(hdf5, prefix + "\t", printData, log, concat(path, t));
}
}
@Nonnull
private static String[] concat(@Nonnull CharSequence[] s, String t) {
@Nonnull String[] strings = new String[s.length + 1];
System.arraycopy(s, 0, strings, 0, s.length);
strings[s.length] = t;
return strings;
}
@Override
public String toString() {
return String.format("Hdf5Archive{%s}", file);
}
@Nonnull
private Group[] openGroups(@Nonnull CharSequence... groups) {
@Nonnull Group[] groupArray = new Group[groups.length];
groupArray[0] = this.file.openGroup(groups[0].toString());
for (int i = 1; i < groups.length; i++) {
groupArray[i] = groupArray[i - 1].openGroup(groups[i].toString());
}
return groupArray;
}
private void closeGroups(@Nonnull Group[] groupArray) {
for (int i = groupArray.length - 1; i >= 0; i--) {
groupArray[i].deallocate();
}
}
/**
* Read data setBytes as ND4J array from group path.
*
* @param datasetName Name of data setBytes
* @param groups Array of zero or more ancestor groups from root to parent.
* @return tensor tensor
*/
@Nullable
public Tensor readDataSet(CharSequence datasetName, @Nonnull CharSequence... groups) {
if (groups.length == 0) {
return readDataSet(this.file, datasetName);
}
@Nonnull Group[] groupArray = openGroups(groups);
@Nullable Tensor a = readDataSet(groupArray[groupArray.length - 1], datasetName);
closeGroups(groupArray);
return a;
}
/**
* Read JSON-formatted string attribute from group path.
*
* @param attributeName Name of attribute
* @param groups Array of zero or more ancestor groups from root to parent.
* @return string string
*/
@Nullable
public CharSequence readAttributeAsJson(String attributeName, @Nonnull String... groups) {
if (groups.length == 0) {
return readAttributeAsJson(this.file.openAttribute(attributeName));
}
@Nonnull Group[] groupArray = openGroups(groups);
@Nullable String s = readAttributeAsJson(groupArray[groups.length - 1].openAttribute(attributeName));
closeGroups(groupArray);
return s;
}
/**
* Read string attribute from group path.
*
* @param attributeName Name of attribute
* @param groups Array of zero or more ancestor groups from root to parent.
* @return string string
*/
@Nullable
public CharSequence readAttributeAsString(String attributeName, @Nonnull String... groups) {
if (groups.length == 0) {
return readAttributeAsString(this.file.openAttribute(attributeName));
}
@Nonnull Group[] groupArray = openGroups(groups);
@Nullable String s = readAttributeAsString(groupArray[groupArray.length - 1].openAttribute(attributeName));
closeGroups(groupArray);
return s;
}
/**
* Check whether group path contains string attribute.
*
* @param attributeName Name of attribute
* @param groups Array of zero or more ancestor groups from root to parent.
* @return Boolean indicating whether attribute exists in group path.
*/
public boolean hasAttribute(String attributeName, @Nonnull String... groups) {
if (groups.length == 0) {
return this.file.attrExists(attributeName);
}
@Nonnull Group[] groupArray = openGroups(groups);
boolean b = groupArray[groupArray.length - 1].attrExists(attributeName);
closeGroups(groupArray);
return b;
}
/**
* Gets attributes.
*
* @param groups the groups
* @return the attributes
*/
@Nonnull
public Map getAttributes(@Nonnull String... groups) {
if (groups.length == 0) {
return getAttributes(this.file);
}
@Nonnull Group[] groupArray = openGroups(groups);
Group group = groupArray[groupArray.length - 1];
@Nonnull Map attributes = getAttributes(group);
closeGroups(groupArray);
return attributes;
}
/**
* Gets attributes.
*
* @param group the group
* @return the attributes
*/
@Nonnull
public Map getAttributes(@Nonnull Group group) {
int numAttrs = group.getNumAttrs();
@Nonnull TreeMap attributes = new TreeMap<>();
for (int i = 0; i < numAttrs; i++) {
Attribute attribute = group.openAttribute(i);
CharSequence name = attribute.getName().getString();
int typeId = attribute.getTypeClass();
if (typeId == 0) {
attributes.put(name, getI64(attribute));
} else {
System.out.println(name + " type = " + typeId);
attributes.put(name, getString(attribute));
}
attribute.deallocate();
}
return attributes;
}
private long getI64(@Nonnull Attribute attribute) {
return getI64(attribute, attribute.getIntType(), new byte[8]);
}
@Nonnull
private CharSequence getString(@Nonnull Attribute attribute) {
return getString(attribute, attribute.getVarLenType(), new byte[1024]);
}
private long getI64(@Nonnull Attribute attribute, DataType dataType, @Nonnull byte[] buffer) {
@Nonnull BytePointer pointer = new BytePointer(buffer);
attribute.read(dataType, pointer);
pointer.get(buffer);
ArrayUtils.reverse(buffer);
return ByteBuffer.wrap(buffer).asLongBuffer().get();
}
@Nonnull
private CharSequence getString(@Nonnull Attribute attribute, DataType dataType, @Nonnull byte[] buffer) {
@Nonnull BytePointer pointer = new BytePointer(buffer);
attribute.read(dataType, pointer);
pointer.get(buffer);
@Nonnull String str = new String(buffer);
if (str.indexOf('\0') >= 0) {
return str.substring(0, str.indexOf('\0'));
} else {
return str;
}
}
/**
* Get list of data sets from group path.
*
* @param groups Array of zero or more ancestor groups from root to parent.
* @return data sets
*/
@Nonnull
public List getDataSets(@Nonnull String... groups) {
if (groups.length == 0) {
return getObjects(this.file, H5O_TYPE_DATASET);
}
@Nonnull Group[] groupArray = openGroups(groups);
@Nonnull List ls = getObjects(groupArray[groupArray.length - 1], H5O_TYPE_DATASET);
closeGroups(groupArray);
return ls;
}
/**
* Get list of groups from group path.
*
* @param groups Array of zero or more ancestor groups from root to parent.
* @return groups groups
*/
@Nonnull
public List getGroups(@Nonnull String... groups) {
if (groups.length == 0) {
return getObjects(this.file, H5O_TYPE_GROUP);
}
@Nonnull Group[] groupArray = openGroups(groups);
@Nonnull List ls = getObjects(groupArray[groupArray.length - 1], H5O_TYPE_GROUP);
closeGroups(groupArray);
return ls;
}
/**
* Read data setBytes as ND4J array from HDF5 group.
*
* @param fileGroup HDF5 file or group
* @param datasetName Name of data setBytes
* @return
*/
@Nullable
private Tensor readDataSet(@Nonnull Group fileGroup, CharSequence datasetName) {
DataSet dataset = fileGroup.openDataSet(datasetName.toString());
DataSpace space = dataset.getSpace();
int nbDims = space.getSimpleExtentNdims();
@Nonnull long[] dims = new long[nbDims];
space.getSimpleExtentDims(dims);
@Nullable float[] dataBuffer = null;
@Nullable FloatPointer fp = null;
int j = 0;
@Nonnull DataType dataType = new DataType(PredType.NATIVE_FLOAT());
@Nullable Tensor data = null;
switch (nbDims) {
case 4: /* 2D Convolution weights */
dataBuffer = new float[(int) (dims[0] * dims[1] * dims[2] * dims[3])];
fp = new FloatPointer(dataBuffer);
dataset.read(fp, dataType);
fp.get(dataBuffer);
data = new Tensor((int) dims[0], (int) dims[1], (int) dims[2], (int) dims[3]);
j = 0;
for (int i1 = 0; i1 < dims[0]; i1++)
for (int i2 = 0; i2 < dims[1]; i2++)
for (int i3 = 0; i3 < dims[2]; i3++)
for (int i4 = 0; i4 < dims[3]; i4++)
data.set(i1, i2, i3, i4, (double) dataBuffer[j++]);
break;
case 3:
dataBuffer = new float[(int) (dims[0] * dims[1] * dims[2])];
fp = new FloatPointer(dataBuffer);
dataset.read(fp, dataType);
fp.get(dataBuffer);
data = new Tensor((int) dims[0], (int) dims[1], (int) dims[2]);
j = 0;
for (int i1 = 0; i1 < dims[0]; i1++)
for (int i2 = 0; i2 < dims[1]; i2++)
for (int i3 = 0; i3 < dims[2]; i3++)
data.set(i1, i2, i3, dataBuffer[j++]);
break;
case 2: /* Dense and Recurrent weights */
dataBuffer = new float[(int) (dims[0] * dims[1])];
fp = new FloatPointer(dataBuffer);
dataset.read(fp, dataType);
fp.get(dataBuffer);
data = new Tensor((int) dims[0], (int) dims[1]);
j = 0;
for (int i1 = 0; i1 < dims[0]; i1++)
for (int i2 = 0; i2 < dims[1]; i2++)
data.set(i1, i2, dataBuffer[j++]);
break;
case 1: /* Bias */
dataBuffer = new float[(int) dims[0]];
fp = new FloatPointer(dataBuffer);
dataset.read(fp, dataType);
fp.get(dataBuffer);
data = new Tensor((int) dims[0]);
j = 0;
for (int i1 = 0; i1 < dims[0]; i1++)
data.set(i1, dataBuffer[j++]);
break;
default:
throw new RuntimeException("Cannot import weights apply rank " + nbDims);
}
space.deallocate();
dataset.deallocate();
return data;
}
/**
* Get list of objects apply a given type from a file group.
*
* @param fileGroup HDF5 file or group
* @param objType Type of object as integer
* @return
*/
@Nonnull
private List getObjects(@Nonnull Group fileGroup, int objType) {
@Nonnull List groups = new ArrayList();
for (int i = 0; i < fileGroup.getNumObjs(); i++) {
BytePointer objPtr = fileGroup.getObjnameByIdx(i);
if (fileGroup.childObjType(objPtr) == objType) {
groups.add(fileGroup.getObjnameByIdx(i).getString());
}
}
return groups;
}
/**
* Read JSON-formatted string attribute.
*
* @param attribute HDF5 attribute to read as JSON formatted string.
* @return
*/
@Nullable
private String readAttributeAsJson(@Nonnull Attribute attribute) {
VarLenType vl = attribute.getVarLenType();
int bufferSizeMult = 1;
@Nullable String s = null;
/* TODO: find a less hacky way to do this.
* Reading variable length strings (from attributes) is a giant
* pain. There does not appear to be any way to determine the
* length of the string in advance, so we use a hack: choose a
* buffer size and read the config. If Jackson fails to parse
* it, then we must not have read the entire config. Increase
* buffer and repeat.
*/
while (true) {
@Nonnull byte[] attrBuffer = new byte[bufferSizeMult * 2000];
@Nonnull BytePointer attrPointer = new BytePointer(attrBuffer);
attribute.read(vl, attrPointer);
attrPointer.get(attrBuffer);
s = new String(attrBuffer);
@Nonnull ObjectMapper mapper = new ObjectMapper();
mapper.enable(DeserializationFeature.FAIL_ON_READING_DUP_TREE_KEY);
try {
mapper.readTree(s);
break;
} catch (IOException e) {
}
bufferSizeMult++;
if (bufferSizeMult > 100) {
throw new RuntimeException("Could not read abnormally long HDF5 attribute");
}
}
return s;
}
/**
* Read attribute as string.
*
* @param attribute HDF5 attribute to read as string.
* @return
*/
@Nullable
private String readAttributeAsString(@Nonnull Attribute attribute) {
VarLenType vl = attribute.getVarLenType();
int bufferSizeMult = 1;
@Nullable String s = null;
/* TODO: find a less hacky way to do this.
* Reading variable length strings (from attributes) is a giant
* pain. There does not appear to be any way to determine the
* length of the string in advance, so we use a hack: choose a
* buffer size and read the config, increase buffer and repeat
* until the buffer ends apply \u0000
*/
while (true) {
@Nonnull byte[] attrBuffer = new byte[bufferSizeMult * 2000];
@Nonnull BytePointer attrPointer = new BytePointer(attrBuffer);
attribute.read(vl, attrPointer);
attrPointer.get(attrBuffer);
s = new String(attrBuffer);
if (s.endsWith("\u0000")) {
s = s.replace("\u0000", "");
break;
}
bufferSizeMult++;
if (bufferSizeMult > 100) {
throw new RuntimeException("Could not read abnormally long HDF5 attribute");
}
}
return s;
}
/**
* Read string attribute from group path.
*
* @param attributeName Name of attribute
* @param bufferSize buffer size to read
* @return string string
*/
@Nonnull
public CharSequence readAttributeAsFixedLengthString(String attributeName, int bufferSize) {
return readAttributeAsFixedLengthString(this.file.openAttribute(attributeName), bufferSize);
}
/**
* Read attribute of fixed buffer size as string.
*
* @param attribute HDF5 attribute to read as string.
* @return
*/
@Nonnull
private CharSequence readAttributeAsFixedLengthString(@Nonnull Attribute attribute, int bufferSize) {
VarLenType vl = attribute.getVarLenType();
@Nonnull byte[] attrBuffer = new byte[bufferSize];
@Nonnull BytePointer attrPointer = new BytePointer(attrBuffer);
attribute.read(vl, attrPointer);
attrPointer.get(attrBuffer);
@Nonnull String s = new String(attrBuffer);
return s;
}
/**
* Print.
*/
public void print() {
print(log);
}
/**
* Print.
*
* @param log the log
*/
public void print(@Nonnull Logger log) {
print(this, log);
}
/**
* Gets filename.
*
* @return the filename
*/
@Nonnull
public File getFilename() {
return filename;
}
}