All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.wikibrain.matrix.SparseMatrixRow Maven / Gradle / Ivy

There is a newer version: 0.9.1
Show newest version
package org.wikibrain.matrix;

import gnu.trove.map.TIntFloatMap;
import org.apache.commons.lang3.ArrayUtils;

import java.io.File;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.IntBuffer;
import java.nio.ShortBuffer;
import java.util.LinkedHashMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * A single sparse matrix row backed by a byte buffer. The row contains:
 * - a row id (int),
 * - a set of n columns, each with an id (int) and value (float packed into two bytes)
 *
 * The row can either be created from the component data, or from a byte buffer.
 * This means that the object can wrap data from an mmap'd file in the correct format.
 *
 * Newly created rows are reordered so that the columns appear in sorted order.
 */
public final class SparseMatrixRow extends BaseMatrixRow implements MatrixRow {
    private static final Logger LOG = LoggerFactory.getLogger(SparseMatrixRow.class);

    public static final Float MIN_SCORE = -1.1f;
    public static final Float MAX_SCORE = 1.1f;

    public static final Float SCORE_RANGE = (MAX_SCORE - MIN_SCORE);
    public static final int PACKED_RANGE = (Short.MAX_VALUE - Short.MIN_VALUE);

    public static final int HEADER = 0xfefefefe;

    /**
     * The main "source" buffer.
     */
    private ByteBuffer buffer;

    /**
     * A view buffer that points to the header.
     */
    private IntBuffer headerBuffer;

    /**
     * A view buffer that points to the ids.
     */
    private IntBuffer idBuffer;

    /**
     * A view buffer that points to the values.
     */
    private ShortBuffer valBuffer;
    private ValueConf vconf;

    public SparseMatrixRow(ValueConf vconf, int rowIndex, TIntFloatMap row) {
        this(vconf, rowIndex, row.keys(), row.values());
    }

    public SparseMatrixRow(ValueConf vconf, int rowIndex, LinkedHashMap row) {
        this(vconf, rowIndex,
            ArrayUtils.toPrimitive(row.keySet().toArray(new Integer[] {})),
            ArrayUtils.toPrimitive(row.values().toArray(new Float[]{}))
        );
    }

    public SparseMatrixRow(ValueConf vconf, int rowIndex, int colIds[], float colVals[]) {
        this.vconf = vconf;
        short packed[] = new short[colVals.length];
        for (int i = 0; i < colVals.length; i++) {
            packed[i] = vconf.pack(colVals[i]);
        }
        createBuffer(rowIndex, colIds, packed);
    }

    public SparseMatrixRow(ValueConf vconf, int rowIndex, int colIds[], short colVals[]) {
        this.vconf = vconf;
        createBuffer(rowIndex, colIds, colVals);
    }

    public void createBuffer(int rowIndex, int colIds[], short colVals[]) {
        assert(colIds.length == colVals.length);
        if (!isNonDecreasing(colIds)) {
            quickSort(colIds, colVals, 0, colIds.length - 1);
            if (!isNonDecreasing(colIds)) {
                throw new IllegalStateException();
            }
        }

        buffer = ByteBuffer.allocate(
                4 +                 // header
                4 +                 // row index
                4 +                 // num cols
                4 * colVals.length +    // col indexes
                2 * colVals.length      // col values
        );
        createViewBuffers(colVals.length);

        headerBuffer.put(0, HEADER);
        headerBuffer.put(1, rowIndex);
        headerBuffer.put(2, colVals.length);
        idBuffer.put(colIds, 0, colIds.length);
        valBuffer.put(colVals, 0, colVals.length);
    }

    // Adapted from http://www.programcreek.com/2012/11/quicksort-array-in-java/
    private void quickSort(int colIds[], short colVals[], int low, int high) {
        if (colIds.length == 0 || low >= high)
            return;

        // pick the pivot
        int middle = (low + high) / 2;
        int pivot = colIds[middle];

        // partition around the pivot
        int i = low, j = high;
        while (i <= j) {
            while (colIds[i] < pivot) {
                i++;
            }
            while (colIds[j] > pivot) {
                j--;
            }
            if (i <= j) {
                int temp = colIds[i];
                short tempV = colVals[i];
                colIds[i] = colIds[j];
                colVals[i] = colVals[j];
                colIds[j] = temp;
                colVals[j] = tempV;
                i++;
                j--;
            }
        }

        //recursively sort two sub parts
        quickSort(colIds, colVals, low, j);
        quickSort(colIds, colVals, i, high);
    }


    static boolean isNonDecreasing(int A[]) {
        int lastId = Integer.MIN_VALUE;
        for (int i = 0; i < A.length; i++) {
            if (A[i] < lastId) {
                return false;
            }
            lastId = A[i];
        }
        return true;
    }

    private void createViewBuffers(int numColumns) {
        buffer.position(0);
        headerBuffer = buffer.asIntBuffer();
        buffer.position(3 * 4);
        idBuffer = buffer.asIntBuffer();
        buffer.position(3 * 4 + numColumns * 4);
        valBuffer = buffer.asShortBuffer();
    }

    /**
     * Wrap an existing byte buffer that contains a row.
     * @param buffer
     */
    public SparseMatrixRow(ValueConf vconf, ByteBuffer buffer) {
        this.vconf = vconf;
        this.buffer = buffer;
        if (this.buffer.getInt(0) != HEADER) {
            throw new IllegalArgumentException("Invalid header in byte buffer");
        }
        createViewBuffers(buffer.getInt(8));
    }

    @Override
    public final int getColIndex(int i) {
        return idBuffer.get(i);
    }

    @Override
    public final float getColValue(int i) {
        return vconf.unpack(valBuffer.get(i));
    }

    public final short getPackedColValue(int i) {
        return valBuffer.get(i);
    }

    @Override
    public final int getRowIndex() {
        return headerBuffer.get(1);
    }

    @Override
    public final int getNumCols() {
        return headerBuffer.get(2);
    }

    public ByteBuffer getBuffer() {
        return buffer;
    }

    public ValueConf getValueConf() {
        return vconf;
    }


}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy