
org.numenta.nupic.encoders.AdaptiveScalarEncoder Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of htm.java Show documentation
Show all versions of htm.java Show documentation
The Java version of Numenta's HTM technology
package org.numenta.nupic.encoders;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
/**
* This is an implementation of the scalar encoder that adapts the min and
* max of the scalar encoder dynamically. This is essential to the streaming
* model of the online prediction framework.
*
* Initialization of an adaptive encoder using resolution or radius is not
* supported; it must be initialized with n. This n is kept constant while
* the min and max of the encoder changes.
*
* The adaptive encoder must be have periodic set to false.
*
* The adaptive encoder may be initialized with a minval and maxval or with
* `None` for each of these. In the latter case, the min and max are set as
* the 1st and 99th percentile over a window of the past 100 records.
*
* *Note:** the sliding window may record duplicates of the values in the
* data set, and therefore does not reflect the statistical distribution of
* the input data and may not be used to calculate the median, mean etc.
*/
public class AdaptiveScalarEncoder extends ScalarEncoder {
private static final Logger LOGGER = LoggerFactory.getLogger(AdaptiveScalarEncoder.class);
public int recordNum = 0;
public boolean learningEnabled = true;
public Double[] slidingWindow = new Double[0];
public int windowSize = 300;
public Double bucketValues;
/**
* {@inheritDoc}
*
* @see org.numenta.nupic.encoders.ScalarEncoder#init()
*/
@Override
public void init() {
this.setPeriodic(false);
super.init();
}
/**
* {@inheritDoc}
*
* @see org.numenta.nupic.encoders.ScalarEncoder#initEncoder(int, double,
* double, int, double, double)
*/
@Override
public void initEncoder(int w, double minVal, double maxVal, int n,
double radius, double resolution) {
this.setPeriodic(false);
this.encLearningEnabled = true;
if (this.periodic) {
throw new IllegalStateException(
"Adaptive scalar encoder does not encode periodic inputs");
}
assert n != 0;
super.initEncoder(w, minVal, maxVal, n, radius, resolution);
}
/**
* Constructs a new {@code AdaptiveScalarEncoder}
*/
public AdaptiveScalarEncoder() {
}
/**
* Returns a builder for building AdaptiveScalarEncoder. This builder may be
* reused to produce multiple builders
*
* @return a {@code AdaptiveScalarEncoder.Builder}
*/
public static AdaptiveScalarEncoder.Builder adaptiveBuilder() {
return new AdaptiveScalarEncoder.Builder();
}
/**
* Constructs a new {@link Builder} suitable for constructing
* {@code AdaptiveScalarEncoder}s.
*/
public static class Builder extends Encoder.Builder {
private Builder() {}
@Override
public AdaptiveScalarEncoder build() {
encoder = new AdaptiveScalarEncoder();
super.build();
((AdaptiveScalarEncoder) encoder).init();
return (AdaptiveScalarEncoder) encoder;
}
}
/**
* {@inheritDoc}
*/
@Override
public List topDownCompute(int[] encoded) {
if (this.getMinVal() == 0 || this.getMaxVal() == 0) {
List res = new ArrayList();
int[] enArray = new int[this.getN()];
Arrays.fill(enArray, 0);
EncoderResult ecResult = new EncoderResult(0, 0, enArray);
res.add(ecResult);
return res;
}
return super.topDownCompute(encoded);
}
/**
* {@inheritDoc}
*/
@Override
public void encodeIntoArray(Double input, int[] output) {
this.recordNum += 1;
boolean learn = false;
if (!this.encLearningEnabled) {
learn = true;
}
if (input == AdaptiveScalarEncoder.SENTINEL_VALUE_FOR_MISSING_DATA) {
Arrays.fill(output, 0);
} else if (!Double.isNaN(input)) {
this.setMinAndMax(input, learn);
}
super.encodeIntoArray(input, output);
}
private void setMinAndMax(Double input, boolean learn) {
if (slidingWindow.length >= windowSize) {
slidingWindow = deleteItem(slidingWindow, 0);
}
slidingWindow = appendItem(slidingWindow, input);
if (this.minVal == this.maxVal) {
this.minVal = input;
this.maxVal = input + 1;
setEncoderParams();
} else {
Double[] sorted = Arrays.copyOf(slidingWindow, slidingWindow.length);
Arrays.sort(sorted);
double minOverWindow = sorted[0];
double maxOverWindow = sorted[sorted.length - 1];
if (minOverWindow < this.minVal) {
LOGGER.trace("Input {}={} smaller than minVal {}. Adjusting minVal to {}",
this.name, input, this.minVal, minOverWindow);
this.minVal = minOverWindow;
setEncoderParams();
}
if (maxOverWindow > this.maxVal) {
LOGGER.trace("Input {}={} greater than maxVal {}. Adjusting maxVal to {}",
this.name, input, this.minVal, minOverWindow);
this.maxVal = maxOverWindow;
setEncoderParams();
}
}
}
private void setEncoderParams() {
this.rangeInternal = this.maxVal - this.minVal;
this.resolution = this.rangeInternal / (this.n - this.w);
this.radius = this.w * this.resolution;
this.range = this.rangeInternal + this.resolution;
this.nInternal = this.n - 2 * this.padding;
this.bucketValues = null;
}
private Double[] appendItem(Double[] a, Double input) {
a = Arrays.copyOf(a, a.length + 1);
a[a.length - 1] = input;
return a;
}
private Double[] deleteItem(Double[] a, int i) {
a = Arrays.copyOfRange(a, 1, a.length - 1);
return a;
}
/**
* {@inheritDoc}
*/
@Override
public int[] getBucketIndices(String inputString) {
double input = Double.parseDouble(inputString);
return calculateBucketIndices(input);
}
/**
* {@inheritDoc}
*/
@Override
public int[] getBucketIndices(double input) {
return calculateBucketIndices(input);
}
private int[] calculateBucketIndices(double input) {
this.recordNum += 1;
boolean learn = false;
if (!this.encLearningEnabled) {
learn = true;
}
if ((Double.isNaN(input)) && (Double.valueOf(input) instanceof Double)) {
input = AdaptiveScalarEncoder.SENTINEL_VALUE_FOR_MISSING_DATA;
}
if (input == AdaptiveScalarEncoder.SENTINEL_VALUE_FOR_MISSING_DATA) {
return new int[this.n];
} else {
this.setMinAndMax(input, learn);
}
return super.getBucketIndices(input);
}
/**
* {@inheritDoc}
*/
@Override
public List getBucketInfo(int[] buckets) {
if (this.minVal == 0 || this.maxVal == 0) {
int[] initialBuckets = new int[this.n];
Arrays.fill(initialBuckets, 0);
List encoderResultList = new ArrayList();
EncoderResult encoderResult = new EncoderResult(0, 0, initialBuckets);
encoderResultList.add(encoderResult);
return encoderResultList;
}
return super.getBucketInfo(buckets);
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy