All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.xiaomi.duckling.ranking.NaiveBayesLearning.scala Maven / Gradle / Ivy

The newest version!
/*
 * Copyright (c) 2020, Xiaomi and/or its affiliates. All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.xiaomi.duckling.ranking

import java.time.Duration
import java.util.{Map => JMap}

import scala.collection.mutable
import scala.collection.JavaConverters._

import org.json4s.jackson.Serialization.writePretty

import com.google.common.collect.Maps
import com.typesafe.scalalogging.LazyLogging

import com.xiaomi.duckling.Types._
import com.xiaomi.duckling.engine.Engine._
import com.xiaomi.duckling.ranking.Bayes.Classifier
import com.xiaomi.duckling.ranking.CorpusSets.{dimExamples, Corpus, Example}
import com.xiaomi.duckling.ranking.Types._
import com.xiaomi.duckling.types.Node
import com.xiaomi.duckling.{JsonSerde, Rules}
import com.xiaomi.duckling.JsonSerde._
import com.xiaomi.duckling.ranking.NaiveBayesRank.extractFeatures

object NaiveBayesLearning extends LazyLogging {
  type Classifiers = JMap[String, Classifier]
  type Dataset = mutable.Map[String, List[Datum]]

  /**
    * Probabilistic layer
    * -- Naive Bayes classifier with Laplace smoothing
    * -- Train one classifier per rule, based on the test corpus.
    *
    * @param rules
    * @param corpus
    * @return
    */
  def makeClassifiers(rules: List[Rule], corpus: Corpus): Classifiers = {
    logger.info("training classifiers from corpus")
    val start = System.currentTimeMillis()
    val classifiers = makeDataset(rules, corpus).map {
      case (rule, listDatum) =>
        logger.debug(s"training for $rule")
        (rule, Bayes.train(listDatum))
    }
    val end = System.currentTimeMillis()
    logger.info(s"training done: elapse ${Duration.ofMillis(end - start)}")
    if (rules.length > classifiers.size) {
      val duplicateRules = rules.groupBy(_.name).filter(_._2.length > 1).keys.toList
      if (duplicateRules.nonEmpty) {
        throw new RuntimeException(s"duplicate name of rules detected, fix it $duplicateRules")
      }
    }
    Maps.newHashMap(classifiers.toMap.asJava)
  }

  /**
    * Build a dataset (rule name -> datums)
    * Dataset = Map[Text, List[Datum] ]
    * Datum = (BagOfFeatures, Class)
    * BagOfFeatures = Map[Feature, Int]
    * Feature = Text
    *
    * @param rules
    * @param corpus
    * @return Map[Text, List[Datum] ]
    */
  def makeDataset(rules: List[Rule], corpus: Corpus): Dataset = {
    val (context, options, examples) = corpus
    val ds = examples.foldLeft(mutable.Map(): Dataset)(makeDataset1(rules, context, options))
    ds
  }

  /**
    * Augment the dataset with one example.
    * -- | Add all the nodes contributing to the resolutions of the input sentence.
    * -- | Classes:
    * -- | 1) True (node contributed to a token passing test predicate)
    * -- | 2) False (node didn't contribute to any passing tokens)
    *
    * @param rules
    * @param context
    * @param options
    * @param dataset Map[rule, List[(Map[Feature, Int], Class)] ]
    * @param example
    * @return
    */
  def makeDataset1(rules: List[Rule],
                   context: Context,
                   options: Options)(dataset: Dataset, example: Example): Dataset = {
    val (doc, rv) = example
    val tokens = parseAndResolve(rules, doc, context, options)
    val (ok, ko) = tokens.partition(JsonSerde.simpleCheck(doc, _, rv))

    val nodesOK: Set[Node] = nodes(ok)
    val nodesKO = nodes(ko).diff(nodesOK)

    val ds = updateDataset(klass = false, nodesKO, updateDataset(klass = true, nodesOK, dataset))
    ds
  }

  def subnodes(node: Node): Set[Node] = {
    if (node.children.isEmpty) Set()
    else node.children.map(subnodes).reduce(_ ++ _) + node
  }

  def nodes(rTokens: List[ResolvedToken]): Set[Node] = {
    val nodes = rTokens.map(r => subnodes(r.node))
    if (nodes.nonEmpty) nodes.reduce(_ ++ _)
    else Set()
  }

  def updateDataset(klass: Class, nodes: Set[Node], dataset: Dataset): Dataset = {
    nodes.foldLeft(dataset) {
      case (ds: Dataset, node @ Node(_, _, _, rule, _, _)) =>
        rule match {
          case Some(r) =>
            val features = extractFeatures(node)
            val extracted = List((features, klass))
            if (!ds.contains(r)) ds(r) = extracted
            else ds(r) = ds(r) ++ extracted
            ds
          case None => ds
        }
    }
  }

  def main(args: Array[String]): Unit = {
    if (args.length != 0) {
      val corpus = dimExamples.values.map(d => d.corpus).reduce((a, b) => (a._1, a._2, a._3 ++ b._3))
      val rules = Rules.rulesFor(null, dimExamples.keySet)
      val classifiers = makeClassifiers(rules, corpus)

      val file = args(0)
      val origin = writePretty(classifiers)
      KryoSerde.makeSerializedFile(classifiers, file)
      val out = KryoSerde.loadSerializedFile(file, classOf[Classifiers])

      val after = writePretty(out)
      assert(origin == after)
    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy