All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.numenta.nupic.algorithms.CLAClassifier Maven / Gradle / Ivy

/* ---------------------------------------------------------------------
 * Numenta Platform for Intelligent Computing (NuPIC)
 * Copyright (C) 2014, Numenta, Inc.  Unless you have an agreement
 * with Numenta, Inc., for a separate license for this software code, the
 * following terms and conditions apply:
 *
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License version 3 as
 * published by the Free Software Foundation.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
 * See the GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see http://www.gnu.org/licenses.
 *
 * http://numenta.org/licenses/
 * ---------------------------------------------------------------------
 */

package org.numenta.nupic.algorithms;

import gnu.trove.list.TIntList;
import gnu.trove.list.array.TIntArrayList;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import org.numenta.nupic.util.ArrayUtils;
import org.numenta.nupic.util.Deque;
import org.numenta.nupic.util.Tuple;

import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import com.fasterxml.jackson.databind.annotation.JsonSerialize;

/**
 * A CLA classifier accepts a binary input from the level below (the
 * "activationPattern") and information from the sensor and encoders (the
 * "classification") describing the input to the system at that time step.
 *
 * When learning, for every bit in activation pattern, it records a history of the
 * classification each time that bit was active. The history is weighted so that
 * more recent activity has a bigger impact than older activity. The alpha
 * parameter controls this weighting.
 *
 * For inference, it takes an ensemble approach. For every active bit in the
 * activationPattern, it looks up the most likely classification(s) from the
 * history stored for that bit and then votes across these to get the resulting
 * classification(s).
 *
 * This classifier can learn and infer a number of simultaneous classifications
 * at once, each representing a shift of a different number of time steps. For
 * example, say you are doing multi-step prediction and want the predictions for
 * 1 and 3 time steps in advance. The CLAClassifier would learn the associations
 * between the activation pattern for time step T and the classifications for
 * time step T+1, as well as the associations between activation pattern T and
 * the classifications for T+3. The 'steps' constructor argument specifies the
 * list of time-steps you want.
 * 
 * @author Numenta
 * @author David Ray
 * @see BitHistory
 */
@JsonSerialize(using=CLAClassifierSerializer.class)
@JsonDeserialize(using=CLAClassifierDeserializer.class)
public class CLAClassifier {
	int verbosity = 0;
	/**
	 * The alpha used to compute running averages of the bucket duty
     * cycles for each activation pattern bit. A lower alpha results
     * in longer term memory.
	 */
	double alpha = 0.001;
	double actValueAlpha = 0.3;
	/** 
	 * The bit's learning iteration. This is updated each time store() gets
     * called on this bit.
	 */
	int learnIteration;
	/**
	 * This contains the offset between the recordNum (provided by caller) and
     * learnIteration (internal only, always starts at 0).
	 */
	int recordNumMinusLearnIteration = -1;
	/**
	 * This contains the value of the highest bucket index we've ever seen
     * It is used to pre-allocate fixed size arrays that hold the weights of
     * each bucket index during inference 
	 */
	int maxBucketIdx;
	/** The sequence different steps of multi-step predictions */
	TIntList steps = new TIntArrayList();
	/**
	 * History of the last _maxSteps activation patterns. We need to keep
     * these so that we can associate the current iteration's classification
     * with the activationPattern from N steps ago
	 */
	Deque patternNZHistory;
	/**
	 * These are the bit histories. Each one is a BitHistory instance, stored in
     * this dict, where the key is (bit, nSteps). The 'bit' is the index of the
     * bit in the activation pattern and nSteps is the number of steps of
     * prediction desired for that bit.
	 */
	Map activeBitHistory = new HashMap();
	/**
	 * This keeps track of the actual value to use for each bucket index. We
     * start with 1 bucket, no actual value so that the first infer has something
     * to return
	 */
	List actualValues = new ArrayList();
	
	String g_debugPrefix = "CLAClassifier";
	
	
	/**
	 * CLAClassifier no-arg constructor with defaults
	 */
	public CLAClassifier() {
		this(new TIntArrayList(new int[] { 1 }), 0.001, 0.3, 0);
	}
	
	/**
	 * Constructor for the CLA classifier
	 * 
	 * @param steps				sequence of the different steps of multi-step predictions to learn
	 * @param alpha				The alpha used to compute running averages of the bucket duty
               					cycles for each activation pattern bit. A lower alpha results
               					in longer term memory.
	 * @param actValueAlpha
	 * @param verbosity			verbosity level, can be 0, 1, or 2
	 */
	public CLAClassifier(TIntList steps, double alpha, double actValueAlpha, int verbosity) {
		this.steps = steps;
		this.alpha = alpha;
		this.actValueAlpha = actValueAlpha;
		this.verbosity = verbosity;
		actualValues.add(null);
		patternNZHistory = new Deque(ArrayUtils.max(steps.toArray()) + 1);
	}
	
	/**
	 * Process one input sample.
     * This method is called by outer loop code outside the nupic-engine. We
     * use this instead of the nupic engine compute() because our inputs and
     * outputs aren't fixed size vectors of reals.
     * 
	 * @param recordNum			Record number of this input pattern. Record numbers should
     *           				normally increase sequentially by 1 each time unless there
     *           				are missing records in the dataset. Knowing this information
     *           				insures that we don't get confused by missing records.
	 * @param classification	{@link Map} of the classification information:
     *                 			bucketIdx: index of the encoder bucket
     *                 			actValue:  actual value going into the encoder
	 * @param patternNZ			list of the active indices from the output below
	 * @param learn				if true, learn this sample
	 * @param infer				if true, perform inference
	 * 
	 * @return					dict containing inference results, there is one entry for each
     *           				step in steps, where the key is the number of steps, and
     *           				the value is an array containing the relative likelihood for
     *           				each bucketIdx starting from bucketIdx 0.
	 *
     *           				There is also an entry containing the average actual value to
     *           				use for each bucket. The key is 'actualValues'.
	 *
     *           				for example:
     *             				{	
     *             					1 :             [0.1, 0.3, 0.2, 0.7],
     *              				4 :             [0.2, 0.4, 0.3, 0.5],
     *              				'actualValues': [1.5, 3,5, 5,5, 7.6],
     *             				}
	 */
	@SuppressWarnings("unchecked")
	public  ClassifierResult compute(int recordNum, Map classification, int[] patternNZ, boolean learn, boolean infer) {
		ClassifierResult retVal = new ClassifierResult();
		List actualValues = (List)this.actualValues;
		
		// Save the offset between recordNum and learnIteration if this is the first
	    // compute
		if(recordNumMinusLearnIteration == -1) {
			recordNumMinusLearnIteration = recordNum - learnIteration;
		}
		
		// Update the learn iteration
		learnIteration = recordNum - recordNumMinusLearnIteration;
		
		if(verbosity >= 1) {
			System.out.println(String.format("\n%s: compute ", g_debugPrefix));
			System.out.println(" recordNum: " + recordNum);
			System.out.println(" learnIteration: " + learnIteration);
			System.out.println(String.format(" patternNZ(%d): ", patternNZ.length, patternNZ));
			System.out.println(" classificationIn: " + classification);
		}
		
		patternNZHistory.append(new Tuple(learnIteration, patternNZ));
		
		//------------------------------------------------------------------------
	    // Inference:
	    // For each active bit in the activationPattern, get the classification
	    // votes
		//
		// Return value dict. For buckets which we don't have an actual value
	    // for yet, just plug in any valid actual value. It doesn't matter what
	    // we use because that bucket won't have non-zero likelihood anyways.
		if(infer) {
			// NOTE: If doing 0-step prediction, we shouldn't use any knowledge
		    //		 of the classification input during inference.
			Object defaultValue = null;
			if(steps.get(0) == 0) {
				defaultValue = 0;
			}else{
				defaultValue = classification.get("actValue");
			}
			
			T[] actValues = (T[])new Object[this.actualValues.size()];
			for(int i = 0;i < actualValues.size();i++) {
				actValues[i] = (T)(actualValues.get(i) == null ? defaultValue : actualValues.get(i));
			}
			
			retVal.setActualValues(actValues);
			
			// For each n-step prediction...
			for(int nSteps : steps.toArray()) {
				// Accumulate bucket index votes and actValues into these arrays
				double[] sumVotes = new double[maxBucketIdx + 1];
				double[] bitVotes = new double[maxBucketIdx + 1];
				
				for(int bit : patternNZ) {
					Tuple key = new Tuple(bit, nSteps);
					BitHistory history = activeBitHistory.get(key);
					if(history == null) continue;
					
					history.infer(learnIteration, bitVotes);
					
					sumVotes = ArrayUtils.d_add(sumVotes, bitVotes);
				}
				
				// Return the votes for each bucket, normalized
				double total = ArrayUtils.sum(sumVotes);
				if(total > 0) {
					sumVotes = ArrayUtils.divide(sumVotes, total);
				}else{
					// If all buckets have zero probability then simply make all of the
			        // buckets equally likely. There is no actual prediction for this
			        // timestep so any of the possible predictions are just as good.
					if(sumVotes.length > 0) {
						Arrays.fill(sumVotes, 1.0 / (double)sumVotes.length);
					}
				}
				
				retVal.setStats(nSteps, sumVotes);
			}
		}
		
		// ------------------------------------------------------------------------
	    // Learning:
	    // For each active bit in the activationPattern, store the classification
	    // info. If the bucketIdx is None, we can't learn. This can happen when the
	    // field is missing in a specific record.
		if(learn && classification.get("bucketIdx") != null) {
			// Get classification info
			int bucketIdx = (int)classification.get("bucketIdx");
			Object actValue = classification.get("actValue");
			
			// Update maxBucketIndex
			maxBucketIdx = (int) Math.max(maxBucketIdx, bucketIdx);
			
			// Update rolling average of actual values if it's a scalar. If it's
		    // not, it must be a category, in which case each bucket only ever
		    // sees one category so we don't need a running average.
			while(maxBucketIdx > actualValues.size() - 1) {
				actualValues.add(null);
			}
			if(actualValues.get(bucketIdx) == null) {
				actualValues.set(bucketIdx, (T)actValue);
			}else{
				if(Number.class.isAssignableFrom(actValue.getClass())) {
					Double val = ((1.0 - actValueAlpha) * ((Number)actualValues.get(bucketIdx)).doubleValue() + 
						actValueAlpha * ((Number)actValue).doubleValue());
					actualValues.set(bucketIdx, (T)val);
				}else{
					actualValues.set(bucketIdx, (T)actValue);
				}
			}
			
			// Train each pattern that we have in our history that aligns with the
		    // steps we have in steps
			int nSteps = -1;
			int iteration = 0;
			int[] learnPatternNZ = null;
			for(int n : steps.toArray()) {
				nSteps = n;
				// Do we have the pattern that should be assigned to this classification
		        // in our pattern history? If not, skip it
				boolean found = false;
				for(Tuple t : patternNZHistory) {
					iteration = (int)t.get(0);
					learnPatternNZ = (int[]) t.get(1);
					if(iteration == learnIteration - nSteps) {
						found = true;
						break;
					}
					iteration++;
				}
				if(!found) continue;
				
				// Store classification info for each active bit from the pattern
		        // that we got nSteps time steps ago.
				for(int bit : learnPatternNZ) {
					// Get the history structure for this bit and step
					Tuple key = new Tuple(bit, nSteps);
					BitHistory history = activeBitHistory.get(key);
					if(history == null) {
						activeBitHistory.put(key, history = new BitHistory(this, bit, nSteps));
					}
					history.store(learnIteration, bucketIdx);
				}
			}
		}
		
		if(infer && verbosity >= 1) {
			System.out.println(" inference: combined bucket likelihoods:");
			System.out.println("   actual bucket values: " + Arrays.toString((T[])retVal.getActualValues()));
			
			for(int key : retVal.stepSet()) {
				if(retVal.getActualValue(key) == null) continue;
				
				Object[] actual = new Object[] { (T)retVal.getActualValue(key) };
				System.out.println(String.format("  %d steps: ", key, pFormatArray(actual)));
				int bestBucketIdx = retVal.getMostProbableBucketIndex(key);
				System.out.println(String.format("   most likely bucket idx: %d, value: %s ", bestBucketIdx, 
					retVal.getActualValue(bestBucketIdx)));
				
			}
		}
		
		return retVal;
	}
	
	/**
	 * Return a string with pretty-print of an array using the given format
  	 * for each element
  	 * 
	 * @param arr
	 * @return
	 */
	private  String pFormatArray(T[] arr) {
		if(arr == null) return "";
		
		StringBuilder sb = new StringBuilder("[ ");
		for(T t : arr) {
			sb.append(String.format("%.2s", t));
		}
		sb.append(" ]");
		return sb.toString();
	}
	
	public String serialize() {
		String json = null;
		ObjectMapper mapper = new ObjectMapper();
		try {
			json = mapper.writeValueAsString(this);
		}catch(Exception e) {
			e.printStackTrace();
		}
		
		return json;
	}
	
	public static CLAClassifier deSerialize(String jsonStrategy) {
		ObjectMapper om = new ObjectMapper();
		CLAClassifier c = null;
		try {
			Object o = om.readValue(jsonStrategy, CLAClassifier.class);
			c = (CLAClassifier)o;
		} catch (Exception e) {
			e.printStackTrace();
		}
		
		return c;
	}
}