All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.wikibrain.sr.normalize.LoessNormalizer Maven / Gradle / Ivy

There is a newer version: 0.9.1
Show newest version
package org.wikibrain.sr.normalize;

import com.typesafe.config.Config;
import gnu.trove.list.TDoubleList;
import gnu.trove.list.array.TDoubleArrayList;
import org.apache.commons.math3.analysis.UnivariateFunction;
import org.apache.commons.math3.analysis.interpolation.LoessInterpolator;
import org.apache.commons.math3.stat.ranking.NaNStrategy;
import org.apache.commons.math3.stat.ranking.NaturalRanking;
import org.apache.commons.math3.stat.ranking.TiesStrategy;
import org.wikibrain.conf.Configuration;
import org.wikibrain.conf.ConfigurationException;
import org.wikibrain.conf.Configurator;
import org.wikibrain.utils.WbMathUtils;

import java.text.DecimalFormat;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * Normalizes in two steps:
 * 1. Create a smoothed weighted average defined over a sample of the observed points.
 * 2. Creates a local linear spline fitted to smoothed points.
 */

public class LoessNormalizer extends BaseNormalizer {

    private static Logger LOG = LoggerFactory.getLogger(LoessNormalizer.class);
    public static final long serialVersionUID = -34232429;

    private TDoubleList X = new TDoubleArrayList();
    private TDoubleList Y = new TDoubleArrayList();
    private boolean logTransform = false;
    private boolean monotonic = false;

    transient private double interpolatorMin;
    transient private double interpolatorMax;
    transient private UnivariateFunction interpolator = null;

    @Override
    public void reset() {
        super.reset();
        X.clear();
        Y.clear();
        interpolatorMin = 0;
        interpolatorMax = 0;
        interpolator = null;
    }

    @Override
    public void observe(double x, double y){
        if (!Double.isNaN(x) && !Double.isInfinite(x)) {
            synchronized (X) {
                X.add(x);
                Y.add(y);
            }
        }
        super.observe(x, y);
    }

    @Override
    public void observationsFinished(){
        // lazily initialized to overcome problems
        // with PolynomialSplineFunction serialization.
        super.observationsFinished();

    }

    private static final double EPSILON = 1E-10;

    @Override
    public double normalize(double x) {
        if (Double.isNaN(x) || Double.isInfinite(x)) {
            return missingMean;
        }
        init();
        x = logIfNeeded(x);
        double sMin = interpolatorMin;
        double sMax = interpolatorMax;

        double x2;
        if (sMin <= x && x <= sMax) {
            x2 = getInterpolationFunction().value(x);
        } else {
            double yMin = getInterpolationFunction().value(sMin);
            double yMax = getInterpolationFunction().value(sMax);
            double halfLife = (sMax - sMin) / 4.0;
            double yDelta = 0.1 * (yMax - yMin);
            if (x < sMin) {
                x2 =  WbMathUtils.toAsymptote(sMin - x, halfLife, yMin, yMin - yDelta);
            } else if (x > sMax) {
                x2 = WbMathUtils.toAsymptote(x - sMax, halfLife, yMax, yMax + yDelta);
            } else {
                throw new IllegalStateException("" + x + " not in [" + sMin + "," + sMax + "]");
            }
        }
        return x2;
    }


    private synchronized  UnivariateFunction getInterpolationFunction() {
        init();
        return interpolator;
    }

    private synchronized void init() {
        if (interpolator != null) {
            return;
        }

        // remove infinite or nan values
        TDoubleList pruned[] = WbMathUtils.removeNotNumberPoints(X, Y);
        X = pruned[0];
        Y = pruned[1];

        // sort points by X coordinate
        double ranks[] =  new NaturalRanking(NaNStrategy.REMOVED, TiesStrategy.SEQUENTIAL).rank(X.toArray());
        if (ranks.length != X.size()) {
            throw new IllegalStateException("invalid sizes: " + ranks.length + " and " + X.size());
        }
        // spots in these arrays will be replaced.
        TDoubleList sortedX = new TDoubleArrayList(X);
        TDoubleList sortedY = new TDoubleArrayList(Y);
        for (int i = 0; i < X.size(); i++) {
            int r = (int)Math.round(ranks[i]) - 1;
            sortedX.set(r, X.get(i));
            sortedY.set(r, Y.get(i));
        }
        X = sortedX;
        Y = sortedY;

        // create the smoothed points.
        int windowSize = Math.min(20, X.size() / 10);
        double smoothed[][] = WbMathUtils.smooth(
                logIfNeeded(X.toArray()),
                Y.toArray(),
                windowSize,
                10);
        double smoothedX[] = smoothed[0];
        double smoothedY[] = smoothed[1];

        /*System.err.print("smoothed points: ");
        for (int i = 0; i < smoothedX.length; i++) {
            System.err.print(" (" + smoothedX[i] + ", " + smoothedY[i] + ")");
        }
        System.err.println();*/
        interpolatorMin = smoothedX[0];
        interpolatorMax = smoothedX[smoothedX.length - 1];

        WbMathUtils.makeMonotonicIncreasing(smoothedX, EPSILON);
        if (monotonic) {
            WbMathUtils.makeMonotonicIncreasing(smoothedY, EPSILON);
        }

        // create the interpolator
        interpolator = new LoessInterpolator().interpolate(smoothedX, smoothedY);
    }

    private double logIfNeeded(double x) {
        if (logTransform) {
            return (x < X.get(0)) ? 0 : Math.log(1 + X.get(0) + x);
        } else {
            return x;
        }
    }

    private double[] logIfNeeded(double X[]) {
        if (logTransform) {
            double X2[] = new double[X.length];
            for (int i = 0; i < X.length; i++) {
                X2[i] = logIfNeeded(X[i]);
            }
            return X2;
        } else {
            return X;
        }
    }

    @Override
    public String dump() {
        init();
        StringBuffer buff = new StringBuffer("loess normalizer");
        if (logTransform) buff.append(" (log'ed)");
        DecimalFormat df = new DecimalFormat("#.##");
        for (int i = 0; i <= 20; i++) {
            int j = Math.min(X.size() - 1, i * X.size() / 20);
            double x = X.get(j);
            buff.append(" <" +
                    df.format(x) + "," +
                    df.format(normalize(x)) + ">");
        }
        return buff.toString();
    }

    public void setLogTransform(boolean b) {
        this.logTransform = b;
    }

    public boolean getLogTransform() {
        return logTransform;
    }

    public void setMonotonic(boolean b) {
        this.monotonic = b;
    }

    public static class Provider extends org.wikibrain.conf.Provider {
        public Provider(Configurator configurator, Configuration config) throws ConfigurationException {
            super(configurator, config);
        }

        @Override
        public Class getType() {
            return Normalizer.class;
        }

        @Override
        public String getPath() {
            return "sr.normalizer";
        }

        @Override
        public Scope getScope() {
            return Scope.INSTANCE;
        }

        @Override
        public LoessNormalizer get(String name, Config config, Map runtimeParams) throws ConfigurationException {
            if (!config.getString("type").equals("loess")) {
                return null;
            }
            LoessNormalizer ln = new LoessNormalizer();
            if (config.hasPath("log")) {
                ln.setLogTransform(config.getBoolean("log"));
            }
            return ln;
        }

    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy