All Downloads are FREE. Search and download functionalities are using the official Maven repository.

commonMain.org.jetbrains.kotlinx.multik.kotlin.KEStatistics.kt Maven / Gradle / Ivy

There is a newer version: 0.2.3
Show newest version
/*
 * Copyright 2020-2021 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
 */

package org.jetbrains.kotlinx.multik.kotlin

import org.jetbrains.kotlinx.multik.api.Statistics
import org.jetbrains.kotlinx.multik.api.mk
import org.jetbrains.kotlinx.multik.kotlin.math.KEMath
import org.jetbrains.kotlinx.multik.kotlin.math.remove
import org.jetbrains.kotlinx.multik.ndarray.data.*
import org.jetbrains.kotlinx.multik.ndarray.operations.div
import org.jetbrains.kotlinx.multik.ndarray.operations.first
import org.jetbrains.kotlinx.multik.ndarray.operations.sorted
import org.jetbrains.kotlinx.multik.ndarray.operations.times

public object KEStatistics : Statistics {
    override fun  median(a: MultiArray): Double? {
        val size = a.size
        return when {
            size == 1 -> a.first().toDouble()
            size > 1 -> {
                val sorted = a.sorted()
                val mid = size / 2
                if (size % 2 != 0) {
                    sorted.data[mid].toDouble()
                } else {
                    (sorted.data[mid - 1].toDouble() + sorted.data[mid].toDouble()) / 2
                }
            }
            else -> null
        }
    }

    override fun  average(a: MultiArray, weights: MultiArray?): Double {
        if (weights == null) return mean(a)
        return mk.math.sum(a * weights).toDouble() / mk.math.sum(weights).toDouble()
    }

    override fun  mean(a: MultiArray): Double {
        val ret = KEMath.sum(a)
        return ret.toDouble() / a.size
    }

    override fun  mean(a: MultiArray, axis: Int): NDArray {
        require(a.dim.d > 1) { "NDArray of dimension one, use the `mean` function without axis." }
        require(axis in 0 until a.dim.d) { "axis $axis is out of bounds for this ndarray of dimension ${a.dim.d}." }
        val newShape = a.shape.remove(axis)
        val retData = initMemoryView(newShape.fold(1, Int::times), DataType.DoubleDataType)
        val indexMap: MutableMap = mutableMapOf()
        for (i in a.shape.indices) {
            if (i == axis) continue
            indexMap[i] = 0.r..a.shape[i]
        }
        for (index in 0 until a.shape[axis]) {
            indexMap[axis] = index.r
            val t = a.slice(indexMap)
            var count = 0
            for (element in t) {
                retData[count] += element.toDouble()
                count++
            }
        }

        return NDArray(
            retData, 0, newShape, dim = dimensionOf(newShape.size)
        ) / a.shape[axis].toDouble()
    }

    override fun  meanD2(a: MultiArray, axis: Int): NDArray = mean(a, axis)

    override fun  meanD3(a: MultiArray, axis: Int): NDArray = mean(a, axis)

    override fun  meanD4(a: MultiArray, axis: Int): NDArray = mean(a, axis)

    override fun  meanDN(a: MultiArray, axis: Int): NDArray = mean(a, axis)
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy