org.nd4j.linalg.dataset.api.iterator.StandardScaler Maven / Gradle / Ivy
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.nd4j.linalg.dataset.api.iterator;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.File;
import java.io.IOException;
@Deprecated
public class StandardScaler {
private static Logger logger = LoggerFactory.getLogger(StandardScaler.class);
private INDArray mean, std;
private long runningTotal = 0;
private long batchCount = 0;
public void fit(DataSet dataSet) {
mean = dataSet.getFeatures().mean(0);
std = dataSet.getFeatures().std(0);
std.addi(Nd4j.scalar(Nd4j.EPS_THRESHOLD));
if (std.min(1) == Nd4j.scalar(Nd4j.EPS_THRESHOLD))
logger.info("API_INFO: Std deviation found to be zero. Transform will round upto epsilon to avoid nans.");
}
/**
* Fit the given model
* @param iterator the data to iterate oer
*/
public void fit(DataSetIterator iterator) {
while (iterator.hasNext()) {
DataSet next = iterator.next();
runningTotal += next.numExamples();
batchCount = next.getFeatures().size(0);
if (mean == null) {
//start with the mean and std of zero
//column wise
mean = next.getFeatures().mean(0);
std = (batchCount == 1) ? Nd4j.zeros(mean.shape()) : Transforms.pow(next.getFeatures().std(0), 2);
std.muli(batchCount);
} else {
// m_newM = m_oldM + (x - m_oldM)/m_n;
// This only works if batch size is 1, m_newS = m_oldS + (x - m_oldM)*(x - m_newM);
INDArray xMinusMean = next.getFeatures().subRowVector(mean);
INDArray newMean = mean.add(xMinusMean.sum(0).divi(runningTotal));
// Using http://i.stanford.edu/pub/cstr/reports/cs/tr/79/773/CS-TR-79-773.pdf
// for a version of calc variance when dataset is partitioned into two sample sets
// Also described in https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
// delta = mean_B - mean_A; A is data seen so far, B is the current batch
// M2 is the var*n
// M2 = M2_A + M2_B + delta^2 * nA * nB/(nA+nB)
INDArray meanB = next.getFeatures().mean(0);
INDArray deltaSq = Transforms.pow(meanB.subRowVector(mean), 2);
INDArray deltaSqScaled =
deltaSq.mul(((float) runningTotal - batchCount) * batchCount / (float) runningTotal);
INDArray mtwoB = Transforms.pow(next.getFeatures().std(0), 2);
mtwoB.muli(batchCount);
std = std.add(mtwoB);
std = std.add(deltaSqScaled);
mean = newMean;
}
}
std.divi(runningTotal);
std = Transforms.sqrt(std);
std.addi(Nd4j.scalar(Nd4j.EPS_THRESHOLD));
if (std.min(1) == Nd4j.scalar(Nd4j.EPS_THRESHOLD))
logger.info("API_INFO: Std deviation found to be zero. Transform will round upto epsilon to avoid nans.");
iterator.reset();
}
/**
* Load the given mean and std
* @param mean the mean file
* @param std the std file
* @throws IOException
*/
public void load(File mean, File std) throws IOException {
this.mean = Nd4j.readBinary(mean);
this.std = Nd4j.readBinary(std);
}
/**
* Save the current mean and std
* @param mean the mean
* @param std the std
* @throws IOException
*/
public void save(File mean, File std) throws IOException {
Nd4j.saveBinary(this.mean, mean);
Nd4j.saveBinary(this.std, std);
}
/**
* Transform the data
* @param dataSet the dataset to transform
*/
public void transform(DataSet dataSet) {
dataSet.setFeatures(dataSet.getFeatures().subRowVector(mean));
dataSet.setFeatures(dataSet.getFeatures().divRowVector(std));
}
public INDArray getMean() {
return mean;
}
public INDArray getStd() {
return std;
}
}