
org.numenta.nupic.util.LowMemorySparseBinaryMatrix Maven / Gradle / Ivy
/* ---------------------------------------------------------------------
* Numenta Platform for Intelligent Computing (NuPIC)
* Copyright (C) 2014, Numenta, Inc. Unless you have an agreement
* with Numenta, Inc., for a separate license for this software code, the
* following terms and conditions apply:
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License version 3 as
* published by the Free Software Foundation.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
* See the GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see http://www.gnu.org/licenses.
*
* http://numenta.org/licenses/
* ---------------------------------------------------------------------
*/
package org.numenta.nupic.util;
import java.lang.reflect.Array;
import java.util.Arrays;
import org.numenta.nupic.model.Persistable;
import gnu.trove.set.TIntSet;
import gnu.trove.set.hash.TIntHashSet;
/**
* Low Memory implementation of {@link SparseBinaryMatrix} without
* a backing array.
*
* @author Jose Luis Martin
*/
public class LowMemorySparseBinaryMatrix extends AbstractSparseBinaryMatrix implements Persistable {
/** keep it simple */
private static final long serialVersionUID = 1L;
private TIntSet sparseSet = new TIntHashSet();
public LowMemorySparseBinaryMatrix(int[] dimensions) {
this(dimensions, false);
}
public LowMemorySparseBinaryMatrix(int[] dimensions, boolean useColumnMajorOrdering) {
super(dimensions, useColumnMajorOrdering);
}
@Override
public Object getSlice(int... coordinates) {
int[] dimensions = getDimensions();
// check for valid coordinates
if (coordinates.length >= dimensions.length)
sliceError(coordinates);
int sliceDimensionsLength = dimensions.length - coordinates.length;
int[] sliceDimensions = (int[]) Array.newInstance(int.class, sliceDimensionsLength);
for (int i = coordinates.length ; i < dimensions.length; i++)
sliceDimensions[i - coordinates.length] = dimensions[i];
int[] elementCoordinates = Arrays.copyOf(coordinates, coordinates.length + 1);
Object slice = Array.newInstance(int.class, sliceDimensions);
if (coordinates.length + 1 == dimensions.length) {
// last slice
for (int i = 0; i < dimensions[coordinates.length]; i++) {
elementCoordinates[coordinates.length] = i;
Array.set(slice, i, get(elementCoordinates));
}
}
else {
for (int i = 0; i < dimensions[sliceDimensionsLength]; i++) {
elementCoordinates[coordinates.length] = i;
Array.set(slice, i, getSlice(elementCoordinates));
}
}
return slice;
}
@Override
public void rightVecSumAtNZ(int[] inputVector, int[] results) {
if (this.dimensions.length > 1) {
for (int value : getSparseIndices()) {
int[] coordinates = computeCoordinates(value);
if (inputVector[coordinates[1]] != 0)
results[coordinates[0]] += 1;
}
}
else {
for(int i = 0; i < this.dimensions[0]; i++) {
results[0] += (inputVector[i] * (int) get(i));
}
for (int i = 0; i < this.dimensions[0]; i++) {
results[i] = results[0];
}
}
}
@Override
public void rightVecSumAtNZ(int[] inputVector, int[] results, double stimulusThreshold) {
if (this.dimensions.length > 1) {
int[] values = getSparseIndices();
for (int i = 0;i < values.length;i++) {
int[] coordinates = computeCoordinates(values[i]);
if(inputVector[coordinates[1]] != 0)
results[coordinates[0]] += 1;
if(i == values.length - 1 && results[coordinates[0]] < stimulusThreshold) {
results[coordinates[0]] = 0;
}
}
}
else {
for(int i = 0; i < this.dimensions[0]; i++) {
results[0] += (inputVector[i] * (int) get(i));
}
for (int i = 0; i < this.dimensions[0]; i++) {
results[i] = results[0];
if(i == this.dimensions[0] - 1 && results[i] < stimulusThreshold) {
results[i] = 0;
}
}
}
}
@Override
public LowMemorySparseBinaryMatrix set(int value, int... coordinates) {
int index = computeIndex(coordinates);
if (value == 1) {
this.sparseSet.add(index);
}
else {
this.sparseSet.remove(index);
}
updateTrueCounts(coordinates);
return this;
}
@Override
public LowMemorySparseBinaryMatrix setForTest(int index, int value) {
if (value == 1) {
this.sparseSet.add(index);
}
else {
this.sparseSet.remove(index);
}
return this;
}
/**
* Update the true counts for a coordinates.
* @param coordinates
*/
private void updateTrueCounts(int... coordinates) {
Object slice = getSlice(coordinates[0]);
int sum = ArrayUtils.aggregateArray(slice);
setTrueCount(coordinates[0],sum);
}
@Override
public LowMemorySparseBinaryMatrix set(int index, Object value) {
super.set(index, ((Integer) value).intValue());
return this;
}
@Override
public Integer get(int index) {
return this.sparseSet.contains(index) ? 1 : 0;
}
/* (non-Javadoc)
* @see java.lang.Object#hashCode()
*/
@Override
public int hashCode() {
final int prime = 31;
int result = 1;
result = prime * result + ((sparseSet == null) ? 0 : sparseSet.hashCode());
return result;
}
/* (non-Javadoc)
* @see java.lang.Object#equals(java.lang.Object)
*/
@Override
public boolean equals(Object obj) {
if(this == obj)
return true;
if(obj == null)
return false;
if(getClass() != obj.getClass())
return false;
LowMemorySparseBinaryMatrix other = (LowMemorySparseBinaryMatrix)obj;
if(sparseSet == null) {
if(other.sparseSet != null)
return false;
} else if(!sparseSet.equals(other.sparseSet))
return false;
return true;
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy