All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.numenta.nupic.util.LowMemorySparseBinaryMatrix Maven / Gradle / Ivy

There is a newer version: 0.6.13
Show newest version
/* ---------------------------------------------------------------------
 * Numenta Platform for Intelligent Computing (NuPIC)
 * Copyright (C) 2014, Numenta, Inc.  Unless you have an agreement
 * with Numenta, Inc., for a separate license for this software code, the
 * following terms and conditions apply:
 *
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License version 3 as
 * published by the Free Software Foundation.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
 * See the GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see http://www.gnu.org/licenses.
 *
 * http://numenta.org/licenses/
 * ---------------------------------------------------------------------
 */

package org.numenta.nupic.util;

import java.lang.reflect.Array;
import java.util.Arrays;

import org.numenta.nupic.model.Persistable;

import gnu.trove.set.TIntSet;
import gnu.trove.set.hash.TIntHashSet;

/**
 * Low Memory implementation of {@link SparseBinaryMatrix} without 
 * a backing array.
 * 
 * @author Jose Luis Martin
 */
public class LowMemorySparseBinaryMatrix extends AbstractSparseBinaryMatrix implements Persistable {
    /** keep it simple */
    private static final long serialVersionUID = 1L;
    
    private TIntSet sparseSet = new TIntHashSet();

    public LowMemorySparseBinaryMatrix(int[] dimensions) {
        this(dimensions, false);
    }

    public LowMemorySparseBinaryMatrix(int[] dimensions, boolean useColumnMajorOrdering) {
        super(dimensions, useColumnMajorOrdering);
    }

    @Override
    public Object getSlice(int... coordinates) {
        int[] dimensions = getDimensions();
        // check for valid coordinates
        if (coordinates.length >= dimensions.length)
            sliceError(coordinates);

        int sliceDimensionsLength = dimensions.length - coordinates.length;
        int[] sliceDimensions = (int[]) Array.newInstance(int.class, sliceDimensionsLength);

        for (int i = coordinates.length ; i < dimensions.length; i++) 
            sliceDimensions[i - coordinates.length] = dimensions[i];

        int[] elementCoordinates = Arrays.copyOf(coordinates, coordinates.length + 1);
        Object slice = Array.newInstance(int.class, sliceDimensions);

        if (coordinates.length + 1 == dimensions.length) {
            // last slice 
            for (int i = 0; i < dimensions[coordinates.length]; i++) {
                elementCoordinates[coordinates.length] = i;
                Array.set(slice,  i, get(elementCoordinates));
            }
        }
        else {
            for (int i = 0; i < dimensions[sliceDimensionsLength]; i++) {
                elementCoordinates[coordinates.length] = i;
                Array.set(slice, i, getSlice(elementCoordinates));
            }
        }

        return slice;
    }
    
    @Override
    public void rightVecSumAtNZ(int[] inputVector, int[] results) {
        if (this.dimensions.length > 1) {
            for (int value : getSparseIndices()) {
                int[] coordinates = computeCoordinates(value);
                
                if (inputVector[coordinates[1]]  != 0)
                    results[coordinates[0]] += 1;
            }
        }
        else {
            for(int i = 0; i < this.dimensions[0]; i++) {
                results[0] += (inputVector[i] * (int) get(i));
            }

            for (int i = 0; i < this.dimensions[0]; i++) {
                results[i] = results[0];
            }
        }
    }
    
    @Override
    public void rightVecSumAtNZ(int[] inputVector, int[] results, double stimulusThreshold) {
        if (this.dimensions.length > 1) {
            int[] values = getSparseIndices();
            for (int i = 0;i < values.length;i++) {
                int[] coordinates = computeCoordinates(values[i]);
                
                if(inputVector[coordinates[1]]  != 0)
                    results[coordinates[0]] += 1;
                
                if(i == values.length - 1 && results[coordinates[0]] < stimulusThreshold) {
                    results[coordinates[0]] = 0;
                }
            }
        }
        else {
            for(int i = 0; i < this.dimensions[0]; i++) {
                results[0] += (inputVector[i] * (int) get(i));
            }

            for (int i = 0; i < this.dimensions[0]; i++) {
                results[i] = results[0];
                
                if(i == this.dimensions[0] - 1 && results[i] < stimulusThreshold) {
                    results[i] = 0;
                }
            }
        }
    }

    @Override
    public LowMemorySparseBinaryMatrix set(int value, int... coordinates) {
        int index = computeIndex(coordinates);
                
        if (value == 1) {
            this.sparseSet.add(index);
        }
        else {
            this.sparseSet.remove(index);
        }
        
        updateTrueCounts(coordinates);

        return this;
    }

    @Override
    public LowMemorySparseBinaryMatrix setForTest(int index, int value) {
        if (value == 1) {
            this.sparseSet.add(index);
        }
        else {
            this.sparseSet.remove(index);
        }

        return this;
    }

    /**
     * Update the true counts for a coordinates.
     * @param coordinates
     */
    private void updateTrueCounts(int... coordinates) {
        Object slice = getSlice(coordinates[0]);
        int sum = ArrayUtils.aggregateArray(slice);
        setTrueCount(coordinates[0],sum);
    }

    @Override
    public LowMemorySparseBinaryMatrix set(int index, Object value) {
        super.set(index, ((Integer) value).intValue());
        return this;
    }

    @Override
    public Integer get(int index) {
       return this.sparseSet.contains(index) ? 1 : 0;
    }

    /* (non-Javadoc)
     * @see java.lang.Object#hashCode()
     */
    @Override
    public int hashCode() {
        final int prime = 31;
        int result = 1;
        result = prime * result + ((sparseSet == null) ? 0 : sparseSet.hashCode());
        return result;
    }

    /* (non-Javadoc)
     * @see java.lang.Object#equals(java.lang.Object)
     */
    @Override
    public boolean equals(Object obj) {
        if(this == obj)
            return true;
        if(obj == null)
            return false;
        if(getClass() != obj.getClass())
            return false;
        LowMemorySparseBinaryMatrix other = (LowMemorySparseBinaryMatrix)obj;
        if(sparseSet == null) {
            if(other.sparseSet != null)
                return false;
        } else if(!sparseSet.equals(other.sparseSet))
            return false;
        return true;
    }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy