All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.pmml4s.model.NaiveBayesModel.scala Maven / Gradle / Ivy

/*
 * Copyright (c) 2017-2024 AutoDeployAI
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.pmml4s.model

import org.pmml4s.common.MiningFunction.MiningFunction
import org.pmml4s.common._
import org.pmml4s.data.{DataVal, Series}
import org.pmml4s.metadata.{Field, MiningSchema, Output, Targets}
import org.pmml4s.transformations.{DerivedField, LocalTransformations}
import org.pmml4s.util.Utils

import scala.collection.immutable

/**
 * Naïve Bayes uses Bayes' Theorem, combined with a ("naive") presumption of conditional independence, to predict the
 * value of a target (output), from evidence given by one or more predictor (input) fields.
 *
 * Naïve Bayes models require the target field to be discretized so that a finite number of values are considered by
 * the model.
 */
class NaiveBayesModel(
                       var parent: Model,
                       override val attributes: NaiveBayesAttributes,
                       override val miningSchema: MiningSchema,
                       val bayesInputs: BayesInputs,
                       val bayesOutput: BayesOutput,
                       override val output: Option[Output] = None,
                       override val targets: Option[Targets] = None,
                       override val localTransformations: Option[LocalTransformations] = None,
                       override val modelStats: Option[ModelStats] = None,
                       override val modelExplanation: Option[ModelExplanation] = None,
                       override val modelVerification: Option[ModelVerification] = None,
                       override val extensions: immutable.Seq[Extension] = immutable.Seq.empty)
  extends Model with HasWrappedNaiveBayesAttributes {

  bayesInputs.inputs.foreach(_.init(targetField, threshold))
  private val targetCounts: Array[Double] = classes.map(x => math.log(bayesOutput.targetValueCounts.countOf(x)))


  /** Model element type. */
  override def modelElement: ModelElement = ModelElement.NaiveBayesModel

  /** Predicts values for a given data series. */
  override def predict(values: Series): Series = {
    val (series, returnInvalid) = prepare(values)
    if (returnInvalid) {
      return nullSeries
    }

    val probLog = targetCounts.clone()
    for (bayesInput <- bayesInputs.inputs) {
      val p = bayesInput.eval(series, threshold)

      var i = 0
      while (i < p.length) {
        // using log-sum-exp trick to fix underflow and overflow
        probLog(i) += math.log(p(i))
        i += 1
      }
    }

    val max = probLog.max
    val probExp = probLog.map(x => math.exp(x - max))
    val sum = probExp.sum

    val outputs = createOutputs().setProbabilities(classes.zip(probExp.map(_ / sum)).toMap)
    outputs.evalPredictedValueByProbabilities()

    result(series, outputs)
  }

  /** Creates an object of NaiveBayesOutputs that is for writing into an output series.  */
  override def createOutputs(): NaiveBayesOutputs = new NaiveBayesOutputs
}

/** Contains several BayesInput elements. */
class BayesInputs(val inputs: Array[BayesInput]) extends PmmlElement

/**
 * For a discrete field, each BayesInput contains the counts pairing the discrete values of that field with those of the
 * target field.
 * For a continuous field, the BayesInput element lists the distributions obtained for that field with each value of the
 * target field. BayesInput may also be used to define how continuous values are encoded as discrete bins.
 * (Discretization is achieved using DerivedField; only the Discretize mapping for DerivedField may be invoked here).
 *
 * Note that a BayesInput element encompasses either one TargetValueStats element or one or more PairCounts elements.
 * Element DerivedField can only be used in conjunction with PairCounts.
 */
class BayesInput(val fieldName: Field,
                 val targetValueStats: Option[TargetValueStats],
                 val pairCounts: Array[PairCounts],
                 val derivedField: Option[DerivedField] = None) extends PmmlElement {

  private var distributions: Array[ContinuousDistribution] = null
  private var probabilities: Array[Array[Double]] = null

  def init(target: Field, threshold: Double): Unit = {
    val classes = target.validValues
    val l = classes.length
    if (targetValueStats.isDefined) {
      distributions = Array.ofDim(l)
      var i = 0
      while (i < l) {
        distributions(i) = targetValueStats.get.getDist(classes(i)).get
        i += 1
      }
    } else {
      val f = derivedField.getOrElse(fieldName)
      val n = derivedField.map(_.numCategories).getOrElse(fieldName.numCategories)
      val counts = Array.ofDim[Double](l)
      probabilities = Array.ofDim(n, l)

      for (pair <- pairCounts) {
        val i = f.encode(pair.value)
        for (c <- pair.targetValueCounts.targetValueCounts) {
          val j = target.encode(c.value)
          probabilities(i.toInt)(j.toInt) = c.count
          counts(j.toInt) += c.count
        }
      }

      var i = 0
      while (i < n) {
        var j = 0
        while (j < l) {
          if (probabilities(i)(j) > 0) {
            probabilities(i)(j) /= counts(j)
          } else {
            probabilities(i)(j) = threshold
          }
          j += 1
        }
        i += 1
      }
    }
  }

  def eval(series: Series, threshold: Double): Array[Double] = {
    if (fieldName.isMissing(series)) {
      Array.emptyDoubleArray
    } else if (distributions != null) {
      distributions.map(x => Math.max(threshold, x.probability(Utils.toDouble(fieldName.get(series)))))
    } else {
      val v = derivedField.map(x => x.encode(x.eval(series))).getOrElse(fieldName.encode(series))
      probabilities(v.toInt)
    }
  }
}


/** Serves as the envelope for element TargetValueStat. */
class TargetValueStats(val targetValueStats: Array[TargetValueStat]) extends PmmlElement {
  private val map = targetValueStats.map(x => (x.value, x.distribution)).toMap

  def getDist(x: Any): Option[ContinuousDistribution] = map.get(x)
}

/**
 * PairCounts lists, for a field Ii's discrete value Iij, the TargetValueCounts that pair the value Iij with each value
 * of the target field.
 */
class PairCounts(val value: DataVal, val targetValueCounts: TargetValueCounts) extends PmmlElement

/**
 * Used for a continuous input field Ii to define statistical measures associated with each value of the target field.
 * As defined in CONTINUOUS-DISTRIBUTION-TYPES, different distribution types can be used to represent such measures.
 * For Bayes models, these are restricted to Gaussian and Poisson distributions.
 */
class TargetValueStat(val value: Any, val distribution: ContinuousDistribution) extends PmmlElement {
  require(distribution.distType == ContinuousDistributionType.GAUSSIAN ||
    distribution.distType == ContinuousDistributionType.POISSON,
    s"Both Gaussian and Poisson distributions are only available for Bayes models, but got ${distribution}")
}

/** Contains the counts associated with the values of the target field. */
class BayesOutput(val fieldName: Field, val targetValueCounts: TargetValueCounts) extends PmmlElement

/**
 * Lists the counts associated with each value of the target field, However, a TargetValueCount whose count is zero may
 * be omitted.
 * Within BayesOutput, TargetValueCounts lists the total count of occurrences of each target value.
 * Within PairCounts, TargetValueCounts lists, for each target value, the count of the joint occurrences of that target
 * value with a particular discrete input value.
 */
class TargetValueCounts(val targetValueCounts: Array[TargetValueCount]) extends PmmlElement {
  private val map = targetValueCounts.map(x => (x.value, x.count)).toMap

  def countOf(value: DataVal): Double = map.getOrElse(value, 0.0)
}

class TargetValueCount(val value: DataVal, val count: Double) extends PmmlElement

trait HasNaiveBayesAttributes extends HasModelAttributes {

  /**
   * Specifies a default (usually very small) probability to use in lieu of P(Ij* | Tk) when count[Ij*Ti] is zero.
   * Similarly, since the probability of a continuous distribution can reach the value of 0 as the lower limit, the
   * same threshold parameter is used as the probability of the continuous variable when the calculated probability of
   * the distribution falls below that value.
   */
  def threshold: Double
}

trait HasWrappedNaiveBayesAttributes extends HasWrappedModelAttributes with HasNaiveBayesAttributes {

  override def attributes: NaiveBayesAttributes

  def threshold: Double = attributes.threshold
}

class NaiveBayesAttributes(
                            override val threshold: Double,
                            override val functionName: MiningFunction,
                            override val modelName: Option[String] = None,
                            override val algorithmName: Option[String] = None,
                            override val isScorable: Boolean = true
                          ) extends ModelAttributes(functionName, modelName, algorithmName, isScorable)
  with HasNaiveBayesAttributes

class NaiveBayesOutputs extends ClsOutputs {
  override def modelElement: ModelElement = ModelElement.NaiveBayesModel
}





© 2015 - 2025 Weber Informatics LLC | Privacy Policy