All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.numenta.nupic.research.TemporalMemory Maven / Gradle / Ivy

There is a newer version: 0.6.13
Show newest version
/* ---------------------------------------------------------------------
 * Numenta Platform for Intelligent Computing (NuPIC)
 * Copyright (C) 2014, Numenta, Inc.  Unless you have an agreement
 * with Numenta, Inc., for a separate license for this software code, the
 * following terms and conditions apply:
 *
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License version 3 as
 * published by the Free Software Foundation.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
 * See the GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see http://www.gnu.org/licenses.
 *
 * http://numenta.org/licenses/
 * ---------------------------------------------------------------------
 */

package org.numenta.nupic.research;

import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

import org.numenta.nupic.Connections;
import org.numenta.nupic.model.Cell;
import org.numenta.nupic.model.Column;
import org.numenta.nupic.model.DistalDendrite;
import org.numenta.nupic.model.Synapse;
import org.numenta.nupic.util.SparseObjectMatrix;

/**
 * Temporal Memory implementation in Java
 * 
 * @author Chetan Surpur
 * @author David Ray
 */
public class TemporalMemory {
    
    /**
     * Constructs a new {@code TemporalMemory}
     */
    public TemporalMemory() {}
    
    /**
     * Uses the specified {@link Connections} object to Build the structural 
     * anatomy needed by this {@code TemporalMemory} to implement its algorithms.
     * 
     * The connections object holds the {@link Column} and {@link Cell} infrastructure,
     * and is used by both the {@link SpatialPooler} and {@link TemporalMemory}. Either of
     * these can be used separately, and therefore this Connections object may have its
     * Columns and Cells initialized by either the init method of the SpatialPooler or the
     * init method of the TemporalMemory. We check for this so that complete initialization
     * of both Columns and Cells occurs, without either being redundant (initialized more than
     * once). However, {@link Cell}s only get created when initializing a TemporalMemory, because
     * they are not used by the SpatialPooler.
     * 
     * @param	c		{@link Connections} object
     */
    public void init(Connections c) {
    	SparseObjectMatrix matrix = c.getMemory() == null ?
    		new SparseObjectMatrix(c.getColumnDimensions()) :
    			c.getMemory();
    	c.setMemory(matrix);
    	
    	int numColumns = matrix.getMaxIndex() + 1;
    	int cellsPerColumn = c.getCellsPerColumn();
        Cell[] cells = new Cell[numColumns * cellsPerColumn];
        
        //Used as flag to determine if Column objects have been created.
        Column colZero = matrix.getObject(0);
        for(int i = 0;i < numColumns;i++) {
            Column column = colZero == null ? 
            	new Column(cellsPerColumn, i) : matrix.getObject(i);
            for(int j = 0;j < cellsPerColumn;j++) {
                cells[i * cellsPerColumn + j] = column.getCell(j);
            }
            //If columns have not been previously configured
            if(colZero == null) matrix.set(i, column);
        }
        //Only the TemporalMemory initializes cells so no need to test 
        c.setCells(cells);
    }
    
    /////////////////////////// CORE FUNCTIONS /////////////////////////////
    
    /**
     * Feeds input record through TM, performing inferencing and learning
     * 
     * @param connections		the connection memory
     * @param activeColumns     direct proximal dendrite input
     * @param learn             learning mode flag
     * @return                  {@link ComputeCycle} container for one cycle of inference values.
     */
    public ComputeCycle compute(Connections connections, int[] activeColumns, boolean learn) {
        ComputeCycle result = computeFn(connections, connections.getColumnSet(activeColumns), new LinkedHashSet(connections.getPredictiveCells()), 
            new LinkedHashSet(connections.getActiveSegments()), new LinkedHashMap>(connections.getActiveSynapsesForSegment()), 
                new LinkedHashSet(connections.getWinnerCells()), learn);
        
        connections.setActiveCells(result.activeCells());
        connections.setWinnerCells(result.winnerCells());
        connections.setPredictiveCells(result.predictiveCells());
        connections.setSuccessfullyPredictedColumns(result.successfullyPredictedColumns());
        connections.setActiveSegments(result.activeSegments());
        connections.setLearningSegments(result.learningSegments());
        connections.setActiveSynapsesForSegment(result.activeSynapsesForSegment());
        
        return result; 
    }
    
    /**
     * Functional version of {@link #compute(int[], boolean)}. 
     * This method is stateless and concurrency safe.
     * 
     * @param c                             {@link Connections} object containing state of memory members
     * @param activeColumns                 proximal dendrite input
     * @param prevPredictiveCells           cells predicting in t-1
     * @param prevActiveSegments            active segments in t-1
     * @param prevActiveSynapsesForSegment  {@link Synapse}s active in t-1
     * @param prevWinnerCells   `           previous winners
     * @param learn                         whether mode is "learning" mode
     * @return
     */
    public ComputeCycle computeFn(Connections c, Set activeColumns, Set prevPredictiveCells, Set prevActiveSegments,
        Map> prevActiveSynapsesForSegment, Set prevWinnerCells, boolean learn) {
        
        ComputeCycle cycle = new ComputeCycle();
        
        activateCorrectlyPredictiveCells(cycle, prevPredictiveCells, activeColumns);
        
        burstColumns(cycle, c, activeColumns, cycle.successfullyPredictedColumns, prevActiveSynapsesForSegment);
        
        if(learn) {
            learnOnSegments(c, prevActiveSegments, cycle.learningSegments, prevActiveSynapsesForSegment, cycle.winnerCells, prevWinnerCells);
        }
        
        cycle.activeSynapsesForSegment = computeActiveSynapses(c, cycle.activeCells);
        
        computePredictiveCells(c, cycle, cycle.activeSynapsesForSegment);
        
        return cycle;
    }

    /**
     * Phase 1: Activate the correctly predictive cells
     * 
     * Pseudocode:
     *
     * - for each previous predictive cell
     *   - if in active column
     *     - mark it as active
     *     - mark it as winner cell
     *     - mark column as predicted
     *     
     * @param c                     ComputeCycle interim values container
     * @param prevPredictiveCells   predictive {@link Cell}s predictive cells in t-1
     * @param activeColumns         active columns in t
     */
    public void activateCorrectlyPredictiveCells(ComputeCycle c, Set prevPredictiveCells, Set activeColumns) {
        for(Cell cell : prevPredictiveCells) {
            Column column = cell.getParentColumn();
            if(activeColumns.contains(column)) {
                c.activeCells.add(cell);
                c.winnerCells.add(cell);
                c.successfullyPredictedColumns.add(column);
            }
        }
    }
    
    /**
     * Phase 2: Burst unpredicted columns.
     * 
     * Pseudocode:
     *
     * - for each unpredicted active column
     *   - mark all cells as active
     *   - mark the best matching cell as winner cell
     *     - (learning)
     *       - if it has no matching segment
     *         - (optimization) if there are previous winner cells
     *           - add a segment to it
     *       - mark the segment as learning
     * 
     * @param cycle                         ComputeCycle interim values container
     * @param c                             Connections temporal memory state
     * @param activeColumns                 active columns in t
     * @param predictedColumns              predicted columns in t
     * @param prevActiveSynapsesForSegment  LinkedHashMap of previously active segments which
     *                                      have had synapses marked as active in t-1     
     */
    public void burstColumns(ComputeCycle cycle, Connections c, Set activeColumns, Set predictedColumns, 
        Map> prevActiveSynapsesForSegment) {
        
    	activeColumns.removeAll(predictedColumns);
    	for(Column column : activeColumns) {
            List cells = column.getCells();
            cycle.activeCells.addAll(cells);
            
            Object[] bestSegmentAndCell = getBestMatchingCell(c, column, prevActiveSynapsesForSegment);
            DistalDendrite bestSegment = (DistalDendrite)bestSegmentAndCell[0];
            Cell bestCell = (Cell)bestSegmentAndCell[1];
            if(bestCell != null) {
                cycle.winnerCells.add(bestCell);
            }
            
            int segmentCounter = c.getSegmentCount();
            if(bestSegment == null) {
                bestSegment = bestCell.createSegment(c, segmentCounter);
                c.setSegmentCount(segmentCounter + 1);
            }
            
            cycle.learningSegments.add(bestSegment);
        }
    }
    
    /**
     * Phase 3: Perform learning by adapting segments.
     * 
     * Pseudocode:
     *
     * - (learning) for each previously active or learning segment
     *   - if learning segment or from winner cell
     *     - strengthen active synapses
     *     - weaken inactive synapses
     *   - if learning segment
     *     - add some synapses to the segment
     *     - sub sample from previous winner cells
     * 
* * @param c the Connections state of the temporal memory * @param prevActiveSegments the Set of segments active in the previous cycle. * @param learningSegments the Set of segments marked as learning {@link #burstColumns(ComputeCycle, Connections, Set, Set, Map)} * @param prevActiveSynapseSegments the map of segments which were previously active to their associated {@link Synapse}s. * @param winnerCells the Set of all winning cells ({@link Cell}s with the most active synapses) * @param prevWinnerCells the Set of cells which were winners during the last compute cycle */ public void learnOnSegments(Connections c, Set prevActiveSegments, Set learningSegments, Map> prevActiveSynapseSegments, Set winnerCells, Set prevWinnerCells) { double permanenceIncrement = c.getPermanenceIncrement(); double permanenceDecrement = c.getPermanenceDecrement(); List prevAndLearning = new ArrayList(prevActiveSegments); prevAndLearning.addAll(learningSegments); for(DistalDendrite dd : prevAndLearning) { boolean isLearningSegment = learningSegments.contains(dd); boolean isFromWinnerCell = winnerCells.contains(dd.getParentCell()); Set activeSynapses = dd.getConnectedActiveSynapses(prevActiveSynapseSegments, 0); if(isLearningSegment || isFromWinnerCell) { dd.adaptSegment(c, activeSynapses, permanenceIncrement, permanenceDecrement); } int synapseCounter = c.getSynapseCount(); int n = c.getMaxNewSynapseCount() - activeSynapses.size(); if(isLearningSegment && n > 0) { Set learnCells = dd.pickCellsToLearnOn(c, n, prevWinnerCells, c.getRandom()); for(Cell sourceCell : learnCells) { dd.createSynapse(c, sourceCell, c.getInitialPermanence(), synapseCounter); synapseCounter += 1; } c.setSynapseCount(synapseCounter); } } } /** * Phase 4: Compute predictive cells due to lateral input on distal dendrites. * * Pseudocode: * * - for each distal dendrite segment with activity >= activationThreshold * - mark the segment as active * - mark the cell as predictive * * @param c the Connections state of the temporal memory * @param cycle the state during the current compute cycle * @param activeSegments */ public void computePredictiveCells(Connections c, ComputeCycle cycle, Map> activeDendrites) { for(DistalDendrite dd : activeDendrites.keySet()) { Set connectedActive = dd.getConnectedActiveSynapses(activeDendrites, c.getConnectedPermanence()); if(connectedActive.size() >= c.getActivationThreshold()) { cycle.activeSegments.add(dd); cycle.predictiveCells.add(dd.getParentCell()); } } } /** * Forward propagates activity from active cells to the synapses that touch * them, to determine which synapses are active. * * @param c the connections state of the temporal memory * @param cellsActive * @return */ public Map> computeActiveSynapses(Connections c, Set cellsActive) { Map> activesSynapses = new LinkedHashMap>(); for(Cell cell : cellsActive) { for(Synapse s : cell.getReceptorSynapses(c)) { Set set = null; if((set = activesSynapses.get(s.getSegment())) == null) { activesSynapses.put((DistalDendrite)s.getSegment(), set = new LinkedHashSet()); } set.add(s); } } return activesSynapses; } /** * Called to start the input of a new sequence. * * @param connections the Connections state of the temporal memory */ public void reset(Connections connections) { connections.getActiveCells().clear(); connections.getPredictiveCells().clear(); connections.getActiveSegments().clear(); connections.getActiveSynapsesForSegment().clear(); connections.getWinnerCells().clear(); } /////////////////////////// HELPER FUNCTIONS /////////////////////////// /** * Gets the cell with the best matching segment * (see `TM.getBestMatchingSegment`) that has the largest number of active * synapses of all best matching segments. * * @param c encapsulated memory and state * @param column {@link Column} within which to search for best cell * @param prevActiveSynapsesForSegment a {@link DistalDendrite}'s previously active {@link Synapse}s * @return an object array whose first index contains a segment, and the second contains a cell */ public Object[] getBestMatchingCell(Connections c, Column column, Map> prevActiveSynapsesForSegment) { Object[] retVal = new Object[2]; Cell bestCell = null; DistalDendrite bestSegment = null; int maxSynapses = 0; for(Cell cell : column.getCells()) { DistalDendrite dd = getBestMatchingSegment(c, cell, prevActiveSynapsesForSegment); if(dd != null) { Set connectedActiveSynapses = dd.getConnectedActiveSynapses(prevActiveSynapsesForSegment, 0); if(connectedActiveSynapses.size() > maxSynapses) { maxSynapses = connectedActiveSynapses.size(); bestCell = cell; bestSegment = dd; } } } if(bestCell == null) { bestCell = column.getLeastUsedCell(c, c.getRandom()); } retVal[0] = bestSegment; retVal[1] = bestCell; return retVal; } /** * Gets the segment on a cell with the largest number of activate synapses, * including all synapses with non-zero permanences. * * @param c encapsulated memory and state * @param column {@link Column} within which to search for best cell * @param activeSynapseSegments a {@link DistalDendrite}'s active {@link Synapse}s * @return the best segment */ public DistalDendrite getBestMatchingSegment(Connections c, Cell cell, Map> activeSynapseSegments) { int maxSynapses = c.getMinThreshold(); DistalDendrite bestSegment = null; for(DistalDendrite dd : cell.getSegments(c)) { Set activeSyns = dd.getConnectedActiveSynapses(activeSynapseSegments, 0); if(activeSyns.size() >= maxSynapses) { maxSynapses = activeSyns.size(); bestSegment = dd; } } return bestSegment; } /** * Returns the column index given the cells per column and * the cell index passed in. * * @param c {@link Connections} memory * @param cellIndex the index where the requested cell resides * @return */ protected int columnForCell(Connections c, int cellIndex) { return cellIndex / c.getCellsPerColumn(); } /** * Returns the cell at the specified index. * @param index * @return */ public Cell getCell(Connections c, int index) { return c.getCells()[index]; } /** * Returns a {@link LinkedHashSet} of {@link Cell}s from a * sorted array of cell indexes. * * @param`c the {@link Connections} object * @param cellIndexes indexes of the {@link Cell}s to return * @return */ public LinkedHashSet getCells(Connections c, int[] cellIndexes) { LinkedHashSet cellSet = new LinkedHashSet(); for(int cell : cellIndexes) { cellSet.add(getCell(c, cell)); } return cellSet; } /** * Returns a {@link LinkedHashSet} of {@link Column}s from a * sorted array of Column indexes. * * @param cellIndexes indexes of the {@link Column}s to return * @return */ public LinkedHashSet getColumns(Connections c, int[] columnIndexes) { return c.getColumnSet(columnIndexes); } }




© 2015 - 2025 Weber Informatics LLC | Privacy Policy