All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.numenta.nupic.util.FastConnectionsMatrix Maven / Gradle / Ivy

The newest version!
/* ---------------------------------------------------------------------
 * Numenta Platform for Intelligent Computing (NuPIC)
 * Copyright (C) 2014, Numenta, Inc.  Unless you have an agreement
 * with Numenta, Inc., for a separate license for this software code, the
 * following terms and conditions apply:
 *
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License version 3 as
 * published by the Free Software Foundation.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
 * See the GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see http://www.gnu.org/licenses.
 *
 * http://numenta.org/licenses/
 * ---------------------------------------------------------------------
 */

package org.numenta.nupic.util;

import org.numenta.nupic.model.Connections;
import org.numenta.nupic.model.Persistable;

import gnu.trove.set.hash.TIntHashSet;

/**
 * Fast implementation of {@link SparseBinaryMatrix} for use as ConnectedMatrix in  
 * {@link Connections}
 * 
 * @author Jose Luis Martin
 */
public class FastConnectionsMatrix extends AbstractSparseBinaryMatrix implements Persistable {
    /** keep it simple */
    private static final long serialVersionUID = 1L;
    
    private TIntHashSet[] columns;
   
    /**
     * @param dimensions
     */
    public FastConnectionsMatrix(int[] dimensions) {
        this(dimensions, false);
    }
    
    /**
     * @param dimensions
     * @param useColumnMajorOrdering
     */
    public FastConnectionsMatrix(int[] dimensions, boolean useColumnMajorOrdering) {
        super(dimensions, useColumnMajorOrdering);
        this.columns = new TIntHashSet[dimensions[0]];
    }

    @Override
    public Object getSlice(int... coordinates) {
        if (coordinates.length > this.numDimensions - 1)
            sliceError(coordinates);
        
        int[] slice = new int[this.dimensions[1]];
        for (int i = 0; i < this.dimensions[1]; i++)
            slice[i] = this.columns[coordinates[0]].contains(i) ? 1 : 0;
        
        return slice;
    }

    @Override
    public void rightVecSumAtNZ(int[] inputVector, int[] results) {
        for (int i = 0; i < dimensions[0]; i++) {
            for (int index : getColumnInput(i).toArray()) {
                if (inputVector[index] != 0)
                    results[i] += 1;
            }
        }
    }
    
    @Override
    public void rightVecSumAtNZ(int[] inputVector, int[] results, double stimulusThreshold) {
        for (int i = 0; i < dimensions[0]; i++) {
            int[] columnIndexes = getColumnInput(i).toArray();
            for (int j = 0;j < columnIndexes.length;j++) {
                if (inputVector[columnIndexes[j]] != 0) {
                    results[i] += 1;
                }
                if(j == columnIndexes.length - 1 && results[i] < stimulusThreshold) {
                    results[i] = 0;
                }
            }
        }
    }

    @Override
    public FastConnectionsMatrix set(int index, Object value) {
       set(index, ((Integer)value).intValue());
       return this;
    }
    
    @Override
    public AbstractSparseBinaryMatrix set(int value, int... coordinates) {
        TIntHashSet input = getColumnInput(coordinates[0]);
        if (value == 0) {
            input.remove(coordinates[1]);
        }
        else {
            input.add(coordinates[1]);
        }
        
        return this;
    }
    
    

    @Override
    public Integer get(int index) {
        int[] coordinates = computeCoordinates(index);
        return  getColumnInput(coordinates[0]).contains(coordinates[1]) ? 1 : 0;
    }

    /**
     * @param i
     */
    private TIntHashSet getColumnInput(int i) {
        if (this.columns[i] == null)
            this.columns[i] = new TIntHashSet();
        
        return this.columns[i];
        
    }

    @Override
    public void clearStatistics(int row) {
        getColumnInput(row).clear();
    }

    @Override
    public int getTrueCount(int index) {
        return getColumnInput(index).size();
    }

    @Override
    public int[] getTrueCounts() {
        int[] trueCounts = new int[this.dimensions[0]];
        for (int i = 0; i < this.dimensions[0]; i++) 
            trueCounts[i] = getTrueCount(i);
        
        return trueCounts;
    }

    @Override
    public AbstractSparseBinaryMatrix setForTest(int index, int value) {
        return set(index, value);
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy