
org.numenta.nupic.encoders.RandomDistributedScalarEncoder Maven / Gradle / Ivy
/* ---------------------------------------------------------------------
* Numenta Platform for Intelligent Computing (NuPIC)
* Copyright (C) 2014, Numenta, Inc. Unless you have an agreement
* with Numenta, Inc., for a separate license for this software code, the
* following terms and conditions apply:
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero Public License version 3 as
* published by the Free Software Foundation.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
* See the GNU Affero Public License for more details.
*
* You should have received a copy of the GNU Affero Public License
* along with this program. If not, see http://www.gnu.org/licenses.
*
* http://numenta.org/licenses/
* ---------------------------------------------------------------------
*/
package org.numenta.nupic.encoders;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;
import org.numenta.nupic.FieldMetaType;
import org.numenta.nupic.util.ArrayUtils;
import org.numenta.nupic.util.Tuple;
import org.numenta.nupic.util.UniversalRandom;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
*
* A scalar encoder encodes a numeric (floating point) value into an array of
* bits.
*
* This class maps a scalar value into a random distributed representation that
* is suitable as scalar input into the spatial pooler. The encoding scheme is
* designed to replace a simple ScalarEncoder. It preserves the important
* properties around overlapping representations. Unlike ScalarEncoder the min
* and max range can be dynamically increased without any negative effects. The
* only required parameter is resolution, which determines the resolution of
* input values.
*
* Scalar values are mapped to a bucket. The class maintains a random
* distributed encoding for each bucket. The following properties are maintained
* by RandomDistributedEncoder:
*
*
* - Similar scalars should have high overlap. Overlap should decrease
* smoothly as scalars become less similar. Specifically, neighboring bucket
* indices must overlap by a linearly decreasing number of bits.
*
*
- Dissimilar scalars should have very low overlap so that the SP does not
* confuse representations. Specifically, buckets that are more than w indices
* apart should have at most maxOverlap bits of overlap. We arbitrarily (and
* safely) define "very low" to be 2 bits of overlap or lower.
*
* Properties 1 and 2 lead to the following overlap rules for buckets i and j:
*
*
* {@code
* If abs(i-j) < w then:
* overlap(i,j) = w - abs(i-j);
* else:
* overlap(i,j) <= maxOverlap;
* }
*
*
* - The representation for a scalar must not change during the lifetime of
* the object. Specifically, as new buckets are created and the min/max range is
* extended, the representation for previously in-range scalars and previously
* created buckets must not change.
*
*
*
* @author Numenta
* @author Anubhav Chaturvedi
*/
public class RandomDistributedScalarEncoder extends Encoder {
private static final long serialVersionUID = 1L;
private static final Logger LOG = LoggerFactory.getLogger(RandomDistributedScalarEncoder.class);
public static final long DEFAULT_SEED = 42;
// UniversalRandom is faster than MersenneTwister, therefore is recommended
UniversalRandom rng;
int maxOverlap;
int maxBuckets;
Double offset;
long seed;
int minIndex;
int maxIndex;
int numRetry;
ConcurrentHashMap> bucketMap;
RandomDistributedScalarEncoder() {
}
public static Encoder.Builder builder() {
return new RandomDistributedScalarEncoder.Builder();
}
/**
* Perform validation on state parameters and proceed with initialization of
* the encoder.
*
* @throws IllegalStateException
* Throws {@code IllegalStateException} containing appropriate
* message if some validation fails.
*/
public void init() throws IllegalStateException {
if (getW() <= 0 || getW() % 2 == 0)
throw new IllegalStateException(
"W must be an odd positive integer (to eliminate centering difficulty)");
setHalfWidth((getW() - 1) / 2);
if (getResolution() <= 0)
throw new IllegalStateException(
"Resolution must be a positive number");
if (n <= 6 * getW())
throw new IllegalStateException(
"n must be strictly greater than 6*w. For good results we "
+ "recommend n be strictly greater than 11*w.");
initEncoder(getResolution(), getW(), getN(), getOffset(), getSeed());
}
/**
* Perform the initialization of the encoder.
*
* @param resolution
* @param w
* @param n
* @param offset
* @param seed
*/
// TODO why are none of these parameters used..?
public void initEncoder(double resolution, int w, int n, Double offset, long seed) {
rng = (seed == -1) ? new UniversalRandom(DEFAULT_SEED) : new UniversalRandom(seed);
initializeBucketMap(getMaxBuckets(), getOffset());
if (getName() == null || getName().isEmpty()) {
setName("[" + getResolution() + "]");
}
// TODO reduce logging level?
LOG.debug(this.toString());
}
/**
* Initialize the bucket map assuming the given number of maxBuckets.
*
* @param maxBuckets
* @param offset
*/
public void initializeBucketMap(int maxBuckets, Double offset) {
/*
* The first bucket index will be _maxBuckets / 2 and bucket indices
* will be allowed to grow lower or higher as long as they don't become
* negative. _maxBuckets is required because the current CLA Classifier
* assumes bucket indices must be non-negative. This normally does not
* need to be changed but if altered, should be set to an even number.
*/
setMaxBuckets(maxBuckets);
setMinIndex(maxBuckets / 2);
setMaxIndex(maxBuckets / 2);
/*
* The scalar offset used to map scalar values to bucket indices. The
* middle bucket will correspond to numbers in the range
* [offset-resolution/2, offset+resolution/2). The bucket index for a
* number x will be: maxBuckets/2 + int( round( (x-offset)/resolution )
* )
*/
setOffset(offset);
/*
* This HashMap maps a bucket index into its bit representation We
* initialize the HashMap with a single bucket with index 0
*/
bucketMap = new ConcurrentHashMap>();
// generate the random permutation
int[] t = ArrayUtils.range(0, getN());
rng.shuffle(t);
bucketMap.put(getMinIndex(),
Arrays.stream(t).boxed().limit(getW()).collect(Collectors.toList()));
// How often we need to retry when generating valid encodings
setNumRetry(0);
}
/**
* Create the given bucket index. Recursively create as many in-between
* bucket indices as necessary.
*
* @param index the index at which bucket needs to be created
* @throws IllegalStateException
*/
public void createBucket(int index) throws IllegalStateException {
if (index < getMinIndex()) {
if (index == getMinIndex() - 1) {
/*
* Create a new representation that has exactly w-1 overlapping
* bits as the min representation
*/
bucketMap.put(index, newRepresentation(getMinIndex(), index));
setMinIndex(index);
} else {
// Recursively create all the indices above and then this index
createBucket(index + 1);
createBucket(index);
}
} else {
if (index == getMaxIndex() + 1) {
/*
* Create a new representation that has exactly w-1 overlapping
* bits as the max representation
*/
bucketMap.put(index, newRepresentation(getMaxIndex(), index));
setMaxIndex(index);
} else {
// Recursively create all the indices below and then this index
createBucket(index - 1);
createBucket(index);
}
}
}
/**
* Get a new representation for newIndex that overlaps with the
* representation at index by exactly w-1 bits
*
* @param index
* @param newIndex
* @throws IllegalStateException
*/
public List newRepresentation(int index, int newIndex)
throws IllegalStateException {
List newRepresentation = new ArrayList(
bucketMap.get(index));
/*
* Choose the bit we will replace in this representation. We need to
* shift this bit deterministically. If this is always chosen randomly
* then there is a 1 in w chance of the same bit being replaced in
* neighboring representations, which is fairly high
*/
int ri = newIndex % getW();
// Now we choose a bit such that the overlap rules are satisfied.
int newBit = rng.nextInt(getN());
newRepresentation.set(ri, newBit);
while (bucketMap.get(index).contains(newBit)
|| !newRepresentationOK(newRepresentation, newIndex)) {
setNumRetry(getNumRetry() + 1);
newBit = rng.nextInt(getN());
newRepresentation.set(ri, newBit);
}
return newRepresentation;
}
/**
* Check if this new candidate representation satisfies all our
* overlap rules. Since we know that neighboring representations differ by
* at most one bit, we compute running overlaps.
*
* @param newRep Encoded SDR to be considered
* @param newIndex The index being considered
* @return {@code true} if newRep satisfies all our overlap rules
* @throws IllegalStateException
*/
public boolean newRepresentationOK(List newRep, int newIndex) {
if (newRep.size() != getW())
return false;
if (newIndex < getMinIndex() - 1 || newIndex > getMaxIndex() + 1) {
throw new IllegalStateException(
"newIndex must be within one of existing indices");
}
// A binary representation of newRep. We will use this to test
// containment
boolean[] newRepBinary = new boolean[getN()];
Arrays.fill(newRepBinary, false);
for (int index : newRep)
newRepBinary[index] = true;
// Midpoint
int midIdx = getMaxBuckets() / 2;
// Start by checking the overlap at minIndex
int runningOverlap = countOverlap(bucketMap.get(getMinIndex()), newRep);
if (!overlapOK(getMinIndex(), newIndex, runningOverlap))
return false;
// Compute running overlaps all the way to the midpoint
for (int i = getMinIndex() + 1; i < midIdx + 1; i++) {
// This is the bit that is going to change
int newBit = (i - 1) % getW();
// Update our running overlap
if (newRepBinary[bucketMap.get(i - 1).get(newBit)])
runningOverlap--;
if (newRepBinary[bucketMap.get(i).get(newBit)])
runningOverlap++;
// Verify our rules
if (!overlapOK(i, newIndex, runningOverlap))
return false;
}
// At this point, runningOverlap contains the overlap for midIdx
// Compute running overlaps all the way to maxIndex
for (int i = midIdx + 1; i <= getMaxIndex(); i++) {
int newBit = i % getW();
// Update our running overlap
if (newRepBinary[bucketMap.get(i - 1).get(newBit)])
runningOverlap--;
if (newRepBinary[bucketMap.get(i).get(newBit)])
runningOverlap++;
// Verify our rules
if (!overlapOK(i, newIndex, runningOverlap))
return false;
}
return true;
}
/**
* Get the overlap between two representations. rep1 and rep2 are
* {@link List} of non-zero indices.
*
* @param rep1 The first representation for overlap calculation
* @param rep2 The second representation for overlap calculation
* @return The number of 'on' bits that overlap
*/
public int countOverlap(List rep1, List rep2) {
int overlap = 0;
for (int index : rep1) {
for (int index2 : rep2)
if (index == index2)
overlap++;
}
return overlap;
}
/**
* Get the overlap between two representations. rep1 and rep2 are arrays
* of non-zero indices.
*
* @param rep1 The first representation for overlap calculation
* @param rep2 The second representation for overlap calculation
* @return The number of 'on' bits that overlap
*/
public int countOverlap(int[] rep1, int[] rep2) {
int overlap = 0;
for (int index : rep1) {
for (int index2 : rep2)
if (index == index2)
overlap++;
}
return overlap;
}
/**
* Check if the given overlap between bucket indices i and j are acceptable.
*
* @param i The index of the bucket to be compared
* @param j The index of the bucket to be compared
* @param overlap The overlap between buckets at index i and j
* @return {@code true} if overlap is acceptable, else {@code false}
*/
public boolean overlapOK(int i, int j, int overlap) {
if (Math.abs(i - j) < getW() && overlap == (getW() - Math.abs(i - j)))
return true;
if (Math.abs(i - j) >= getW() && overlap <= getMaxOverlap())
return true;
return false;
}
/**
* Check if the overlap between the buckets at indices i and j are
* acceptable. The overlap is calculate from the bucketMap.
*
* @param i The index of the bucket to be compared
* @param j The index of the bucket to be compared
* @return {@code true} if the given overlap is acceptable, else {@code false}
* @throws IllegalStateException
*/
public boolean overlapOK(int i, int j) throws IllegalStateException {
return overlapOK(i, j, countOverlapIndices(i, j));
}
/**
* Get the overlap between bucket at indices i and j
*
* @param i The index of the bucket
* @param j The index of the bucket
* @return the overlap between bucket at indices i and j
* @throws IllegalStateException
*/
public int countOverlapIndices(int i, int j) throws IllegalStateException {
boolean containsI = bucketMap.containsKey(i);
boolean containsJ = bucketMap.containsKey(j);
if (containsI && containsJ) {
List rep1 = bucketMap.get(i);
List rep2 = bucketMap.get(j);
return countOverlap(rep1, rep2);
} else if (!containsI && !containsJ)
throw new IllegalStateException("index " + i + " and " + j + " don't exist");
else if(!containsI)
throw new IllegalStateException("index " + i + " doesn't exist");
else
throw new IllegalStateException("index " + j + " doesn't exist");
}
/**
* Given a bucket index, return the list of indices of the 'on' bits. If the
* bucket index does not exist, it is created. If the index falls outside
* our range we clip it.
*
* @param index The bucket index
* @return The list of active bits in the representation
* @throws IllegalStateException
*/
public List mapBucketIndexToNonZeroBits(int index)
throws IllegalStateException {
if (index < 0)
index = 0;
if (index >= getMaxBuckets())
index = getMaxBuckets() - 1;
if (!bucketMap.containsKey(index)) {
LOG.trace("Adding additional buckets to handle index={} ", index);
createBucket(index);
}
return bucketMap.get(index);
}
/**
* {@inheritDoc}
*/
@Override
public int[] getBucketIndices(double x) {
if (Double.isNaN(x))
x = Encoder.SENTINEL_VALUE_FOR_MISSING_DATA;
int test = Double.compare(x, Encoder.SENTINEL_VALUE_FOR_MISSING_DATA);
if (test == 0)
return new int[0];
if (getOffset() == null)
setOffset(x);
/*
* Difference in the round function behavior for Python and Java In
* Python, the absolute value is rounded up and sign is applied in Java,
* value is rounded to next biggest integer
*
* so for Python, round(-0.5) => -1.0 whereas in Java, Math.round(-0.5)
* => 0.0
*/
double deltaIndex = (x - getOffset()) / getResolution();
int sign = (int) (deltaIndex / Math.abs(deltaIndex));
int bucketIdx = (getMaxBuckets() / 2)
+ (sign * (int) Math.round(Math.abs(deltaIndex)));
if (bucketIdx < 0)
bucketIdx = 0;
else if (bucketIdx >= getMaxBuckets())
bucketIdx = getMaxBuckets() - 1;
int[] bucketIdxArray = new int[1];
bucketIdxArray[0] = bucketIdx;
return bucketIdxArray;
}
/**
* {@inheritDoc}
*/
@Override
public int getWidth() {
return getN();
}
/**
* {@inheritDoc}
*/
@Override
public boolean isDelta() {
return false;
}
/**
* {@inheritDoc}
*/
@Override
public void setLearning(boolean learningEnabled) {
setLearningEnabled(learningEnabled);
}
/**
* {@inheritDoc}
*/
@Override
public List getDescription() {
String name = getName();
if (name == null || name.isEmpty())
setName("[" + getResolution() + "]");
name = getName();
return new ArrayList(Arrays.asList(new Tuple[] { new Tuple(name,
0) }));
}
/**
* @return maxOverlap for this RDSE
*/
public int getMaxOverlap() {
return maxOverlap;
}
/**
* @return the maxBuckets for this RDSE
*/
public int getMaxBuckets() {
return maxBuckets;
}
/**
* @return the seed for the random number generator
*/
public long getSeed() {
return seed;
}
/**
* @return the offset
*/
public Double getOffset() {
return offset;
}
private int getMinIndex() {
return minIndex;
}
private int getMaxIndex() {
return maxIndex;
}
/**
* @return the number of retry to create new bucket
*/
public int getNumRetry() {
return numRetry;
}
/**
* @param maxOverlap The maximum permissible overlap between representations
*/
public void setMaxOverlap(int maxOverlap) {
this.maxOverlap = maxOverlap;
}
/**
* @param maxBuckets the new maximum number of buckets allowed
*/
public void setMaxBuckets(int maxBuckets) {
this.maxBuckets = maxBuckets;
}
/**
* @param seed
*/
public void setSeed(long seed) {
this.seed = seed;
}
/**
* @param offset
*/
public void setOffset(Double offset) {
this.offset = offset;
}
private void setMinIndex(int minIndex) {
this.minIndex = minIndex;
}
private void setMaxIndex(int maxIndex) {
this.maxIndex = maxIndex;
}
/**
* @param numRetry New number of retries for new representation
*/
public void setNumRetry(int numRetry) {
this.numRetry = numRetry;
}
@Override
public String toString() {
// TODO don't mix StringBuilder appending with String concatenation
StringBuilder dumpString = new StringBuilder(50);
dumpString.append("RandomDistributedScalarEncoder:\n");
dumpString.append(" minIndex: " + getMinIndex() + "\n");
dumpString.append(" maxIndex: " + getMaxIndex() + "\n");
dumpString.append(" w: " + getW() + "\n");
dumpString.append(" n: " + getWidth() + "\n");
dumpString.append(" resolution: " + getResolution() + "\n");
dumpString.append(" offset: " + getOffset() + "\n");
dumpString.append(" numTries: " + getNumRetry() + "\n");
dumpString.append(" name: " + getName() + "\n");
dumpString.append(" buckets : \n");
for (int index : bucketMap.keySet()) {
dumpString.append(" [ " + index + " ]: "
+ Arrays.deepToString(bucketMap.get(index).toArray())
+ "\n");
}
return dumpString.toString();
}
/**
*
* Returns a {@link Encoder.Builder} for constructing
* {@link RandomDistributedScalarEncoder}s.
*
*
* The base class architecture is put together in such a way where
* boilerplate initialization can be kept to a minimum for implementing
* subclasses, while avoiding the mistake-proneness of extremely long
* argument lists.
*
*
* @author Anubhav Chaturvedi
*/
public static class Builder
extends
Encoder.Builder {
private int maxOverlap;
private int maxBuckets;
private Double offset;
private long seed;
int minIndex;
int maxIndex;
private Builder() {
this.n(400);
this.w(21);
seed = 42;
maxBuckets = 1000;
maxOverlap = 2;
offset = null;
}
private Builder(int n, int w) {
this();
this.n(n);
this.w(w);
}
@Override
public RandomDistributedScalarEncoder build() {
// Must be instantiated so that super class can initialize
// boilerplate variables.
encoder = new RandomDistributedScalarEncoder();
// Call super class here
RandomDistributedScalarEncoder partialBuild = super.build();
////////////////////////////////////////////////////////
// Implementing classes would do setting of specific //
// vars here together with any sanity checking //
////////////////////////////////////////////////////////
partialBuild.setSeed(seed);
partialBuild.setMaxOverlap(maxOverlap);
partialBuild.setMaxBuckets(maxBuckets);
partialBuild.setOffset(offset);
partialBuild.setNumRetry(0);
partialBuild.init();
return partialBuild;
}
public RandomDistributedScalarEncoder.Builder setOffset(double offset) {
this.offset = Double.valueOf(offset);
return this;
}
public RandomDistributedScalarEncoder.Builder setMaxBuckets(
int maxBuckets) {
this.maxBuckets = maxBuckets;
return this;
}
public RandomDistributedScalarEncoder.Builder setMaxOverlap(
int maxOverlap) {
this.maxOverlap = maxOverlap;
return this;
}
public RandomDistributedScalarEncoder.Builder setSeed(long seed) {
this.seed = seed;
return this;
}
}
/**
* {@inheritDoc}
*/
@Override
public void encodeIntoArray(Double inputData, int[] output) {
int[] bucketIdx = getBucketIndices(inputData);
Arrays.fill(output, 0);
if (bucketIdx.length == 0)
return;
if (bucketIdx[0] != Integer.MIN_VALUE) {
List indices;
try {
indices = mapBucketIndexToNonZeroBits(bucketIdx[0]);
for (int index : indices)
output[index] = 1;
} catch (IllegalStateException e) {
e.printStackTrace();
}
}
}
/**
* {@inheritDoc}
*/
@SuppressWarnings("unchecked")
@Override
public List getBucketValues(Class returnType) {
return new ArrayList<>((Collection)this.bucketMap.values());
}
/**
* {@inheritDoc}
*/
@Override
public Set getDecoderOutputFieldTypes() {
return new LinkedHashSet(Arrays.asList(FieldMetaType.FLOAT, FieldMetaType.INTEGER));
}
}