All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.mayabot.blas.FloatMatrix.kt Maven / Gradle / Ivy

The newest version!
package com.mayabot.blas

import java.io.File
import java.io.Serializable
import java.nio.ByteBuffer
import java.nio.channels.FileChannel
import java.util.*
import kotlin.math.min


/**
 * 不可变的Float矩阵
 */
interface FloatMatrix : Serializable {
    fun rows(): Int
    fun cols(): Int
    operator fun get(row: Int): Vector
    operator fun get(i: Int, j: Int): Float

    fun write(channel: FileChannel)

    fun save(file: File) {
        file.outputStream().channel.use {
            it.writeInt(rows())
            it.writeInt(cols())
            this.write(it)
        }
    }

    fun toMutableFloatMatrix(): MutableFloatMatrix

    companion object {
        fun byteBufferMatrix(rows: Int, cols: Int) = MutableByteBufferMatrix(rows, cols, false)
        fun directByteBufferMatrix(rows: Int, cols: Int) = MutableByteBufferMatrix(rows, cols, true)
        fun floatArrayMatrix(rows: Int, cols: Int) = MutableFloatArrayMatrix(rows, cols)
        fun readOnlyFloatArrayMatrix(rows: Int, cols: Int, data: FloatArray) = FloatArrayMatrix(rows, cols, data)


        fun loadMatrix(file: File, mmap: Boolean): FloatMatrix {
            fun pages(total: Long, size: Int): Int = ((total + size.toLong() - 1) / size.toLong()).toInt()
            return if (mmap) {
                file.inputStream().channel.use {
                    val rows = it.readInt()
                    val cols = it.readInt()

                    //一个区域可以容纳多少行
                    var areaRows = 0
                    while (areaRows * cols < 268435456) {
                        areaRows += 10
                    }

                    val fileSize = it.size()
                    val arrayBytes = fileSize - 8
                    val areaCount = pages(arrayBytes, 4 * areaRows * cols)
                    val areaBytes = areaRows * cols * 4
                    val lastBytes = arrayBytes % (areaRows * cols * 4)

                    val list = ArrayList()
                    for (a in 0 until areaCount) {
                        val len = if (a == areaCount - 1) lastBytes else areaBytes.toLong()
                        list += it.map(FileChannel.MapMode.READ_ONLY, 8 + a.toLong() * areaBytes, len)
                    }
                    AreaByteBufferMatrix(rows, cols, list)
                }
            } else {
                val dataInput = AutoDataInput.open(file)
                val rows = dataInput.readInt()
                val cols = dataInput.readInt()
                val floatArray = FloatArray(rows * cols)
                for (i in 0 until rows * cols) {
                    floatArray[i] = dataInput.readFloat()
                }
                FloatMatrix.readOnlyFloatArrayMatrix(rows, cols, floatArray)
            }

        }

        private fun FileChannel.writeInt(value: Int) {
            this.write(ByteBuffer.allocate(4).putInt(value).apply { flip() })
        }

        private fun FileChannel.readInt(): Int {
            val b = ByteBuffer.allocate(4)
            this.read(b)
            b.flip()
            return b.int
        }
    }
}

abstract class BaseMatrix(val rows: Int, val cols: Int) : FloatMatrix {

    override fun rows(): Int {
        return rows
    }

    override fun cols(): Int {
        return cols
    }

    override fun toString(): String {
        if (rows == 0) {
            return ""
        }

        val b = StringBuilder()

        b.append("-".repeat(cols * 12))
        b.append("\n")

        for (i in 0 until min(20, rows)) {
            val row = get(i)

            for (j in 0 until cols) {
                b.append(row[j]).append("\t")
            }

            b.append("\n")
        }

        if (rows > 20) {
            b.append("....more....")
        }
        b.append("\n")

        return b.toString()
    }
}

/**
 * 行存储的只读矩阵。内存实现
 */
class FloatArrayMatrix(rows: Int, cols: Int, val data: FloatArray) : BaseMatrix(rows, cols), FloatMatrix {

    override fun write(channel: FileChannel) {
        val byteBuffer = ByteBuffer.allocate(cols * 4)
        val asFloatBuffer = byteBuffer.asFloatBuffer()
        for (row in 0 until rows) {
            asFloatBuffer.clear()
            asFloatBuffer.put(data, row * cols, cols)

            byteBuffer.position(0)
            byteBuffer.limit(cols * 4)
            channel.write(byteBuffer)
        }
    }

    override fun toMutableFloatMatrix(): MutableFloatMatrix {
        val to = MutableFloatArrayMatrix(rows, cols)

        System.arraycopy(this.data, 0, to.data, 0, rows * cols)

        return to
    }

    //private var length = rows * cols

    private val rowView = (0 until rows).mapIndexed { index, _ -> FloatArrayVector(data, index * cols, cols) }.toTypedArray()

    /**
     * 行视图
     */
    override operator fun get(row: Int): Vector {
        return rowView[row]
    }

    override operator fun get(i: Int, j: Int): Float {
        return data[i * cols + j]
    }
}

/**
 * 底层是一个ByteBuffer。实现只读版本的FloatMatrix
 */
class ByteBufferMatrix(rows: Int, cols: Int, val data: ByteBuffer) : BaseMatrix(rows, cols), FloatMatrix {

    // private var length = rows * cols
    //private val rowView = (0 until rows).mapIndexed { index, _ -> VectorDefault(data, index * cols, cols) }.toTypedArray()

    private fun index(i: Int, j: Int): Int {
        return i * cols + j
    }

    /**
     * 行视图
     */
    override operator fun get(row: Int): Vector {
        return ByteBufferVector(data, row * cols, cols)
    }

    override operator fun get(i: Int, j: Int): Float {
        return data.getFloat(index(i, j) shl 2)
    }

    override fun write(channel: FileChannel) {
        data.position(0)
        data.limit(data.capacity())

        channel.write(data)
    }

    override fun toMutableFloatMatrix(): MutableFloatMatrix {
        val inputMatrix = MutableFloatArrayMatrix(rows, cols)
        for (i in 0 until rows) {
            inputMatrix[i] += this[i]
        }
        return inputMatrix
    }
}


/**
 * 特殊版本的只读二维矩阵。
 * 在使用内存映射读取文件时,Java规定每个ByteBuffer不能超过2G大小
 */
class AreaByteBufferMatrix(rows: Int, cols: Int, val data: List) : BaseMatrix(rows, cols), FloatMatrix {

    // private var length = rows * cols
    //private val rowView = (0 until rows).mapIndexed { index, _ -> VectorDefault(data, index * cols, cols) }.toTypedArray()

    val areaRows = data[0].capacity() / 4 / cols

    private fun index(i: Int, j: Int): Int {
        return i * cols + j
    }

    /**
     * 行视图
     */
    override operator fun get(row: Int): Vector {
        val area = row / areaRows
        val areaOffeet = row % areaRows
        return ByteBufferVector(data[area], areaOffeet * cols, cols)
    }

    override operator fun get(i: Int, j: Int): Float {

        val area = i / areaRows
        val areaOffeet = i % areaRows

        return data[area].getFloat(index(areaOffeet, j) shl 2)
    }

    override fun write(channel: FileChannel) {
        for (x in data) {
            x.position(0)
            x.limit(x.capacity())
            channel.write(x)
        }
    }

    override fun toMutableFloatMatrix(): MutableFloatMatrix {
        val inputMatrix = MutableFloatArrayMatrix(rows, cols)
        for (i in 0 until rows) {
            inputMatrix[i] += this[i]
        }
        return inputMatrix
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy