org.datavec.api.timeseries.util.TimeSeriesWritableUtils Maven / Gradle / Ivy
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.api.timeseries.util;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Getter;
import org.datavec.api.writable.NDArrayWritable;
import org.datavec.api.writable.Writable;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.common.primitives.Pair;
import java.util.Iterator;
import java.util.List;
public class TimeSeriesWritableUtils {
/**
* Unchecked exception, thrown to signify that a zero-length sequence data set was encountered.
*/
public static class ZeroLengthSequenceException extends RuntimeException {
public ZeroLengthSequenceException() {
this("");
}
public ZeroLengthSequenceException(String type) {
super(String.format("Encountered zero-length %ssequence", type.equals("") ? "" : type + " "));
}
}
@AllArgsConstructor
@Builder
@Getter
public static class RecordDetails {
private int minValues;
private int maxTSLength;
}
/**
* Get the {@link RecordDetails}
* detailing the length of the time series
* @param record the input time series
* to get the details for
* @return the record details for the record
*/
public static RecordDetails getDetails(List>> record) {
int maxTimeSeriesLength = 0;
for(List> step : record) {
maxTimeSeriesLength = Math.max(maxTimeSeriesLength,step.size());
}
return RecordDetails.builder()
.minValues(record.size())
.maxTSLength(maxTimeSeriesLength).build();
}
/**
* Convert the writables
* to a sequence (3d) data set,
* and also return the
* mask array (if necessary)
* @param timeSeriesRecord the input time series
*
*/
public static Pair convertWritablesSequence(List>> timeSeriesRecord) {
return convertWritablesSequence(timeSeriesRecord,getDetails(timeSeriesRecord));
}
/**
* Convert the writables
* to a sequence (3d) data set,
* and also return the
* mask array (if necessary)
*/
public static Pair convertWritablesSequence(List>> list,
RecordDetails details) {
INDArray arr;
if (list.get(0).size() == 0) {
throw new ZeroLengthSequenceException("Zero length sequence encountered");
}
List firstStep = list.get(0).get(0);
int size = 0;
//Need to account for NDArrayWritables etc in list:
for (Writable w : firstStep) {
if (w instanceof NDArrayWritable) {
size += ((NDArrayWritable) w).get().size(1);
} else {
size++;
}
}
arr = Nd4j.create(new int[] {details.getMinValues(), size, details.getMaxTSLength()}, 'f');
boolean needMaskArray = false;
for (List> c : list) {
if (c.size() < details.getMaxTSLength()) {
needMaskArray = true;
break;
}
}
INDArray maskArray;
if (needMaskArray) {
maskArray = Nd4j.ones(details.getMinValues(), details.getMaxTSLength());
} else {
maskArray = null;
}
for (int i = 0; i < details.getMinValues(); i++) {
List> sequence = list.get(i);
int t = 0;
int k;
for (List timeStep : sequence) {
k = t++;
//Convert entire reader contents, without modification
Iterator iter = timeStep.iterator();
int j = 0;
while (iter.hasNext()) {
Writable w = iter.next();
if (w instanceof NDArrayWritable) {
INDArray row = ((NDArrayWritable) w).get();
arr.put(new INDArrayIndex[] {NDArrayIndex.point(i),
NDArrayIndex.interval(j, j + row.length()), NDArrayIndex.point(k)}, row);
j += row.length();
} else {
arr.putScalar(i, j, k, w.toDouble());
j++;
}
}
}
//For any remaining time steps: set mask array to 0 (just padding)
if (needMaskArray) {
//Masking array entries at end (for align start)
int lastStep = sequence.size();
for (int t2 = lastStep; t2 < details.getMaxTSLength(); t2++) {
maskArray.putScalar(i, t2, 0.0);
}
}
}
return new Pair<>(arr, maskArray);
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy