All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.numenta.nupic.util.AbstractSparseBinaryMatrix Maven / Gradle / Ivy

There is a newer version: 0.6.13
Show newest version
/* ---------------------------------------------------------------------
 * Numenta Platform for Intelligent Computing (NuPIC)
 * Copyright (C) 2014, Numenta, Inc.  Unless you have an agreement
 * with Numenta, Inc., for a separate license for this software code, the
 * following terms and conditions apply:
 *
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License version 3 as
 * published by the Free Software Foundation.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
 * See the GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see http://www.gnu.org/licenses.
 *
 * http://numenta.org/licenses/
 * ---------------------------------------------------------------------
 */

package org.numenta.nupic.util;

import java.lang.reflect.Array;
import java.util.Arrays;

import org.numenta.nupic.model.Persistable;

import gnu.trove.TIntCollection;
import gnu.trove.iterator.TIntIterator;
import gnu.trove.list.TIntList;
import gnu.trove.list.array.TIntArrayList;
import gnu.trove.set.TIntSet;
import gnu.trove.set.hash.TIntHashSet;

/**
 * Base class for matrices containing specifically binary (0 or 1) integer values
 * 
 * @author David Ray
 * @author Jose Luis Martin
 */
@SuppressWarnings("rawtypes")
public abstract class AbstractSparseBinaryMatrix extends AbstractSparseMatrix implements Persistable {
    /** keep it simple */
    private static final long serialVersionUID = 1L;
    
    private int[] trueCounts;

    /**
     * Constructs a new {@code AbstractSparseBinaryMatrix} with the specified
     * dimensions (defaults to row major ordering)
     * 
     * @param dimensions    each indexed value is a dimension size
     */
    public AbstractSparseBinaryMatrix(int[] dimensions) {
        this(dimensions, false);
    }

    /**
     * Constructs a new {@code AbstractSparseBinaryMatrix} with the specified dimensions,
     * allowing the specification of column major ordering if desired. 
     * (defaults to row major ordering)
     * 
     * @param dimensions                each indexed value is a dimension size
     * @param useColumnMajorOrdering    if true, indicates column first iteration, otherwise
     *                                  row first iteration is the default (if false).
     */
    public AbstractSparseBinaryMatrix(int[] dimensions, boolean useColumnMajorOrdering) {
        super(dimensions, useColumnMajorOrdering);
        this.trueCounts = new int[dimensions[0]];
    }

    /**
     * Returns the slice specified by the passed in coordinates.
     * The array is returned as an object, therefore it is the caller's
     * responsibility to cast the array to the appropriate dimensions.
     * 
     * @param coordinates	the coordinates which specify the returned array
     * @return	the array specified
     * @throws	IllegalArgumentException if the specified coordinates address
     * 			an actual value instead of the array holding it.
     */
    public abstract Object getSlice(int... coordinates);

    /**
     * Launch getSlice error, to share it with subclass {@link #getSlice(int...)}
     * implementations.
     * @param coordinates
     */
    protected void sliceError(int... coordinates) {
        throw new IllegalArgumentException(
            "This method only returns the array holding the specified maximum index: " + 
                    Arrays.toString(dimensions));
    }
    
    /**
     * Calculate the flat indexes of a slice
     * @return the flat indexes array
     */
    protected int[] getSliceIndexes(int[] coordinates) {
        int[] dimensions = getDimensions();
        // check for valid coordinates
        if (coordinates.length >= dimensions.length) {
            sliceError(coordinates);
        }

        int sliceDimensionsLength = dimensions.length - coordinates.length;
        int[] sliceDimensions = (int[]) Array.newInstance(int.class, sliceDimensionsLength);

        for (int i = coordinates.length ; i < dimensions.length; i++) { 
            sliceDimensions[i - coordinates.length] = dimensions[i];
        }

        int[] elementCoordinates = Arrays.copyOf(coordinates, coordinates.length + 1);
        int sliceSize = Arrays.stream(sliceDimensions).reduce((n,i) -> n*i).getAsInt();
        int[] slice = new int[sliceSize];

        if (coordinates.length + 1 == dimensions.length) {
            // last slice 
            for (int i = 0; i < dimensions[coordinates.length]; i++) {
                elementCoordinates[coordinates.length] = i;
                Array.set(slice,  i, computeIndex(elementCoordinates));
            }
        }
        else {
            for (int i = 0; i < dimensions[sliceDimensionsLength]; i++) {
                elementCoordinates[coordinates.length] = i;
                int[] indexes = getSliceIndexes(elementCoordinates);
                System.arraycopy(indexes, 0, slice, i*indexes.length, indexes.length);
            }
        }

        return slice;
    }

    /**
     * Fills the specified results array with the result of the 
     * matrix vector multiplication.
     * 
     * @param inputVector		the right side vector
     * @param results			the results array
     */
    public abstract void rightVecSumAtNZ(int[] inputVector, int[] results);
    
    /**
     * Fills the specified results array with the result of the 
     * matrix vector multiplication.
     * 
     * @param inputVector       the right side vector
     * @param results           the results array
     */
    public abstract void rightVecSumAtNZ(int[] inputVector, int[] results, double stimulusThreshold);
        
    /**
     * Sets the value at the specified index.
     * 
     * @param index     the index the object will occupy
     * @param object    the object to be indexed.
     */
    @Override
    public AbstractSparseBinaryMatrix set(int index, int value) {
        int[] coordinates = computeCoordinates(index);
        return set(value, coordinates);
    }

    /**
     * Sets the value to be indexed at the index
     * computed from the specified coordinates.
     * @param coordinates   the row major coordinates [outer --> ,...,..., inner]
     * @param object        the object to be indexed.
     */
    @Override
    public abstract AbstractSparseBinaryMatrix set(int value, int... coordinates);

    /**
     * Sets the specified values at the specified indexes.
     * 
     * @param indexes   indexes of the values to be set
     * @param values    the values to be indexed.
     * 
     * @return this {@code SparseMatrix} implementation
     */
    public AbstractSparseBinaryMatrix set(int[] indexes, int[] values) { 
        for(int i = 0;i < indexes.length;i++) {
            set(indexes[i], values[i]);
        }
        return this;
    }


    public Integer get(int... coordinates) {
        return get(computeIndex(coordinates));
    }

    public abstract Integer get(int index);
    
    /**
     * Sets the value at the specified index skipping the automatic
     * truth statistic tallying of the real method.
     * 
     * @param index     the index the object will occupy
     * @param object    the object to be indexed.
     */
    public abstract AbstractSparseBinaryMatrix setForTest(int index, int value);

    /**
     * Call This for TEST METHODS ONLY
     * Sets the specified values at the specified indexes.
     * 
     * @param indexes   indexes of the values to be set
     * @param values    the values to be indexed.
     * 
     * @return this {@code SparseMatrix} implementation
     */
    public AbstractSparseBinaryMatrix set(int[] indexes, int[] values, boolean isTest) { 
        for(int i = 0;i < indexes.length;i++) {
            if(isTest) setForTest(indexes[i], values[i]);
            else set(indexes[i], values[i]);
        }
        return this;
    }

    /**
     * Returns the count of 1's set on the specified row.
     * @param index
     * @return
     */
    public int getTrueCount(int index) {
        return trueCounts[index];
    }

    /**
     * Sets the count of 1's on the specified row.
     * @param index
     * @param count
     */
    public void setTrueCount(int index, int count) {
        this.trueCounts[index] = count;
    }

    /**
     * Get the true counts for all outer indexes.
     * @return
     */
    public int[] getTrueCounts() {
        return trueCounts;
    }

    /**
     * Clears the true counts prior to a cycle where they're
     * being set
     */
    public void clearStatistics(int row) {
        trueCounts[row] = 0;
        
        for (int index : getSliceIndexes(new int[] { row })) {
            set(index, 0);
        }
    }

    /**
     * Returns the int value at the index computed from the specified coordinates
     * @param coordinates   the coordinates from which to retrieve the indexed object
     * @return  the indexed object
     */
    public int getIntValue(int... coordinates) {
        return get(computeIndex(coordinates));
    }

    /**
     * Returns the T at the specified index.
     * 
     * @param index     the index of the T to return
     * @return  the T at the specified index.
     */
    @Override
    public int getIntValue(int index) {
        return get(index);
    }

    /**
     * Returns a sorted array of occupied indexes.
     * @return  a sorted array of occupied indexes.
     */
    @Override
    public int[] getSparseIndices() {
        TIntList indexes = new TIntArrayList();
        for (int i = 0; i <= getMaxIndex(); i ++) {
            if (get(i) > 0) {
                indexes.add(i);
            }
        }
        
        return indexes.toArray();
    }

    /**
     * This {@code SparseBinaryMatrix} will contain the operation of or-ing
     * the inputMatrix with the contents of this matrix; returning this matrix
     * as the result.
     * 
     * @param inputMatrix   the matrix containing the "on" bits to or
     * @return  this matrix
     */
    public AbstractSparseBinaryMatrix or(AbstractSparseBinaryMatrix inputMatrix) {
        int[] mask = inputMatrix.getSparseIndices();
        int[] ones = new int[mask.length];
        Arrays.fill(ones, 1);
        return set(mask, ones);
    }

    /**
     * This {@code SparseBinaryMatrix} will contain the operation of or-ing
     * the sparse list with the contents of this matrix; returning this matrix
     * as the result.
     * 
     * @param onBitIndexes  the matrix containing the "on" bits to or
     * @return  this matrix
     */
    public AbstractSparseBinaryMatrix or(TIntCollection onBitIndexes) {
        int[] ones = new int[onBitIndexes.size()];
        Arrays.fill(ones, 1);
        return set(onBitIndexes.toArray(), ones);
    }

    /**
     * This {@code SparseBinaryMatrix} will contain the operation of or-ing
     * the sparse array with the contents of this matrix; returning this matrix
     * as the result.
     * 
     * @param onBitIndexes  the int array containing the "on" bits to or
     * @return  this matrix
     */
    public AbstractSparseBinaryMatrix or(int[] onBitIndexes) {
        int[] ones = new int[onBitIndexes.length];
        Arrays.fill(ones, 1);
        return set(onBitIndexes, ones);
    }
    
    protected TIntSet getSparseSet() {
        return new TIntHashSet(getSparseIndices());
    }

    /**
     * Returns true if the on bits of the specified matrix are
     * matched by the on bits of this matrix. It is allowed that 
     * this matrix have more on bits than the specified matrix.
     * 
     * @param matrix
     * @return
     */
    public boolean all(AbstractSparseBinaryMatrix matrix) {
        return getSparseSet().containsAll(matrix.getSparseIndices());
    }

    /**
     * Returns true if the on bits of the specified list are
     * matched by the on bits of this matrix. It is allowed that 
     * this matrix have more on bits than the specified matrix.
     * 
     * @param matrix
     * @return
     */
    public boolean all(TIntCollection onBits) {
        return getSparseSet().containsAll(onBits);
    }

    /**
     * Returns true if the on bits of the specified array are
     * matched by the on bits of this matrix. It is allowed that 
     * this matrix have more on bits than the specified matrix.
     * 
     * @param matrix
     * @return
     */
    public boolean all(int[] onBits) {
        return getSparseSet().containsAll(onBits);
    }

    /**
     * Returns true if any of the on bits of the specified matrix are
     * matched by the on bits of this matrix. It is allowed that 
     * this matrix have more on bits than the specified matrix.
     * 
     * @param matrix
     * @return
     */
    public boolean any(AbstractSparseBinaryMatrix matrix) {
        TIntSet keySet = getSparseSet();
        
        for(int i : matrix.getSparseIndices()) {
            if(keySet.contains(i)) return true;
        }
        return false;
    }

    /**
     * Returns true if any of the on bit indexes of the specified collection are
     * matched by the on bits of this matrix. It is allowed that 
     * this matrix have more on bits than the specified matrix.
     * 
     * @param matrix
     * @return
     */
    public boolean any(TIntList onBits) {
        TIntSet keySet = getSparseSet();
        
        for(TIntIterator i = onBits.iterator();i.hasNext();) {
            if(keySet.contains(i.next())) return true;
        }
        return false;
    }

    /**
     * Returns true if any of the on bit indexes of the specified matrix are
     * matched by the on bits of this matrix. It is allowed that 
     * this matrix have more on bits than the specified matrix.
     * 
     * @param matrix
     * @return
     */
    public boolean any(int[] onBits) {
        TIntSet keySet = getSparseSet();
        
        for(int i : onBits) {
            if(keySet.contains(i)) return true;
        }
        return false;
    }

    /* (non-Javadoc)
     * @see java.lang.Object#hashCode()
     */
    @Override
    public int hashCode() {
        final int prime = 31;
        int result = super.hashCode();
        result = prime * result + Arrays.hashCode(trueCounts);
        return result;
    }

    /* (non-Javadoc)
     * @see java.lang.Object#equals(java.lang.Object)
     */
    @Override
    public boolean equals(Object obj) {
        if(this == obj)
            return true;
        if(!super.equals(obj))
            return false;
        if(getClass() != obj.getClass())
            return false;
        AbstractSparseBinaryMatrix other = (AbstractSparseBinaryMatrix)obj;
        if(!Arrays.equals(trueCounts, other.trueCounts))
            return false;
        return true;
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy