
org.numenta.nupic.util.LowMemorySparseBinaryMatrix Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of htm.java Show documentation
Show all versions of htm.java Show documentation
The Java version of Numenta's HTM technology
/* ---------------------------------------------------------------------
* Numenta Platform for Intelligent Computing (NuPIC)
* Copyright (C) 2014, Numenta, Inc. Unless you have an agreement
* with Numenta, Inc., for a separate license for this software code, the
* following terms and conditions apply:
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License version 3 as
* published by the Free Software Foundation.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
* See the GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see http://www.gnu.org/licenses.
*
* http://numenta.org/licenses/
* ---------------------------------------------------------------------
*/
package org.numenta.nupic.util;
import java.lang.reflect.Array;
import java.util.Arrays;
import org.numenta.nupic.model.Persistable;
import gnu.trove.set.TIntSet;
import gnu.trove.set.hash.TIntHashSet;
/**
* Low Memory implementation of {@link SparseBinaryMatrix} without
* a backing array.
*
* @author Jose Luis Martin
*/
public class LowMemorySparseBinaryMatrix extends AbstractSparseBinaryMatrix implements Persistable {
/** keep it simple */
private static final long serialVersionUID = 1L;
private TIntSet sparseSet = new TIntHashSet();
public LowMemorySparseBinaryMatrix(int[] dimensions) {
this(dimensions, false);
}
public LowMemorySparseBinaryMatrix(int[] dimensions, boolean useColumnMajorOrdering) {
super(dimensions, useColumnMajorOrdering);
}
@Override
public Object getSlice(int... coordinates) {
int[] dimensions = getDimensions();
// check for valid coordinates
if (coordinates.length >= dimensions.length)
sliceError(coordinates);
int sliceDimensionsLength = dimensions.length - coordinates.length;
int[] sliceDimensions = (int[]) Array.newInstance(int.class, sliceDimensionsLength);
for (int i = coordinates.length ; i < dimensions.length; i++)
sliceDimensions[i - coordinates.length] = dimensions[i];
int[] elementCoordinates = Arrays.copyOf(coordinates, coordinates.length + 1);
Object slice = Array.newInstance(int.class, sliceDimensions);
if (coordinates.length + 1 == dimensions.length) {
// last slice
for (int i = 0; i < dimensions[coordinates.length]; i++) {
elementCoordinates[coordinates.length] = i;
Array.set(slice, i, get(elementCoordinates));
}
}
else {
for (int i = 0; i < dimensions[sliceDimensionsLength]; i++) {
elementCoordinates[coordinates.length] = i;
Array.set(slice, i, getSlice(elementCoordinates));
}
}
return slice;
}
@Override
public void rightVecSumAtNZ(int[] inputVector, int[] results) {
if (this.dimensions.length > 1) {
for (int value : getSparseIndices()) {
int[] coordinates = computeCoordinates(value);
if (inputVector[coordinates[1]] != 0)
results[coordinates[0]] += 1;
}
}
else {
for(int i = 0; i < this.dimensions[0]; i++) {
results[0] += (inputVector[i] * (int) get(i));
}
for (int i = 0; i < this.dimensions[0]; i++) {
results[i] = results[0];
}
}
}
@Override
public void rightVecSumAtNZ(int[] inputVector, int[] results, double stimulusThreshold) {
if (this.dimensions.length > 1) {
int[] values = getSparseIndices();
for (int i = 0;i < values.length;i++) {
int[] coordinates = computeCoordinates(values[i]);
if(inputVector[coordinates[1]] != 0)
results[coordinates[0]] += 1;
if(i == values.length - 1 && results[coordinates[0]] < stimulusThreshold) {
results[coordinates[0]] = 0;
}
}
}
else {
for(int i = 0; i < this.dimensions[0]; i++) {
results[0] += (inputVector[i] * (int) get(i));
}
for (int i = 0; i < this.dimensions[0]; i++) {
results[i] = results[0];
if(i == this.dimensions[0] - 1 && results[i] < stimulusThreshold) {
results[i] = 0;
}
}
}
}
@Override
public LowMemorySparseBinaryMatrix set(int value, int... coordinates) {
int index = computeIndex(coordinates);
if (value == 1) {
this.sparseSet.add(index);
}
else {
this.sparseSet.remove(index);
}
updateTrueCounts(coordinates);
return this;
}
@Override
public LowMemorySparseBinaryMatrix setForTest(int index, int value) {
if (value == 1) {
this.sparseSet.add(index);
}
else {
this.sparseSet.remove(index);
}
return this;
}
/**
* Update the true counts for a coordinates.
* @param coordinates
*/
private void updateTrueCounts(int... coordinates) {
Object slice = getSlice(coordinates[0]);
int sum = ArrayUtils.aggregateArray(slice);
setTrueCount(coordinates[0],sum);
}
@Override
public LowMemorySparseBinaryMatrix set(int index, Object value) {
super.set(index, ((Integer) value).intValue());
return this;
}
@Override
public Integer get(int index) {
return this.sparseSet.contains(index) ? 1 : 0;
}
/* (non-Javadoc)
* @see java.lang.Object#hashCode()
*/
@Override
public int hashCode() {
final int prime = 31;
int result = 1;
result = prime * result + ((sparseSet == null) ? 0 : sparseSet.hashCode());
return result;
}
/* (non-Javadoc)
* @see java.lang.Object#equals(java.lang.Object)
*/
@Override
public boolean equals(Object obj) {
if(this == obj)
return true;
if(obj == null)
return false;
if(getClass() != obj.getClass())
return false;
LowMemorySparseBinaryMatrix other = (LowMemorySparseBinaryMatrix)obj;
if(sparseSet == null) {
if(other.sparseSet != null)
return false;
} else if(!sparseSet.equals(other.sparseSet))
return false;
return true;
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy