
org.numenta.nupic.algorithms.Classification Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of htm.java Show documentation
Show all versions of htm.java Show documentation
The Java version of Numenta's HTM technology
The newest version!
package org.numenta.nupic.algorithms;
import java.util.Arrays;
import org.numenta.nupic.model.Persistable;
import gnu.trove.map.TIntObjectMap;
import gnu.trove.map.hash.TIntObjectHashMap;
/**
* Container for the results of a classification computation by the
* {@link CLAClassifier}
*
* @author David Ray
*
* @param
*/
public class Classification implements Persistable {
private static final long serialVersionUID = 1L;
/** Array of actual values */
private T[] actualValues;
/** Map of step count -to- probabilities */
TIntObjectMap probabilities = new TIntObjectHashMap();
/**
* Utility method to copy the contents of a ClassifierResult.
*
* @return a copy of this {@code ClassifierResult} which will not be affected
* by changes to the original.
*/
public Classification copy() {
Classification retVal = new Classification();
retVal.actualValues = Arrays.copyOf(actualValues, actualValues.length);
retVal.probabilities = new TIntObjectHashMap(probabilities);
return retVal;
}
/**
* Returns the actual value for the specified bucket index
*
* @param bucketIndex
* @return
*/
public T getActualValue(int bucketIndex) {
if(actualValues == null || actualValues.length < bucketIndex + 1) {
return null;
}
return (T)actualValues[bucketIndex];
}
/**
* Returns all actual values entered
*
* @return array of type <T>
*/
public T[] getActualValues() {
return actualValues;
}
/**
* Sets the array of actual values being entered.
*
* @param values
* @param <T>[] the value array type
*/
public void setActualValues(T[] values) {
actualValues = values;
}
/**
* Returns a count of actual values entered
* @return
*/
public int getActualValueCount() {
return actualValues.length;
}
/**
* Returns the probability at the specified index for the given step
* @param step
* @param bucketIndex
* @return
*/
public double getStat(int step, int bucketIndex) {
return probabilities.get(step)[bucketIndex];
}
/**
* Sets the array of probabilities for the specified step
* @param step
* @param votes
*/
public void setStats(int step, double[] votes) {
probabilities.put(step, votes);
}
/**
* Returns the probabilities for the specified step
* @param step
* @return
*/
public double[] getStats(int step) {
return probabilities.get(step);
}
/**
* Returns the input value corresponding with the highest probability
* for the specified step.
*
* @param step the step key under which the most probable value will be returned.
* @return
*/
public T getMostProbableValue(int step) {
int idx = -1;
if(probabilities.get(step) == null || (idx = getMostProbableBucketIndex(step)) == -1) {
return null;
}
return getActualValue(idx);
}
/**
* Returns the bucket index corresponding with the highest probability
* for the specified step.
*
* @param step the step key under which the most probable index will be returned.
* @return -1 if there is no such entry
*/
public int getMostProbableBucketIndex(int step) {
if(probabilities.get(step) == null) return -1;
double max = 0;
int bucketIdx = -1;
int i = 0;
for(double d : probabilities.get(step)) {
if(d > max) {
max = d;
bucketIdx = i;
}
++i;
}
return bucketIdx;
}
/**
* Returns the count of steps
* @return
*/
public int getStepCount() {
return probabilities.size();
}
/**
* Returns the count of probabilities for the specified step
* @param the step indexing the probability values
* @return
*/
public int getStatCount(int step) {
return probabilities.get(step).length;
}
/**
* Returns a set of steps being recorded.
* @return
*/
public int[] stepSet() {
return probabilities.keySet().toArray();
}
@Override
public int hashCode() {
final int prime = 31;
int result = 1;
result = prime * result + Arrays.hashCode(actualValues);
result = prime * result + ((probabilities == null) ? 0 : probabilities.hashCode());
return result;
}
@Override
public boolean equals(Object obj) {
if(this == obj)
return true;
if(obj == null)
return false;
if(getClass() != obj.getClass())
return false;
@SuppressWarnings("rawtypes")
Classification other = (Classification)obj;
if(!Arrays.equals(actualValues, other.actualValues))
return false;
if(probabilities == null) {
if(other.probabilities != null)
return false;
} else {
for(int key : probabilities.keys()) {
if(!Arrays.equals(probabilities.get(key), (double[])other.probabilities.get(key))) {
return false;
}
}
}
return true;
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy