org.deeplearning4j.spark.datavec.DataVecDataSetFunction Maven / Gradle / Ivy
The newest version!
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.spark.datavec;
import lombok.extern.slf4j.Slf4j;
import org.apache.spark.api.java.function.Function;
import org.datavec.api.io.WritableConverter;
import org.datavec.api.io.converters.WritableConverterException;
import org.datavec.api.writable.NDArrayWritable;
import org.datavec.api.writable.Writable;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.FeatureUtil;
import java.io.Serializable;
import java.util.List;
@Slf4j
public class DataVecDataSetFunction implements Function, DataSet>, Serializable {
private final int labelIndex;
private final int labelIndexTo;
private final int numPossibleLabels;
private final boolean regression;
private final DataSetPreProcessor preProcessor;
private final WritableConverter converter;
protected int batchSize = -1;
public DataVecDataSetFunction(int labelIndex, int numPossibleLabels, boolean regression) {
this(labelIndex, numPossibleLabels, regression, null, null);
}
/**
* @param labelIndex Index of the label column
* @param numPossibleLabels Number of classes for classification (not used if regression = true)
* @param regression False for classification, true for regression
* @param preProcessor DataSetPreprocessor (may be null)
* @param converter WritableConverter (may be null)
*/
public DataVecDataSetFunction(int labelIndex, int numPossibleLabels, boolean regression,
DataSetPreProcessor preProcessor, WritableConverter converter) {
this(labelIndex, labelIndex, numPossibleLabels, regression, preProcessor, converter);
}
/**
* Main constructor, including for multi-label regression
*
* @param labelIndexFrom Index of the first target
* @param labelIndexTo Index of the last target, inclusive (for classification or single-output regression: same as labelIndexFrom)
* @param numPossibleLabels Unused for regression, or number of classes for classification
* @param regression If true: regression. false: classification
*/
public DataVecDataSetFunction(int labelIndexFrom, int labelIndexTo, int numPossibleLabels, boolean regression,
DataSetPreProcessor preProcessor, WritableConverter converter) {
this.labelIndex = labelIndexFrom;
this.labelIndexTo = labelIndexTo;
this.numPossibleLabels = numPossibleLabels;
this.regression = regression;
this.preProcessor = preProcessor;
this.converter = converter;
}
@Override
public DataSet call(List currList) throws Exception {
//allow people to specify label index as -1 and infer the last possible label
int labelIndex = this.labelIndex;
if (numPossibleLabels >= 1 && labelIndex < 0) {
labelIndex = currList.size() - 1;
}
INDArray label = null;
INDArray featureVector = null;
int featureCount = 0;
int labelCount = 0;
//no labels
if (currList.size() == 2 && currList.get(1) instanceof NDArrayWritable
&& currList.get(0) instanceof NDArrayWritable && currList.get(0) == currList.get(1)) {
NDArrayWritable writable = (NDArrayWritable) currList.get(0);
DataSet ds = new DataSet(writable.get(), writable.get());
if (preProcessor != null)
preProcessor.preProcess(ds);
return ds;
}
if (currList.size() == 2 && currList.get(0) instanceof NDArrayWritable) {
if (!regression)
label = FeatureUtil.toOutcomeVector((int) Double.parseDouble(currList.get(1).toString()),
numPossibleLabels);
else
label = Nd4j.scalar(Double.parseDouble(currList.get(1).toString())).reshape(1,1);
NDArrayWritable ndArrayWritable = (NDArrayWritable) currList.get(0);
featureVector = ndArrayWritable.get();
DataSet ds = new DataSet(featureVector, label);
if (preProcessor != null)
preProcessor.preProcess(ds);
return ds;
}
for (int j = 0; j < currList.size(); j++) {
Writable current = currList.get(j);
//ndarray writable is an insane slow down here
if (!(current instanceof NDArrayWritable) && current.toString().isEmpty())
continue;
if (labelIndex >= 0 && j >= labelIndex && j <= labelIndexTo) {
//single label case (classification, single label regression etc)
if (converter != null) {
try {
current = converter.convert(current);
} catch (WritableConverterException e) {
log.error("",e);
}
}
if (regression) {
//single and multi-label regression
if (label == null) {
label = Nd4j.zeros(1, labelIndexTo - labelIndex + 1);
}
label.putScalar(0, labelCount++, current.toDouble());
} else {
if (numPossibleLabels < 1)
throw new IllegalStateException(
"Number of possible labels invalid, must be >= 1 for classification");
int curr = current.toInt();
if (curr >= numPossibleLabels)
throw new IllegalStateException(
"Invalid index: got index " + curr + " but numPossibleLabels is "
+ numPossibleLabels + " (must be 0 <= idx < numPossibleLabels");
label = FeatureUtil.toOutcomeVector(curr, numPossibleLabels);
}
} else {
try {
double value = current.toDouble();
if (featureVector == null) {
if (regression && labelIndex >= 0) {
//Handle the possibly multi-label regression case here:
int nLabels = labelIndexTo - labelIndex + 1;
featureVector = Nd4j.create(1, currList.size() - nLabels);
} else {
//Classification case, and also no-labels case
featureVector = Nd4j.create(1, labelIndex >= 0 ? currList.size() - 1 : currList.size());
}
}
featureVector.putScalar(featureCount++, value);
} catch (UnsupportedOperationException e) {
// This isn't a scalar, so check if we got an array already
if (current instanceof NDArrayWritable) {
Preconditions.checkState(featureVector == null, "Already got an array");
featureVector = ((NDArrayWritable) current).get();
} else {
throw e;
}
}
}
}
DataSet ds = new DataSet(featureVector, (labelIndex >= 0 ? label : featureVector));
if (preProcessor != null)
preProcessor.preProcess(ds);
return ds;
}
}