
commonMain.org.jetbrains.kotlinx.multik.kotlin.KEStatistics.kt Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of multik-kotlin-jvm Show documentation
Show all versions of multik-kotlin-jvm Show documentation
Multidimensional array library for Kotlin.
/*
* Copyright 2020-2021 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
*/
package org.jetbrains.kotlinx.multik.kotlin
import org.jetbrains.kotlinx.multik.api.Statistics
import org.jetbrains.kotlinx.multik.api.mk
import org.jetbrains.kotlinx.multik.kotlin.math.KEMath
import org.jetbrains.kotlinx.multik.kotlin.math.remove
import org.jetbrains.kotlinx.multik.ndarray.data.*
import org.jetbrains.kotlinx.multik.ndarray.operations.div
import org.jetbrains.kotlinx.multik.ndarray.operations.first
import org.jetbrains.kotlinx.multik.ndarray.operations.sorted
import org.jetbrains.kotlinx.multik.ndarray.operations.times
public object KEStatistics : Statistics {
override fun median(a: MultiArray): Double? {
val size = a.size
return when {
size == 1 -> a.first().toDouble()
size > 1 -> {
val sorted = a.sorted()
val mid = size / 2
if (size % 2 != 0) {
sorted.data[mid].toDouble()
} else {
(sorted.data[mid - 1].toDouble() + sorted.data[mid].toDouble()) / 2
}
}
else -> null
}
}
override fun average(a: MultiArray, weights: MultiArray?): Double {
if (weights == null) return mean(a)
return mk.math.sum(a * weights).toDouble() / mk.math.sum(weights).toDouble()
}
override fun mean(a: MultiArray): Double {
val ret = KEMath.sum(a)
return ret.toDouble() / a.size
}
override fun mean(a: MultiArray, axis: Int): NDArray {
require(a.dim.d > 1) { "NDArray of dimension one, use the `mean` function without axis." }
require(axis in 0 until a.dim.d) { "axis $axis is out of bounds for this ndarray of dimension ${a.dim.d}." }
val newShape = a.shape.remove(axis)
val retData = initMemoryView(newShape.fold(1, Int::times), DataType.DoubleDataType)
val indexMap: MutableMap = mutableMapOf()
for (i in a.shape.indices) {
if (i == axis) continue
indexMap[i] = 0.r..a.shape[i]
}
for (index in 0 until a.shape[axis]) {
indexMap[axis] = index.r
val t = a.slice(indexMap)
var count = 0
for (element in t) {
retData[count] += element.toDouble()
count++
}
}
return NDArray(
retData, 0, newShape, dim = dimensionOf(newShape.size)
) / a.shape[axis].toDouble()
}
override fun meanD2(a: MultiArray, axis: Int): NDArray = mean(a, axis)
override fun meanD3(a: MultiArray, axis: Int): NDArray = mean(a, axis)
override fun meanD4(a: MultiArray, axis: Int): NDArray = mean(a, axis)
override fun meanDN(a: MultiArray, axis: Int): NDArray = mean(a, axis)
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy