All Downloads are FREE. Search and download functionalities are using the official Maven repository.

eu.fbk.utils.math.Scaler Maven / Gradle / Ivy

package eu.fbk.utils.math;

import org.apache.log4j.Logger;

import java.util.Collection;
import java.util.List;
import java.util.Objects;

/**
 * Simple scaler for machine learning purposes
 *
 * @author Yaroslav Nechaev ([email protected])
 */
public class Scaler {
    private static final Logger logger = Logger.getLogger(Scaler.class.getName());
    public static final int MAX_DEBUG_OUTPUT = 50;

    private double[] stddev = null;
    private double[] mean = null;

    public void fit(List features) {
        if (features.size() == 0) {
            throw new IllegalArgumentException("Feature matrix can't be empty");
        }
        mean = new double[features.get(0).length];
        stddev = new double[mean.length];

        //Calculating mean
        for (int i = 0; i < mean.length; i++) {
            mean[i] = 0;
            stddev[i] = 0;
        }
        for (double[] featureVector : features) {
            for (int i = 0; i < mean.length; i++) {
                mean[i] += featureVector[i];
            }
        }
        for (int i = 0; i < mean.length; i++) {
            mean[i] /= features.size();
        }

        //Calculating standard deviation
        for (double[] featureVector : features) {
            for (int i = 0; i < mean.length; i++) {
                stddev[i] += Math.pow(featureVector[i] - mean[i], 2);
            }
        }
        for (int i = 0; i < mean.length; i++) {
            stddev[i] = Math.sqrt(stddev[i] / features.size());
        }

        if (logger.isDebugEnabled()) {
            logger.debug("Fit training set with idx/mean/stddev: ");
            int outputLength = mean.length > MAX_DEBUG_OUTPUT ? MAX_DEBUG_OUTPUT : mean.length;
            for (int i = 0; i < outputLength; i++) {
                logger.debug(String.format("  %d\t%.2f\t%.2f", i, mean[i], stddev[i]));
            }
            if (mean.length > MAX_DEBUG_OUTPUT) {
                logger.debug("  ...");
            }
        }
    }

    public void transform(double[] features) {
        Objects.requireNonNull(stddev);
        Objects.requireNonNull(mean);

        for (int i = 0; i < features.length; i++) {
            features[i] = features[i] - mean[i];
            if (stddev[i] != 0) {
                features[i] = (features[i] - mean[i]) / stddev[i];
            }
        }
    }

    public void transform(Collection features) {
        features.forEach(this::transform);
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy