All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.spark.datavec.DataVecSequencePairDataSetFunction Maven / Gradle / Ivy

The newest version!
/*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */

package org.deeplearning4j.spark.datavec;

import org.apache.spark.api.java.function.Function;
import org.datavec.api.io.WritableConverter;
import org.datavec.api.writable.NDArrayWritable;
import org.datavec.api.writable.Writable;
import org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.util.FeatureUtil;
import scala.Tuple2;

import java.io.Serializable;
import java.util.Iterator;
import java.util.List;

public class DataVecSequencePairDataSetFunction
                implements Function>, List>>, DataSet>, Serializable {
    /**Alignment mode for dealing with input/labels of differing lengths (for example, one-to-many and many-to-one type situations).
     * For example, might have 10 time steps total but only one label at end for sequence classification.
* EQUAL_LENGTH: Default. Assume that label and input time series are of equal length
* ALIGN_START: Align the label/input time series at the first time step, and zero pad either the labels or * the input at the end (pad whichever is shorter)
* ALIGN_END: Align the label/input at the last time step, zero padding either the input or the labels as required
*/ public enum AlignmentMode { EQUAL_LENGTH, ALIGN_START, ALIGN_END } private final boolean regression; private final int numPossibleLabels; private final AlignmentMode alignmentMode; private final DataSetPreProcessor preProcessor; private final WritableConverter converter; /** Constructor for equal length and no conversion of labels (i.e., regression or already in one-hot representation). * No data set proprocessor or writable converter */ public DataVecSequencePairDataSetFunction() { this(-1, true); } /**Constructor for equal length, no data set preprocessor or writable converter * @see #DataVecSequencePairDataSetFunction(int, boolean, AlignmentMode, DataSetPreProcessor, WritableConverter) */ public DataVecSequencePairDataSetFunction(int numPossibleLabels, boolean regression) { this(numPossibleLabels, regression, AlignmentMode.EQUAL_LENGTH); } /**Constructor for data with a specified alignment mode, no data set preprocessor or writable converter * @see #DataVecSequencePairDataSetFunction(int, boolean, AlignmentMode, DataSetPreProcessor, WritableConverter) */ public DataVecSequencePairDataSetFunction(int numPossibleLabels, boolean regression, AlignmentMode alignmentMode) { this(numPossibleLabels, regression, alignmentMode, null, null); } /** * @param numPossibleLabels Number of classes for classification (not used if regression = true) * @param regression False for classification, true for regression * @param alignmentMode Alignment mode for data. See {@link DataVecSequencePairDataSetFunction.AlignmentMode} * @param preProcessor DataSetPreprocessor (may be null) * @param converter WritableConverter (may be null) */ public DataVecSequencePairDataSetFunction(int numPossibleLabels, boolean regression, AlignmentMode alignmentMode, DataSetPreProcessor preProcessor, WritableConverter converter) { this.numPossibleLabels = numPossibleLabels; this.regression = regression; this.alignmentMode = alignmentMode; this.preProcessor = preProcessor; this.converter = converter; } @Override public DataSet call(Tuple2>, List>> input) throws Exception { List> featuresSeq = input._1(); List> labelsSeq = input._2(); int featuresLength = featuresSeq.size(); int labelsLength = labelsSeq.size(); Iterator> fIter = featuresSeq.iterator(); Iterator> lIter = labelsSeq.iterator(); INDArray inputArr = null; INDArray outputArr = null; int[] idx = new int[3]; int i = 0; while (fIter.hasNext()) { List step = fIter.next(); if (i == 0) { int[] inShape = new int[] {1, step.size(), featuresLength}; inputArr = Nd4j.create(inShape); } Iterator timeStepIter = step.iterator(); int f = 0; idx[1] = 0; while (timeStepIter.hasNext()) { Writable current = timeStepIter.next(); if (converter != null) current = converter.convert(current); try { inputArr.putScalar(idx, current.toDouble()); } catch (UnsupportedOperationException e) { // This isn't a scalar, so check if we got an array already if (current instanceof NDArrayWritable) { inputArr.get(NDArrayIndex.point(idx[0]), NDArrayIndex.all(), NDArrayIndex.point(idx[2])) .putRow(0, ((NDArrayWritable) current).get()); } else { throw e; } } idx[1] = ++f; } idx[2] = ++i; } idx = new int[3]; i = 0; while (lIter.hasNext()) { List step = lIter.next(); if (i == 0) { int[] outShape = new int[] {1, (regression ? step.size() : numPossibleLabels), labelsLength}; outputArr = Nd4j.create(outShape); } Iterator timeStepIter = step.iterator(); int f = 0; idx[1] = 0; if (regression) { //Load all values without modification while (timeStepIter.hasNext()) { Writable current = timeStepIter.next(); if (converter != null) current = converter.convert(current); outputArr.putScalar(idx, current.toDouble()); idx[1] = ++f; } } else { //Expect a single value (index) -> convert to one-hot vector Writable value = timeStepIter.next(); int labelClassIdx = value.toInt(); INDArray line = FeatureUtil.toOutcomeVector(labelClassIdx, numPossibleLabels); outputArr.tensorAlongDimension(i, 1).assign(line); //1d from [1,nOut,timeSeriesLength] -> tensor i along dimension 1 is at time i } idx[2] = ++i; } DataSet ds; if (alignmentMode == AlignmentMode.EQUAL_LENGTH || featuresLength == labelsLength) { ds = new DataSet(inputArr, outputArr); } else if (alignmentMode == AlignmentMode.ALIGN_END) { if (featuresLength > labelsLength) { //Input longer, pad output INDArray newOutput = Nd4j.create(1, outputArr.size(1), featuresLength); newOutput.get(NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.interval(featuresLength - labelsLength, featuresLength)).assign(outputArr); //Need an output mask array, but not an input mask array INDArray outputMask = Nd4j.create(1, featuresLength); for (int j = featuresLength - labelsLength; j < featuresLength; j++) outputMask.putScalar(j, 1.0); ds = new DataSet(inputArr, newOutput, Nd4j.ones(outputMask.shape()), outputMask); } else { //Output longer, pad input INDArray newInput = Nd4j.create(1, inputArr.size(1), labelsLength); newInput.get(NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.interval(labelsLength - featuresLength, labelsLength)).assign(inputArr); //Need an input mask array, but not an output mask array INDArray inputMask = Nd4j.create(1, labelsLength); for (int j = labelsLength - featuresLength; j < labelsLength; j++) inputMask.putScalar(j, 1.0); ds = new DataSet(newInput, outputArr, inputMask, Nd4j.ones(inputMask.shape())); } } else if (alignmentMode == AlignmentMode.ALIGN_START) { if (featuresLength > labelsLength) { //Input longer, pad output INDArray newOutput = Nd4j.create(1, outputArr.size(1), featuresLength); newOutput.get(NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.interval(0, labelsLength)) .assign(outputArr); //Need an output mask array, but not an input mask array INDArray outputMask = Nd4j.create(1, featuresLength); for (int j = 0; j < labelsLength; j++) outputMask.putScalar(j, 1.0); ds = new DataSet(inputArr, newOutput, Nd4j.ones(outputMask.shape()), outputMask); } else { //Output longer, pad input INDArray newInput = Nd4j.create(1, inputArr.size(1), labelsLength); newInput.get(NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.interval(0, featuresLength)) .assign(inputArr); //Need an input mask array, but not an output mask array INDArray inputMask = Nd4j.create(1, labelsLength); for (int j = 0; j < featuresLength; j++) inputMask.putScalar(j, 1.0); ds = new DataSet(newInput, outputArr, inputMask, Nd4j.ones(inputMask.shape())); } } else { throw new UnsupportedOperationException("Invalid alignment mode: " + alignmentMode); } if (preProcessor != null) preProcessor.preProcess(ds); return ds; } }




© 2015 - 2024 Weber Informatics LLC | Privacy Policy