All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.numenta.nupic.encoders.AdaptiveScalarEncoder Maven / Gradle / Ivy

package org.numenta.nupic.encoders;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

/**
 * This is an implementation of the scalar encoder that adapts the min and
 * max of the scalar encoder dynamically. This is essential to the streaming
 * model of the online prediction framework.
 * 
 * Initialization of an adaptive encoder using resolution or radius is not
 * supported; it must be initialized with n. This n is kept constant while
 * the min and max of the encoder changes.
 * 
 * The adaptive encoder must be have periodic set to false.
 * 
 * The adaptive encoder may be initialized with a minval and maxval or with
 * `None` for each of these. In the latter case, the min and max are set as
 * the 1st and 99th percentile over a window of the past 100 records.
 * 
 * *Note:** the sliding window may record duplicates of the values in the
 * data set, and therefore does not reflect the statistical distribution of
 * the input data and may not be used to calculate the median, mean etc.
 */
public class AdaptiveScalarEncoder extends ScalarEncoder {

    private static final long serialVersionUID = 1L;

    private static final Logger LOGGER = LoggerFactory.getLogger(AdaptiveScalarEncoder.class);

    public int recordNum = 0;
    public boolean learningEnabled = true;
    public Double[] slidingWindow = new Double[0];
    public int windowSize = 300;
    public Double bucketValues;

    /**
     * {@inheritDoc}
     *
     * @see org.numenta.nupic.encoders.ScalarEncoder#init()
     */
    @Override
    public void init() {
        this.setPeriodic(false);
        super.init();
    }

    /**
     * {@inheritDoc}
     *
     * @see org.numenta.nupic.encoders.ScalarEncoder#initEncoder(int, double,
     * double, int, double, double)
     */
    @Override
    public void initEncoder(int w, double minVal, double maxVal, int n,
        double radius, double resolution) {
        this.encLearningEnabled = true;
        if(this.periodic) {
            throw new IllegalStateException(
                "Adaptive scalar encoder does not encode periodic inputs");
        }
        assert n != 0;
        super.initEncoder(w, minVal, maxVal, n, radius, resolution);
    }

    /**
     * Constructs a new {@code AdaptiveScalarEncoder}
     */
    public AdaptiveScalarEncoder() {
    }

    /**
     * Returns a builder for building AdaptiveScalarEncoder. This builder may be
     * reused to produce multiple builders
     *
     * @return a {@code AdaptiveScalarEncoder.Builder}
     */
    public static AdaptiveScalarEncoder.Builder adaptiveBuilder() {
        return new AdaptiveScalarEncoder.Builder();
    }

    /**
     * Constructs a new {@link Builder} suitable for constructing
     * {@code AdaptiveScalarEncoder}s.
     */
    public static class Builder extends Encoder.Builder {
        private Builder() {}

        @Override
        public AdaptiveScalarEncoder build() {
            encoder = new AdaptiveScalarEncoder();
            super.build();
            ((AdaptiveScalarEncoder) encoder).init();
            return (AdaptiveScalarEncoder) encoder;
        }
    }

    /**
     * {@inheritDoc}
     */
    @Override
    public List topDownCompute(int[] encoded) {
        if (this.getMinVal() == 0 || this.getMaxVal() == 0) {
            List res = new ArrayList();
            int[] enArray = new int[this.getN()];
            Arrays.fill(enArray, 0);
            Encoding ecResult = new Encoding(0, 0, enArray);
            res.add(ecResult);
            return res;
        }
        return super.topDownCompute(encoded);
    }

    /**
     * {@inheritDoc}
     */
    @Override
    public void encodeIntoArray(Double input, int[] output) {
        this.recordNum += 1;
        boolean learn = false;
        if (!this.encLearningEnabled) {
            learn = true;
        }
        if (input == AdaptiveScalarEncoder.SENTINEL_VALUE_FOR_MISSING_DATA) {
            Arrays.fill(output, 0);
        } else if (!Double.isNaN(input)) {
            this.setMinAndMax(input, learn);
        }
        super.encodeIntoArray(input, output);
    }

    private void setMinAndMax(Double input, boolean learn) {
        if (slidingWindow.length >= windowSize) {
            slidingWindow = deleteItem(slidingWindow, 0);
        }
        slidingWindow = appendItem(slidingWindow, input);

        if (this.minVal == this.maxVal) {
            this.minVal = input;
            this.maxVal = input + 1;
            setEncoderParams();
        } else {
            Double[] sorted = Arrays.copyOf(slidingWindow, slidingWindow.length);
            Arrays.sort(sorted);
            double minOverWindow = sorted[0];
            double maxOverWindow = sorted[sorted.length - 1];
            if (minOverWindow < this.minVal) {
                LOGGER.trace("Input {}={} smaller than minVal {}. Adjusting minVal to {}",
                                this.name, input, this.minVal, minOverWindow);
                this.minVal = minOverWindow;
                setEncoderParams();
            }
            if (maxOverWindow > this.maxVal) {
                LOGGER.trace("Input {}={} greater than maxVal {}. Adjusting maxVal to {}",
                                this.name, input, this.minVal, minOverWindow);
                this.maxVal = maxOverWindow;
                setEncoderParams();
            }
        }
    }

    private void setEncoderParams() {
        this.rangeInternal = this.maxVal - this.minVal;
        this.resolution = this.rangeInternal / (this.n - this.w);
        this.radius = this.w * this.resolution;
        this.range = this.rangeInternal + this.resolution;
        this.nInternal = this.n - 2 * this.padding;
        this.bucketValues = null;
    }

    private Double[] appendItem(Double[] a, Double input) {
        a = Arrays.copyOf(a, a.length + 1);
        a[a.length - 1] = input;
        return a;
    }

    private Double[] deleteItem(Double[] a, int i) {
        a = Arrays.copyOfRange(a, 1, a.length - 1);
        return a;
    }

    /**
     * {@inheritDoc}
     */
    @Override
    public int[] getBucketIndices(String inputString) {
        double input = Double.parseDouble(inputString);
        return calculateBucketIndices(input);
    }

    /**
     * {@inheritDoc}
     */
    @Override
    public int[] getBucketIndices(double input) {
        return calculateBucketIndices(input);
    }

    private int[] calculateBucketIndices(double input) {
        this.recordNum += 1;
        boolean learn = false;
        if (!this.encLearningEnabled) {
            learn = true;
        }
        if ((Double.isNaN(input)) && (Double.valueOf(input) instanceof Double)) {
            input = AdaptiveScalarEncoder.SENTINEL_VALUE_FOR_MISSING_DATA;
        }
        if (input == AdaptiveScalarEncoder.SENTINEL_VALUE_FOR_MISSING_DATA) {
            return new int[this.n];
        } else {
            this.setMinAndMax(input, learn);
        }
        return super.getBucketIndices(input);
    }

    /**
     * {@inheritDoc}
     */
    @Override
    public List getBucketInfo(int[] buckets) {
        if (this.minVal == 0 || this.maxVal == 0) {
            int[] initialBuckets = new int[this.n];
            Arrays.fill(initialBuckets, 0);
            List encoderResultList = new ArrayList();
            Encoding encoderResult = new Encoding(0, 0, initialBuckets);
            encoderResultList.add(encoderResult);
            return encoderResultList;
        }
        return super.getBucketInfo(buckets);
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy