org.nd4j.linalg.dataset.api.preprocessor.AbstractDataSetNormalizer Maven / Gradle / Ivy
The newest version!
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.nd4j.linalg.dataset.api.preprocessor;
import lombok.AccessLevel;
import lombok.EqualsAndHashCode;
import lombok.NonNull;
import lombok.Setter;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.stats.NormalizerStats;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
@EqualsAndHashCode(callSuper = false)
public abstract class AbstractDataSetNormalizer extends AbstractNormalizer
implements DataNormalization {
protected NormalizerStrategy strategy;
@Setter(AccessLevel.PROTECTED)
private S featureStats;
@Setter(AccessLevel.PROTECTED)
private S labelStats;
private boolean fitLabels = false;
protected AbstractDataSetNormalizer(NormalizerStrategy strategy) {
this.strategy = strategy;
}
/**
* Flag to specify if the labels/outputs in the dataset should be also normalized
* default value is false
*
* @param fitLabels
*/
@Override
public void fitLabel(boolean fitLabels) {
this.fitLabels = fitLabels;
}
/**
* Whether normalization for the labels is also enabled. Most commonly used for regression, not classification.
*
* @return True if labels will be
*/
@Override
public boolean isFitLabel() {
return this.fitLabels;
}
/**
* Fit a dataset (only compute based on the statistics from this dataset)
* @param dataSet the dataset to compute on
*/
@Override
public void fit(DataSet dataSet) {
featureStats = (S) newBuilder().addFeatures(dataSet).build();
if (isFitLabel()) {
labelStats = (S) newBuilder().addLabels(dataSet).build();
}
}
protected S getFeatureStats() {
return featureStats;
}
protected S getLabelStats() {
return labelStats;
}
@Override
protected boolean isFit() {
return featureStats != null;
}
/**
* Fit the given model
*
* @param iterator for the data to iterate over
*/
@Override
public void fit(DataSetIterator iterator) {
S.Builder featureNormBuilder = newBuilder();
S.Builder labelNormBuilder = newBuilder();
iterator.reset();
while (iterator.hasNext()) {
DataSet next = iterator.next();
featureNormBuilder.addFeatures(next);
if (fitLabels) {
labelNormBuilder.addLabels(next);
}
}
featureStats = (S) featureNormBuilder.build();
if (fitLabels) {
labelStats = (S) labelNormBuilder.build();
}
iterator.reset();
}
protected abstract S.Builder newBuilder();
/**
* Pre process a dataset
*
* @param toPreProcess the data set to pre process
*/
@Override
public void preProcess(@NonNull DataSet toPreProcess) {
transform(toPreProcess.getFeatures(), toPreProcess.getFeaturesMaskArray());
transformLabel(toPreProcess.getLabels(), toPreProcess.getLabelsMaskArray());
}
/**
* Transform the given dataset
*
* @param toPreProcess
*/
@Override
public void transform(DataSet toPreProcess) {
preProcess(toPreProcess);
}
/**
* Transform the given INDArray
*
* @param features
*/
@Override
public void transform(INDArray features) {
transform(features, null);
}
@Override
public void transform(INDArray features, INDArray featuresMask) {
S featureStatsLocal = getFeatureStats();
if(featureStatsLocal == null){
throw new ND4JIllegalStateException("Features statistics were not yet calculated. Make sure to run fit() first.");
}
strategy.preProcess(features, featuresMask, featureStatsLocal); }
/**
* Transform the labels. If {@link #isFitLabel()} == false, this is a no-op
*/
@Override
public void transformLabel(INDArray label) {
transformLabel(label, null);
}
@Override
public void transformLabel(INDArray label, INDArray labelsMask) {
if (isFitLabel()) {
strategy.preProcess(label, labelsMask, getLabelStats());
}
}
@Override
public void revertFeatures(INDArray features) {
revertFeatures(features, null);
}
@Override
public void revertFeatures(INDArray features, INDArray featuresMask) {
strategy.revert(features, featuresMask, getFeatureStats());
}
/**
* Undo (revert) the normalization applied by this DataNormalization instance to the specified labels array.
* If labels normalization is disabled (i.e., {@link #isFitLabel()} == false) then this is a no-op.
* Can also be used to undo normalization for network output arrays, in the case of regression.
*
* @param labels Labels array to revert the normalization on
*/
@Override
public void revertLabels(INDArray labels) {
revertLabels(labels, null);
}
@Override
public void revertLabels(INDArray labels, INDArray labelsMask) {
if (isFitLabel()) {
strategy.revert(labels, labelsMask, getLabelStats());
}
}
/**
* Revert the data to what it was before transform
*
* @param data the dataset to revert back
*/
@Override
public void revert(DataSet data) {
revertFeatures(data.getFeatures(), data.getFeaturesMaskArray());
revertLabels(data.getLabels(), data.getLabelsMaskArray());
}
}