All Downloads are FREE. Search and download functionalities are using the official Maven repository.

commonMain.jetbrains.datalore.plot.base.stat.Density2dStat.kt Maven / Gradle / Ivy

There is a newer version: 4.5.3-alpha1
Show newest version
/*
 * Copyright (c) 2019. JetBrains s.r.o.
 * Use of this source code is governed by the MIT license that can be found in the LICENSE file.
 */

package jetbrains.datalore.plot.base.stat

import jetbrains.datalore.plot.base.Aes
import jetbrains.datalore.plot.base.DataFrame
import jetbrains.datalore.plot.base.StatContext
import jetbrains.datalore.plot.base.data.TransformVar
import jetbrains.datalore.plot.base.stat.math3.BlockRealMatrix
import jetbrains.datalore.plot.common.data.SeriesUtil

class Density2dStat constructor(
    bandWidthX: Double?,
    bandWidthY: Double?,
    bandWidthMethod: DensityStat.BandWidthMethod,  // Used is `bandWidth` is not set.
    adjust: Double,
    kernel: DensityStat.Kernel,
    nX: Int,
    nY: Int,
    isContour: Boolean,
    binCount: Int,
    binWidth: Double

) : AbstractDensity2dStat(
    bandWidthX = bandWidthX,
    bandWidthY = bandWidthY,
    bandWidthMethod = bandWidthMethod,
    adjust = adjust,
    kernel = kernel,
    nX = nX,
    nY = nY,
    isContour = isContour,
    binCount = binCount,
    binWidth = binWidth
) {

    override fun apply(data: DataFrame, statCtx: StatContext, messageConsumer: (s: String) -> Unit): DataFrame {
        if (!hasRequiredValues(data, Aes.X, Aes.Y)) {
            return withEmptyStatValues()
        }

        val xVector = data.getNumeric(TransformVar.X)
        val yVector = data.getNumeric(TransformVar.Y)

        // if no data, return empty
        if (xVector.isEmpty()) {
            return DataFrame.Builder.emptyFrame()
        }

        // if length of x and y doesn't match, throw error
        if (xVector.size != yVector.size) {
            throw RuntimeException("len(x)= " + xVector.size + " and len(y)= " + yVector.size + " doesn't match!")
        }

        val xRange = statCtx.overallXRange()
        val yRange = statCtx.overallYRange()

        val statX = ArrayList()
        val statY = ArrayList()
        val statDensity = ArrayList()

        val bandWidth = DoubleArray(2)
//        bandWidth[0] = if (bandWidths != null) bandWidths!![0] else DensityStatUtil.bandWidth(
//            bandWidthMethod,
//            xVector
//        )
        bandWidth[0] = getBandWidthX(xVector)
//        bandWidth[1] = if (bandWidths != null) bandWidths!![1] else DensityStatUtil.bandWidth(
//            bandWidthMethod,
//            yVector
//        )
        bandWidth[1] = getBandWidthY(yVector)

        val stepsX = DensityStatUtil.createStepValues(xRange!!, nX)
        val stepsY = DensityStatUtil.createStepValues(yRange!!, nY)

        // weight aesthetics
        val groupWeight = BinStatUtil.weightVector(xVector.size, data)

        val matrixX = BlockRealMatrix(
            DensityStatUtil.createRawMatrix(
                xVector,
                stepsX,
                kernelFun,
                bandWidth[0],
                adjust,
                groupWeight
            )
        )
        val matrixY = BlockRealMatrix(
            DensityStatUtil.createRawMatrix(
                yVector,
                stepsY,
                kernelFun,
                bandWidth[1],
                adjust,
                groupWeight
            )
        )
        // size: nY * nX
        val matrixFinal = matrixY.multiply(matrixX.transpose())

        for (row in 0 until nY) {
            for (col in 0 until nX) {
                statX.add(stepsX[col])
                statY.add(stepsY[row])
                statDensity.add(matrixFinal.getEntry(row, col) / SeriesUtil.sum(groupWeight))
                //newGroups.add((double) (int) group);
            }
        }

        if (isContour) {
            val zRange = SeriesUtil.range(statDensity)
            val levels = ContourStatUtil.computeLevels(zRange, binOptions)
                ?: return DataFrame.Builder.emptyFrame()

            val pathListByLevel = ContourStatUtil.computeContours(
                xRange,
                yRange,
                nX,
                nY,
                statDensity,
                levels
            )

            return Contour.getPathDataFrame(levels, pathListByLevel)
        } else {
            return DataFrame.Builder()
                .putNumeric(Stats.X, statX)
                .putNumeric(Stats.Y, statY)
                .putNumeric(Stats.DENSITY, statDensity)
                .build()
        }
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy