
org.numenta.nupic.util.FastConnectionsMatrix Maven / Gradle / Ivy
/* ---------------------------------------------------------------------
* Numenta Platform for Intelligent Computing (NuPIC)
* Copyright (C) 2014, Numenta, Inc. Unless you have an agreement
* with Numenta, Inc., for a separate license for this software code, the
* following terms and conditions apply:
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License version 3 as
* published by the Free Software Foundation.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
* See the GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see http://www.gnu.org/licenses.
*
* http://numenta.org/licenses/
* ---------------------------------------------------------------------
*/
package org.numenta.nupic.util;
import org.numenta.nupic.model.Connections;
import org.numenta.nupic.model.Persistable;
import gnu.trove.set.hash.TIntHashSet;
/**
* Fast implementation of {@link SparseBinaryMatrix} for use as ConnectedMatrix in
* {@link Connections}
*
* @author Jose Luis Martin
*/
public class FastConnectionsMatrix extends AbstractSparseBinaryMatrix implements Persistable {
/** keep it simple */
private static final long serialVersionUID = 1L;
private TIntHashSet[] columns;
/**
* @param dimensions
*/
public FastConnectionsMatrix(int[] dimensions) {
this(dimensions, false);
}
/**
* @param dimensions
* @param useColumnMajorOrdering
*/
public FastConnectionsMatrix(int[] dimensions, boolean useColumnMajorOrdering) {
super(dimensions, useColumnMajorOrdering);
this.columns = new TIntHashSet[dimensions[0]];
}
@Override
public Object getSlice(int... coordinates) {
if (coordinates.length > this.numDimensions - 1)
sliceError(coordinates);
int[] slice = new int[this.dimensions[1]];
for (int i = 0; i < this.dimensions[1]; i++)
slice[i] = this.columns[coordinates[0]].contains(i) ? 1 : 0;
return slice;
}
@Override
public void rightVecSumAtNZ(int[] inputVector, int[] results) {
for (int i = 0; i < dimensions[0]; i++) {
for (int index : getColumnInput(i).toArray()) {
if (inputVector[index] != 0)
results[i] += 1;
}
}
}
@Override
public void rightVecSumAtNZ(int[] inputVector, int[] results, double stimulusThreshold) {
for (int i = 0; i < dimensions[0]; i++) {
int[] columnIndexes = getColumnInput(i).toArray();
for (int j = 0;j < columnIndexes.length;j++) {
if (inputVector[columnIndexes[j]] != 0) {
results[i] += 1;
}
if(j == columnIndexes.length - 1 && results[i] < stimulusThreshold) {
results[i] = 0;
}
}
}
}
@Override
public FastConnectionsMatrix set(int index, Object value) {
set(index, ((Integer)value).intValue());
return this;
}
@Override
public AbstractSparseBinaryMatrix set(int value, int... coordinates) {
TIntHashSet input = getColumnInput(coordinates[0]);
if (value == 0) {
input.remove(coordinates[1]);
}
else {
input.add(coordinates[1]);
}
return this;
}
@Override
public Integer get(int index) {
int[] coordinates = computeCoordinates(index);
return getColumnInput(coordinates[0]).contains(coordinates[1]) ? 1 : 0;
}
/**
* @param i
*/
private TIntHashSet getColumnInput(int i) {
if (this.columns[i] == null)
this.columns[i] = new TIntHashSet();
return this.columns[i];
}
@Override
public void clearStatistics(int row) {
getColumnInput(row).clear();
}
@Override
public int getTrueCount(int index) {
return getColumnInput(index).size();
}
@Override
public int[] getTrueCounts() {
int[] trueCounts = new int[this.dimensions[0]];
for (int i = 0; i < this.dimensions[0]; i++)
trueCounts[i] = getTrueCount(i);
return trueCounts;
}
@Override
public AbstractSparseBinaryMatrix setForTest(int index, int value) {
return set(index, value);
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy