All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.wikibrain.matrix.MemoryMappedMatrix Maven / Gradle / Ivy

There is a newer version: 0.9.1
Show newest version
package org.wikibrain.matrix;

import gnu.trove.map.hash.TIntLongHashMap;

import java.io.File;
import java.io.IOException;
import java.lang.ref.WeakReference;
import java.nio.ByteBuffer;
import java.nio.IntBuffer;
import java.nio.LongBuffer;
import java.nio.MappedByteBuffer;
import java.nio.channels.FileChannel;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;


/**
 * A wrapper around a file channel that contains a matrix in row major order.
 * Returns rows at a particular offset in the form of ByteBuffers backed by a memory mapped file.
 */
public class MemoryMappedMatrix {
    public static final Logger LOG = LoggerFactory.getLogger(MemoryMappedMatrix.class);

    public static final int PAGE_SIZE = 1024 * 1024 * 1024;     // 1GB

    private FileChannel channel;
    protected List buffers = new ArrayList();
    private File path;

    private final IntBuffer rowIds;     // row ids in order of offsets on disk
    private final LongBuffer rowOffsets;      // row offsets associated with sorted ids
    private final int numRows;

    private WeakReference rowIdsInDiskOrder = null;

    public MemoryMappedMatrix(File path, FileChannel channel,TIntLongHashMap rowOffsets) throws IOException {
        throw new UnsupportedOperationException();
    }

    public MemoryMappedMatrix(File path, FileChannel channel, IntBuffer rowIds, LongBuffer rowOffsets) throws IOException {
        this.path = path;
        this.channel = channel;
        if (rowIds.capacity() != rowOffsets.capacity()) {
            throw new IllegalArgumentException();
        }
        this.rowIds = rowIds;
        this.rowOffsets = rowOffsets;
        this.numRows = rowIds.capacity();
        int lastId = Integer.MIN_VALUE;
        for (int i = 0; i < numRows; i++) {
            if (rowIds.get(i) < lastId) {
                throw new IllegalArgumentException("Row ids must be in strictly increasing order");
            }
            lastId = rowIds.get(i);
        }
        pageInRows();
    }

    public void close() throws IOException {
        for (MappedBufferWrapper buffer : buffers) {
            buffer.close();
        }
        // try to garbage collect any freed buffers
        System.gc();
        System.gc();
        System.gc();
        channel.close();
    }

    private void pageInRows() throws IOException {
        if (numRows == 0) {
            return;
        }
        // tricky: pages must align with row boundaries
        long sortedOffsets[] = new long[numRows];
        for (int i = 0; i < numRows; i++) {
            sortedOffsets[i] = rowOffsets.get(i);
        }
        Arrays.sort(sortedOffsets);

        long startPos = sortedOffsets[0];
        long lastPos = startPos;

        for (int i = 1; i < numRows; i++) {
            long pos = sortedOffsets[i];
            if (pos - startPos > PAGE_SIZE) {
                assert(lastPos != startPos);
                addBuffer(startPos, lastPos);
                startPos = lastPos;
            }
            lastPos = pos;
        }
        addBuffer(startPos, channel.size());
    }


    private void addBuffer(long startPos, long endPos) throws IOException {
        long length = endPos - startPos;
        debug("adding page at " + startPos + " of length " + length);
        buffers.add(new MappedBufferWrapper(channel, startPos, endPos));
    }

    public ByteBuffer getRow(int rowId) throws IOException {
        long targetOffset = getRowOffset(rowId);
        if (targetOffset < 0) {
            return null;
        }
        MappedBufferWrapper row = null;
        // TODO: binary search
        for (int i = 0; i < buffers.size(); i++) {
            MappedBufferWrapper wrapper = buffers.get(i);
            if (wrapper.start <= targetOffset && targetOffset < wrapper.end) {
                row = wrapper;
            }
        }
        if (row == null) {
            throw new IllegalArgumentException("did not find row " + rowId + " with offset " + targetOffset);
        }
        return row.get(targetOffset);
    }

    private long getRowOffset(int rowId) {
        int lo = 0;
        int hi = numRows - 1;
        while (lo <= hi) {
            int mid = (lo + hi) / 2;
            int midId = rowIds.get(mid);

            if (rowId < midId) {
                hi = mid - 1;
            } else if (rowId > midId) {
                lo = mid + 1;
            } else {
                return rowOffsets.get(mid);
            }
        }
        return -1;
    }

    static class MappedBufferWrapper {
        FileChannel channel;
        MappedByteBuffer buffer;
        long start;
        long end;

        public MappedBufferWrapper(FileChannel channel, long start, long end) {
            this.channel = channel;
            this.start = start;
            this.end = end;
        }
        public synchronized ByteBuffer get(long position) throws IOException {
            if (buffer == null) {
                buffer = channel.map(FileChannel.MapMode.READ_ONLY, start, end - start);
            }
            buffer.position((int) (position - start));
            return buffer.slice();
        }
        public synchronized void close() {
            buffer = null;
        }
    }

    public synchronized  int[] getRowIdsInDiskOrder() {
        if (rowIdsInDiskOrder == null || rowIdsInDiskOrder.get() == null) {
            int ids[] = new int[numRows];
            for (int i = 0; i < numRows; i++) {
                ids[i] = rowIds.get(i);
            }
            sortIdsByOffset(ids);
            this.rowIdsInDiskOrder = new WeakReference(ids);
            return ids;
        } else {
            return rowIdsInDiskOrder.get();
        }
    }

    // Adapted from http://www.programcreek.com/2012/11/quicksort-array-in-java/
    private void sortIdsByOffset(int colIds[]) {
        this.sortIdsByOffset(colIds, 0, colIds.length-1);
    }

    private void sortIdsByOffset(int colIds[], int low, int high) {
        if (colIds.length == 0 || low >= high)
            return;

        // pick the pivot
        int middle = (low + high) / 2;
        long pivot = getRowOffset(colIds[middle]);

        // partition around the pivot
        int i = low, j = high;
        while (i <= j) {
            while (getRowOffset(colIds[i]) < pivot) {
                i++;
            }
            while (getRowOffset(colIds[j]) > pivot) {
                j--;
            }
            if (i <= j) {
                int temp = colIds[i];
                colIds[i] = colIds[j];
                colIds[j] = temp;
                i++;
                j--;
            }
        }

        //recursively sort two sub parts
        sortIdsByOffset(colIds, low, j);
        sortIdsByOffset(colIds, i, high);
    }

    private void info(String message) {
        LOG.info("sparse matrix " + path + ": " + message);
    }
    private void debug(String message) {
        LOG.debug("sparse matrix " + path + ": " + message);
    }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy