Please wait. This can take some minutes ...
Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance.
Project price only 1 $
You can buy this project and download/modify it how often you want.
org.nd4j.linalg.dataset.MultiDataSet Maven / Gradle / Ivy
package org.nd4j.linalg.dataset;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.executioner.GridExecutioner;
import org.nd4j.linalg.dataset.api.DataSetUtil;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import java.io.*;
import java.util.*;
/**Implementation of {@link org.nd4j.linalg.dataset.api.MultiDataSet}
* @author Alex Black
*/
public class MultiDataSet implements org.nd4j.linalg.dataset.api.MultiDataSet {
private static final INDArray EMPTY_MASK_ARRAY_PLACEHOLDER = Nd4j.create(new float[] {-1});
private INDArray[] features;
private INDArray[] labels;
private INDArray[] featuresMaskArrays;
private INDArray[] labelsMaskArrays;
private List exampleMetaData;
/** Create a new (empty) MultiDataSet object (all fields are null) */
public MultiDataSet() {
}
/** MultiDataSet constructor with single features/labels input, no mask arrays */
public MultiDataSet(INDArray features, INDArray labels) {
this(features, labels, null, null);
}
/** MultiDataSet constructor with single features/labels input, single mask arrays */
public MultiDataSet(INDArray features, INDArray labels, INDArray featuresMask, INDArray labelsMask) {
this((features != null ? new INDArray[] {features} : null),
(labels != null ? new INDArray[] {labels} : null),
(featuresMask != null ? new INDArray[]{featuresMask} : null),
(labelsMask != null ? new INDArray[]{labelsMask} : null));
}
/** MultiDataSet constructor with no mask arrays */
public MultiDataSet(INDArray[] features, INDArray[] labels) {
this(features, labels, null, null);
}
/**
*
* @param features The features (inputs) to the algorithm/neural network
* @param labels The labels (outputs) to the algorithm/neural network
* @param featuresMaskArrays The mask arrays for the features. May be null. Typically used with variable-length time series models, etc
* @param labelsMaskArrays The mask arrays for the labels. May be null. Typically used with variable-length time series models, etc
*/
public MultiDataSet(INDArray[] features, INDArray[] labels, INDArray[] featuresMaskArrays,
INDArray[] labelsMaskArrays) {
if (features != null && featuresMaskArrays != null && features.length != featuresMaskArrays.length) {
throw new IllegalArgumentException("Invalid features / features mask arrays combination: "
+ "features and features mask arrays must not be different lengths");
}
if (labels != null && labelsMaskArrays != null && labels.length != labelsMaskArrays.length) {
throw new IllegalArgumentException("Invalid labels / labels mask arrays combination: "
+ "labels and labels mask arrays must not be different lengths");
}
this.features = features;
this.labels = labels;
this.featuresMaskArrays = featuresMaskArrays;
this.labelsMaskArrays = labelsMaskArrays;
Nd4j.getExecutioner().commit();
}
@Override
public List getExampleMetaData() {
return exampleMetaData;
}
@Override
public List getExampleMetaData(Class metaDataType) {
return (List) exampleMetaData;
}
@Override
public void setExampleMetaData(List exampleMetaData) {
this.exampleMetaData = (List) exampleMetaData;
}
@Override
public int numFeatureArrays() {
return (features != null ? features.length : 0);
}
@Override
public int numLabelsArrays() {
return (labels != null ? labels.length : 0);
}
@Override
public INDArray[] getFeatures() {
return features;
}
@Override
public INDArray getFeatures(int index) {
return features[index];
}
@Override
public void setFeatures(INDArray[] features) {
this.features = features;
}
@Override
public void setFeatures(int idx, INDArray features) {
this.features[idx] = features;
}
@Override
public INDArray[] getLabels() {
return labels;
}
@Override
public INDArray getLabels(int index) {
return labels[index];
}
@Override
public void setLabels(INDArray[] labels) {
this.labels = labels;
}
@Override
public void setLabels(int idx, INDArray labels) {
this.labels[idx] = labels;
}
@Override
public boolean hasMaskArrays() {
if (featuresMaskArrays == null && labelsMaskArrays == null)
return false;
if (featuresMaskArrays != null) {
for (INDArray i : featuresMaskArrays) {
if (i != null)
return true;
}
}
if (labelsMaskArrays != null) {
for (INDArray i : labelsMaskArrays) {
if (i != null)
return true;
}
}
return false;
}
@Override
public INDArray[] getFeaturesMaskArrays() {
return featuresMaskArrays;
}
@Override
public INDArray getFeaturesMaskArray(int index) {
return (featuresMaskArrays != null ? featuresMaskArrays[index] : null);
}
@Override
public void setFeaturesMaskArrays(INDArray[] maskArrays) {
this.featuresMaskArrays = maskArrays;
}
@Override
public void setFeaturesMaskArray(int idx, INDArray maskArray) {
this.featuresMaskArrays[idx] = maskArray;
}
@Override
public INDArray[] getLabelsMaskArrays() {
return labelsMaskArrays;
}
@Override
public INDArray getLabelsMaskArray(int index) {
return (labelsMaskArrays != null ? labelsMaskArrays[index] : null);
}
@Override
public void setLabelsMaskArray(INDArray[] labelsMaskArrays) {
this.labelsMaskArrays = labelsMaskArrays;
}
@Override
public void setLabelsMaskArray(int idx, INDArray labelsMaskArray) {
this.labelsMaskArrays[idx] = labelsMaskArray;
}
@Override
public void save(OutputStream to) throws IOException {
int numFArr = (features == null ? 0 : features.length);
int numLArr = (labels == null ? 0 : labels.length);
int numFMArr = (featuresMaskArrays == null ? 0 : featuresMaskArrays.length);
int numLMArr = (labelsMaskArrays == null ? 0 : labelsMaskArrays.length);
try (DataOutputStream dos = new DataOutputStream(new BufferedOutputStream(to))) {
dos.writeInt(numFArr);
dos.writeInt(numLArr);
dos.writeInt(numFMArr);
dos.writeInt(numLMArr);
saveINDArrays(features, dos, false);
saveINDArrays(labels, dos, false);
saveINDArrays(featuresMaskArrays, dos, true);
saveINDArrays(labelsMaskArrays, dos, true);
}
}
private void saveINDArrays(INDArray[] arrays, DataOutputStream dos, boolean isMask) throws IOException {
if (arrays != null && arrays.length > 0) {
for (INDArray fm : arrays) {
if (isMask && fm == null) {
fm = EMPTY_MASK_ARRAY_PLACEHOLDER;
}
Nd4j.write(fm, dos);
}
}
}
@Override
public void save(File to) throws IOException {
save(new FileOutputStream(to));
}
@Override
public void load(InputStream from) throws IOException {
try (DataInputStream dis = new DataInputStream(from)) {
int numFArr = dis.readInt();
int numLArr = dis.readInt();
int numFMArr = dis.readInt();
int numLMArr = dis.readInt();
features = loadINDArrays(numFArr, dis, false);
labels = loadINDArrays(numLArr, dis, false);
featuresMaskArrays = loadINDArrays(numFMArr, dis, true);
labelsMaskArrays = loadINDArrays(numLMArr, dis, true);
}
}
private INDArray[] loadINDArrays(int numArrays, DataInputStream dis, boolean isMask) throws IOException {
INDArray[] result = null;
if (numArrays > 0) {
result = new INDArray[numArrays];
for (int i = 0; i < numArrays; i++) {
INDArray arr = Nd4j.read(dis);
result[i] = isMask && arr.equals(EMPTY_MASK_ARRAY_PLACEHOLDER) ? null : arr;
}
}
return result;
}
@Override
public void load(File from) throws IOException {
load(new FileInputStream(from));
}
@Override
public List asList() {
int nExamples = features[0].size(0);
List list = new ArrayList<>();
for (int i = 0; i < nExamples; i++) {
INDArray[] thisFeatures = new INDArray[features.length];
INDArray[] thisLabels = new INDArray[labels.length];
INDArray[] thisFeaturesMaskArray =
(featuresMaskArrays != null ? new INDArray[featuresMaskArrays.length] : null);
INDArray[] thisLabelsMaskArray = (labelsMaskArrays != null ? new INDArray[labelsMaskArrays.length] : null);
for (int j = 0; j < features.length; j++) {
thisFeatures[j] = getSubsetForExample(features[j], i);
}
for (int j = 0; j < labels.length; j++) {
thisLabels[j] = getSubsetForExample(labels[j], i);
}
if (thisFeaturesMaskArray != null) {
for (int j = 0; j < thisFeaturesMaskArray.length; j++) {
if (featuresMaskArrays[j] == null)
continue;
thisFeaturesMaskArray[j] = getSubsetForExample(featuresMaskArrays[j], i);
}
}
if (thisLabelsMaskArray != null) {
for (int j = 0; j < thisLabelsMaskArray.length; j++) {
if (labelsMaskArrays[j] == null)
continue;
thisLabelsMaskArray[j] = getSubsetForExample(labelsMaskArrays[j], i);
}
}
list.add(new MultiDataSet(thisFeatures, thisLabels, thisFeaturesMaskArray, thisLabelsMaskArray));
}
return list;
}
private static INDArray getSubsetForExample(INDArray array, int idx) {
//Note the interval use here: normally .point(idx) would be used, but this collapses the point dimension
// when used on arrays with rank of 3 or greater
//So (point,all,all) on a 3d input returns a 2d output. Whereas, we want a 3d [1,x,y] output here
switch (array.rank()) {
case 2:
return array.get(NDArrayIndex.point(idx), NDArrayIndex.all());
case 3:
return array.get(NDArrayIndex.interval(idx, idx, true), NDArrayIndex.all(), NDArrayIndex.all());
case 4:
return array.get(NDArrayIndex.interval(idx, idx, true), NDArrayIndex.all(), NDArrayIndex.all(),
NDArrayIndex.all());
default:
throw new IllegalStateException("Cannot get subset for rank " + array.rank() + " array");
}
}
/**
* Clone the dataset
*
* @return a clone of the dataset
*/
@Override
public MultiDataSet copy() {
MultiDataSet ret = new MultiDataSet(copy(getFeatures()), copy(getLabels()));
if (labelsMaskArrays != null) {
ret.setLabelsMaskArray(copy(labelsMaskArrays));
}
if (featuresMaskArrays != null) {
ret.setFeaturesMaskArrays(copy(featuresMaskArrays));
}
return ret;
}
private INDArray[] copy(INDArray[] arrays) {
INDArray[] result = new INDArray[arrays.length];
for (int i = 0; i < arrays.length; i++) {
result[i] = arrays[i].dup();
}
return result;
}
/** Merge a collection of MultiDataSet objects into a single MultiDataSet.
* Merging is done by concatenating along dimension 0 (example number in batch)
* Merging operation may introduce mask arrays (when necessary) for time series data that has different lengths;
* if mask arrays already exist, these will be merged also.
*
* @param toMerge Collection of MultiDataSet objects to merge
* @return a single MultiDataSet object, containing the arrays of
*/
public static MultiDataSet merge(Collection toMerge) {
if (toMerge.size() == 1) {
org.nd4j.linalg.dataset.api.MultiDataSet mds = toMerge.iterator().next();
if (mds instanceof MultiDataSet)
return (MultiDataSet) mds;
else
return new MultiDataSet(mds.getFeatures(), mds.getLabels(), mds.getFeaturesMaskArrays(),
mds.getLabelsMaskArrays());
}
List list;
if (toMerge instanceof List)
list = (List) toMerge;
else
list = new ArrayList<>(toMerge);
int nInArrays = list.get(0).numFeatureArrays();
int nOutArrays = list.get(0).numLabelsArrays();
INDArray[][] features = new INDArray[list.size()][0];
INDArray[][] labels = new INDArray[list.size()][0];
INDArray[][] featuresMasks = new INDArray[list.size()][0];
INDArray[][] labelsMasks = new INDArray[list.size()][0];
int i = 0;
for (org.nd4j.linalg.dataset.api.MultiDataSet mds : list) {
features[i] = mds.getFeatures();
labels[i] = mds.getLabels();
featuresMasks[i] = mds.getFeaturesMaskArrays();
labelsMasks[i] = mds.getLabelsMaskArrays();
if (features[i] == null || features[i].length != nInArrays) {
throw new IllegalStateException(
"Cannot merge MultiDataSets with different number of input arrays: toMerge[0] has "
+ nInArrays + " input arrays; toMerge[" + i + "] has "
+ (features[i] != null ? features[i].length : null) + " arrays");
}
if (labels[i] == null || labels[i].length != nOutArrays) {
throw new IllegalStateException(
"Cannot merge MultiDataSets with different number of output arrays: toMerge[0] has "
+ nOutArrays + " output arrays; toMerge[" + i + "] has "
+ (labels[i] != null ? labels[i].length : null) + " arrays");
}
i++;
}
//Now, merge:
INDArray[] mergedFeatures = new INDArray[nInArrays];
INDArray[] mergedLabels = new INDArray[nOutArrays];
INDArray[] mergedFeaturesMasks = new INDArray[nInArrays];
INDArray[] mergedLabelsMasks = new INDArray[nOutArrays];
boolean needFeaturesMasks = false;
for (i = 0; i < nInArrays; i++) {
Pair pair = DataSetUtil.mergeFeatures(features, featuresMasks, i); //merge(features, featuresMasks, i);
mergedFeatures[i] = pair.getFirst();
mergedFeaturesMasks[i] = pair.getSecond();
if (mergedFeaturesMasks[i] != null)
needFeaturesMasks = true;
}
if (!needFeaturesMasks)
mergedFeaturesMasks = null;
boolean needLabelsMasks = false;
for (i = 0; i < nOutArrays; i++) {
Pair pair = DataSetUtil.mergeLabels(labels, labelsMasks, i);
mergedLabels[i] = pair.getFirst();
mergedLabelsMasks[i] = pair.getSecond();
if (mergedLabelsMasks[i] != null)
needLabelsMasks = true;
}
if (!needLabelsMasks)
mergedLabelsMasks = null;
return new MultiDataSet(mergedFeatures, mergedLabels, mergedFeaturesMasks, mergedLabelsMasks);
}
@Override
public String toString() {
StringBuilder builder = new StringBuilder();
int totalEntries = numFeatureArrays();
if (totalEntries != numLabelsArrays()) {
return "";
}
for (int i = 0; i < totalEntries; i++) {
builder.append("\n=========== ENTRY " + i + " =================\n");
builder.append("\n=== INPUT ===\n").append(getFeatures(i).toString().replaceAll(";", "\n"))
.append("\n=== OUTPUT ===\n").append(getLabels(i).toString().replaceAll(";", "\n"));
if (getFeaturesMaskArray(i) != null) {
builder.append("\n=== INPUT MASK ===\n")
.append(getFeaturesMaskArray(i).toString().replaceAll(";", "\n"));
}
if (getLabelsMaskArray(i) != null) {
builder.append("\n=== OUTPUT MASK ===\n")
.append(getLabelsMaskArray(i).toString().replaceAll(";", "\n"));
}
}
return builder.toString();
}
@Override
public boolean equals(Object o) {
if (o == this)
return true;
if (!(o instanceof MultiDataSet))
return false;
MultiDataSet m = (MultiDataSet) o;
if (!bothNullOrEqual(features, m.features))
return false;
if (!bothNullOrEqual(labels, m.labels))
return false;
if (!bothNullOrEqual(featuresMaskArrays, m.featuresMaskArrays))
return false;
return bothNullOrEqual(labelsMaskArrays, m.labelsMaskArrays);
}
private boolean bothNullOrEqual(INDArray[] first, INDArray[] second) {
if (first == null && second == null)
return true;
if (first == null || second == null)
return false; //One but not both null
if (first.length != second.length)
return false;
for (int i = 0; i < first.length; i++) {
if (!Objects.equals(first[i], second[i])) {
return false;
}
}
return true;
}
@Override
public int hashCode() {
int result = 0;
if (features != null) {
for (INDArray f : features) {
result = result * 31 + f.hashCode();
}
}
if (labels != null) {
for (INDArray l : labels) {
result = result * 31 + l.hashCode();
}
}
if (featuresMaskArrays != null) {
for (INDArray fm : featuresMaskArrays) {
result = result * 31 + fm.hashCode();
}
}
if (labelsMaskArrays != null) {
for (INDArray lm : labelsMaskArrays) {
result = result * 31 + lm.hashCode();
}
}
return result;
}
/**
* This method returns memory used by this DataSet
*
* @return
*/
@Override
public long getMemoryFootprint() {
long reqMem = 0;
for (INDArray f: features)
reqMem += f == null ? 0 : f.lengthLong() * Nd4j.sizeOfDataType();
if (featuresMaskArrays != null)
for (INDArray f: featuresMaskArrays)
reqMem += f == null ? 0 : f.lengthLong() * Nd4j.sizeOfDataType();
if (labelsMaskArrays != null)
for (INDArray f: labelsMaskArrays)
reqMem += f == null ? 0 : f.lengthLong() * Nd4j.sizeOfDataType();
if (labels != null)
for (INDArray f: labels)
reqMem += f == null ? 0 : f.lengthLong() * Nd4j.sizeOfDataType();
return reqMem;
}
@Override
public void migrate() {
if (Nd4j.getMemoryManager().getCurrentWorkspace() != null) {
if (features != null)
for (int e = 0; e < features.length; e++)
features[e] = features[e].migrate();
if (labels != null)
for (int e = 0; e < labels.length; e++)
labels[e] = labels[e].migrate();
if (featuresMaskArrays != null)
for (int e = 0; e < featuresMaskArrays.length; e++)
featuresMaskArrays[e] = featuresMaskArrays[e].migrate();
if (labelsMaskArrays != null)
for (int e = 0; e < labelsMaskArrays.length; e++)
labelsMaskArrays[e] = labelsMaskArrays[e].migrate();
}
}
/**
* This method migrates this DataSet into current Workspace (if any)
*/
@Override
public void detach() {
if (features != null)
for (int e = 0; e < features.length; e++)
features[e] = features[e].detach();
if (labels != null)
for (int e = 0; e < labels.length; e++)
labels[e] = labels[e].detach();
if (featuresMaskArrays != null)
for (int e = 0; e < featuresMaskArrays.length; e++)
featuresMaskArrays[e] = featuresMaskArrays[e].detach();
if (labelsMaskArrays != null)
for (int e = 0; e < labelsMaskArrays.length; e++)
labelsMaskArrays[e] = labelsMaskArrays[e].detach();
}
}