com.github.chen0040.data.utils.Scaler Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of java-data-frame Show documentation
Show all versions of java-data-frame Show documentation
Some common patterns of data frame in Java
The newest version!
package com.github.chen0040.data.utils;
import com.github.chen0040.data.frame.DataFrame;
import com.github.chen0040.data.frame.DataRow;
import com.github.chen0040.data.frame.InputDataColumn;
import com.github.chen0040.data.frame.OutputDataColumn;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
/**
* Created by xschen on 5/5/2017.
*/
public class Scaler {
private Map means = new HashMap<>();
private Map sds = new HashMap<>();
public Scaler makeCopy() {
Scaler clone = new Scaler();
clone.copy(this);
return clone;
}
public void copy(Scaler that) {
means.clear();
sds.clear();
means.putAll(that.means);
sds.putAll(that.sds);
}
public void fit(DataFrame frame) {
means.clear();
sds.clear();
List inputColumns = frame.getInputColumns().stream().map(InputDataColumn::getColumnName).collect(Collectors.toList());
List outputColumns = frame.getOutputColumns().stream().map(OutputDataColumn::getColumnName).collect(Collectors.toList());
for(String c : inputColumns){
double[] values = new double[frame.rowCount()];
for(int i=0; i < frame.rowCount(); ++i){
double value = frame.row(i).getCell(c);
values[i] = value;
}
double mean = Mean.apply(values);
means.put(c, mean);
double sd = StdDev.apply(values, mean);
sds.put(c, sd);
}
for(String c : outputColumns){
double[] values = new double[frame.rowCount()];
for(int i=0; i < frame.rowCount(); ++i){
double value = frame.row(i).getTargetCell(c);
values[i] = value;
}
double mean = Mean.apply(values);
means.put(c, mean);
double sd = StdDev.apply(values, mean);
sds.put(c, sd);
}
}
public double transform(String columnName, double value) {
double mean = means.getOrDefault(columnName, 0.0);
double sd = sds.getOrDefault(columnName, 0.0);
if(sd != 0){
return (value - mean) / sd;
} else {
return value;
}
}
public double inverseTransform(String columnName, double value) {
double mean = means.getOrDefault(columnName, 0.0);
double sd = sds.getOrDefault(columnName, 0.0);
if(sd != 0){
return value * sd + mean;
} else {
return value;
}
}
public DataRow transform(DataRow row) {
DataRow scaled = row.makeCopy();
List inputColumns = scaled.getColumnNames();
for(String c : inputColumns){
scaled.setCell(c, transform(c, scaled.getCell(c)));
}
List outputColumns = scaled.getTargetColumnNames();
for(String c : outputColumns) {
scaled.setTargetCell(c, transform(c, scaled.getTargetCell(c)));
}
return scaled;
}
public DataRow inverseTransform(DataRow row) {
DataRow scaled = row.makeCopy();
List inputColumns = scaled.getColumnNames();
for(String c : inputColumns){
scaled.setCell(c, inverseTransform(c, scaled.getCell(c)));
}
List outputColumns = scaled.getTargetColumnNames();
for(String c : outputColumns) {
scaled.setTargetCell(c, inverseTransform(c, scaled.getTargetCell(c)));
}
return scaled;
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy