All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ai.salmonbrain.ruleofthumb.MannWhitneyTest.scala Maven / Gradle / Ivy

package ai.salmonbrain.ruleofthumb

import org.apache.commons.math3.distribution.BinomialDistribution
import org.apache.commons.math3.stat.descriptive.moment.Variance
import org.apache.commons.math3.stat.descriptive.rank.Median
import org.apache.commons.math3.stat.inference.MannWhitneyUTest

object MannWhitneyTest extends BaseStatTest {
  val median = new Median()
  val variance = new Variance()

  def mannWhitneyTest(
      control: Array[Double],
      treatment: Array[Double],
      alpha: Double,
      beta: Double,
      useLinearApproximationForVariance: Boolean
  ): StatResult = {
    assert(alpha < 1 && beta < 1)
    val controlMedian = median.evaluate(control)
    val treatmentMedian = median.evaluate(treatment)

    val (treatmentMedianVariance, controlMedianVariance) =
      if (useLinearApproximationForVariance)
        (medianVariance(treatment), medianVariance(control))
      else (variance.evaluate(treatment), variance.evaluate(control))

    (treatmentMedianVariance, controlMedianVariance) match {
      case x if x._1 < EPS || x._2 < EPS =>
        StatResult(
          Double.NaN,
          Double.NaN,
          -1,
          controlMedian,
          treatmentMedian,
          controlMedianVariance,
          treatmentMedianVariance,
          Double.NaN,
          Double.NaN,
          CentralTendency.MEDIAN.toString,
          isZeroVariance = true
        )
      case _ =>
        val mannWhitneyUTest = new MannWhitneyUTest()
        val uStatistic = mannWhitneyUTest.mannWhitneyU(control, treatment)
        val pValue = mannWhitneyUTest.mannWhitneyUTest(control, treatment)

        val std = math.sqrt(
          controlMedianVariance / control.length + treatmentMedianVariance / treatment.length
        )
        val size = math.max(control.length, treatment.length)

        val ci = CI(
          controlMedian,
          controlMedianVariance,
          treatmentMedian,
          treatmentMedianVariance,
          std,
          normalDistribution.inverseCumulativeProbability(alpha / 2),
          normalDistribution.inverseCumulativeProbability(1 - alpha / 2),
          size
        )

        val sampleSize = sampleSizeEstimation(
          alpha,
          beta,
          treatmentMedian,
          controlMedian,
          treatment.length,
          control.length
        )

        StatResult(
          uStatistic,
          pValue,
          sampleSize,
          controlMedian,
          treatmentMedian,
          controlMedianVariance,
          treatmentMedianVariance,
          ci.lowerPercent,
          ci.upperPercent,
          CentralTendency.MEDIAN.toString,
          isZeroVariance = false
        )
    }
  }

  /*
   * https://www.researchgate.net/publication/11148358_Statistical_inference_for_a_linear_function_of_medians_Confidence_intervals_hypothesis_testing_and_sample_size_requirements
   */
  def medianVariance(values: Array[Double]): Double = {
    val sorted = values.sorted
    val y1 = sorted(
      sorted.length - alpha(sorted.length)
    )

    val y2 = sorted(alpha(sorted.length) - 1)
    val zed = zeta(sorted.length)
    square(y1 - y2) / (4 * square(zed))
  }

  private def alpha(length: Int): Int = {
    math
      .round(
        (length + 1.0) / 2 - math.sqrt(length)
      )
      .toInt
  }

  private def aBinomial(length: Int): Double = {
    new BinomialDistribution(length, 0.5).cumulativeProbability(alpha(length) - 1) * 2
  }

  private def zeta(length: Int): Double = {
    normalDistribution.inverseCumulativeProbability(1 - aBinomial(length) / 2)
  }

  /*
   * https://www.researchgate.net/publication/11148358_Statistical_inference_for_a_linear_function_of_medians_Confidence_intervals_hypothesis_testing_and_sample_size_requirements
   */
  def sampleSizeEstimation(
      alpha: Double,
      beta: Double,
      medianTreatment: Double,
      medianControl: Double,
      treatmentSize: Int,
      controlSize: Int
  ): Long = {
    val nominator = math.ceil(
      square(
        normalDistribution.inverseCumulativeProbability(1 - alpha / 2) + normalDistribution
          .inverseCumulativeProbability(1 - beta)
      )
    )

    val denominator = square(medianTreatment - medianControl) / (treatmentSize + controlSize)

    (nominator / denominator).toLong
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy