All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.tencent.angel.sona.ml.stat.test.StreamingTestMethod.scala Maven / Gradle / Ivy

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.tencent.angel.sona.ml.stat.test

import java.io.Serializable

import scala.language.implicitConversions
import scala.math.pow

import com.twitter.chill.MeatLocker
import org.apache.commons.math3.stat.descriptive.StatisticalSummaryValues
import org.apache.commons.math3.stat.inference.TTest

import org.apache.spark.internal.Logging
import org.apache.spark.streaming.dstream.DStream
import org.apache.spark.util.StatCounter

/**
 * Significance testing methods for [[StreamingTest]]. New 2-sample statistical significance tests
 * should extend [[StreamingTestMethod]] and introduce a new entry in
 * [[StreamingTestMethod.TEST_NAME_TO_OBJECT]]
 */
private[stat] sealed trait StreamingTestMethod extends Serializable {

  val methodName: String
  val nullHypothesis: String

  protected type SummaryPairStream =
    DStream[(StatCounter, StatCounter)]

  /**
   * Perform streaming 2-sample statistical significance testing.
   *
   * @param sampleSummaries stream pairs of summary statistics for the 2 samples
   * @return stream of rest results
   */
  def doTest(sampleSummaries: SummaryPairStream): DStream[StreamingTestResult]

  /**
   * Implicit adapter to convert between streaming summary statistics type and the type required by
   * the t-testing libraries.
   */
  protected implicit def toApacheCommonsStats(
      summaryStats: StatCounter): StatisticalSummaryValues = {
    new StatisticalSummaryValues(
      summaryStats.mean,
      summaryStats.variance,
      summaryStats.count,
      summaryStats.max,
      summaryStats.min,
      summaryStats.mean * summaryStats.count
    )
  }
}

/**
 * Performs Welch's 2-sample t-test. The null hypothesis is that the two data sets have equal mean.
 * This test does not assume equal variance between the two samples and does not assume equal
 * sample size.
 *
 * @see Welch's t-test (Wikipedia)
 */
private[stat] object WelchTTest extends StreamingTestMethod with Logging {

  override final val methodName = "Welch's 2-sample t-test"
  override final val nullHypothesis = "Both groups have same mean"

  private final val tTester = MeatLocker(new TTest())

  override def doTest(data: SummaryPairStream): DStream[StreamingTestResult] =
    data.map[StreamingTestResult]((test _).tupled)

  private def test(
      statsA: StatCounter,
      statsB: StatCounter): StreamingTestResult = {
    def welchDF(sample1: StatisticalSummaryValues, sample2: StatisticalSummaryValues): Double = {
      val s1 = sample1.getVariance
      val n1 = sample1.getN
      val s2 = sample2.getVariance
      val n2 = sample2.getN

      val a = pow(s1, 2) / n1
      val b = pow(s2, 2) / n2

      pow(a + b, 2) / ((pow(a, 2) / (n1 - 1)) + (pow(b, 2) / (n2 - 1)))
    }

    new StreamingTestResult(
      tTester.get.tTest(statsA, statsB),
      welchDF(statsA, statsB),
      tTester.get.t(statsA, statsB),
      methodName,
      nullHypothesis
    )
  }
}

/**
 * Performs Students's 2-sample t-test. The null hypothesis is that the two data sets have equal
 * mean. This test assumes equal variance between the two samples and does not assume equal sample
 * size. For unequal variances, Welch's t-test should be used instead.
 *
 * @see Student's t-test (Wikipedia)
 */
private[stat] object StudentTTest extends StreamingTestMethod with Logging {

  override final val methodName = "Student's 2-sample t-test"
  override final val nullHypothesis = "Both groups have same mean"

  private final val tTester = MeatLocker(new TTest())

  override def doTest(data: SummaryPairStream): DStream[StreamingTestResult] =
    data.map[StreamingTestResult]((test _).tupled)

  private def test(
      statsA: StatCounter,
      statsB: StatCounter): StreamingTestResult = {
    def studentDF(sample1: StatisticalSummaryValues, sample2: StatisticalSummaryValues): Double =
      sample1.getN + sample2.getN - 2

    new StreamingTestResult(
      tTester.get.homoscedasticTTest(statsA, statsB),
      studentDF(statsA, statsB),
      tTester.get.homoscedasticT(statsA, statsB),
      methodName,
      nullHypothesis
    )
  }
}

/**
 * Companion object holding supported [[StreamingTestMethod]] names and handles conversion between
 * strings used in [[StreamingTest]] configuration and actual method implementation.
 *
 * Currently supported tests: `welch`, `student`.
 */
private[stat] object StreamingTestMethod {
  // Note: after new `StreamingTestMethod`s are implemented, please update this map.
  private final val TEST_NAME_TO_OBJECT: Map[String, StreamingTestMethod] = Map(
    "welch" -> WelchTTest,
    "student" -> StudentTTest)

  def getTestMethodFromName(method: String): StreamingTestMethod =
    TEST_NAME_TO_OBJECT.get(method) match {
      case Some(test) => test
      case None =>
        throw new IllegalArgumentException(
          "Unrecognized method name. Supported streaming test methods: "
            + TEST_NAME_TO_OBJECT.keys.mkString(", "))
    }
}





© 2015 - 2025 Weber Informatics LLC | Privacy Policy