com.clust4j.algo.preprocess.RobustScaler Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of clust4j Show documentation
Show all versions of clust4j Show documentation
A suite of classification clustering algorithm implementations for Java. A number of partitional, hierarchical and density-based algorithms including DBSCAN, k-Means, k-Medoids, MeanShift, Affinity Propagation, HDBSCAN and more.
The newest version!
/*******************************************************************************
* Copyright 2015, 2016 Taylor G Smith
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
package com.clust4j.algo.preprocess;
import org.apache.commons.math3.exception.DimensionMismatchException;
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.stat.descriptive.DescriptiveStatistics;
import com.clust4j.except.ModelNotFitException;
import com.clust4j.utils.MatUtils;
import com.clust4j.utils.VecUtils;
public class RobustScaler extends Transformer {
private static final long serialVersionUID = 9139185680482876266L;
volatile private MedianCenterer centerer;
volatile double[] scale;
private RobustScaler(RobustScaler rs) {
this.centerer = null == rs.centerer ? null : rs.centerer.copy();
this.scale = VecUtils.copy(rs.scale);
}
public RobustScaler() {
}
@Override
protected void checkFit() {
if(null == centerer)
throw new ModelNotFitException("model not yet fit");
}
@Override
public RealMatrix inverseTransform(RealMatrix X) {
checkFit();
// This effectively copies, so no need to do a copy later
double[][] data = X.getData();
final int m = data.length;
final int n = data[0].length;
if(n != this.centerer.medians.length)
throw new DimensionMismatchException(n, this.centerer.medians.length);
// First, multiply back by scales
for(int j = 0; j < n; j++) {
for(int i = 0; i < m; i++) {
data[i][j] *= scale[j];
// To avoid a second pass of O(M*N), we
// won't call the inverseTransform in the centerer,
// we will just explicitly add the median back here.
data[i][j] += centerer.medians[j];
}
}
return new Array2DRowRealMatrix(data, false);
}
@Override
public RobustScaler copy() {
return new RobustScaler(this);
}
@Override
public RobustScaler fit(RealMatrix X) {
synchronized(fitLock) {
this.centerer = new MedianCenterer().fit(X);
// Get percentile
final int n = X.getColumnDimension();
double[][] transpose = X.transpose().getData();
// top row will be 25th, bottom 75
double[][] quantiles_25_75 = new double[2][n];
// Quantile engine
DescriptiveStatistics stats;
for(int j = 0; j < n; j++) {
stats = new DescriptiveStatistics();
for(int i = 0; i < transpose[j].length; i++) {
stats.addValue(transpose[j][i]);
}
quantiles_25_75[0][j] = stats.getPercentile(25);
quantiles_25_75[0][j] = stats.getPercentile(75);
}
// set the scale
this.scale = VecUtils.subtract(quantiles_25_75[1], quantiles_25_75[0]);
// If we have a constant value, we might get zeroes in the scale:
for(int i = 0; i < scale.length; i++) {
if(scale[i] == 0) {
scale[i] = 1.0;
}
}
return this;
}
}
@Override
public RealMatrix transform(RealMatrix data) {
return new Array2DRowRealMatrix(transform(data.getData()), false);
}
@Override
public double[][] transform(double[][] data) {
checkFit();
MatUtils.checkDimsForUniformity(data);
final int m = data.length;
final int n = data[0].length;
// Dim mismatch will happen on the median side
double[][] centered = centerer.transform(data);
// Scale:
for(int j = 0; j < n; j++) {
for(int i = 0; i < m; i++) {
centered[i][j] /= scale[j];
}
}
return centered;
}
}