All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.xiaomi.duckling.DuckParser.scala Maven / Gradle / Ivy

The newest version!
package com.xiaomi.duckling
import java.util.{Set => JSet}
import java.util

import scala.collection.JavaConverters._

import com.typesafe.scalalogging.LazyLogging

import com.xiaomi.duckling.dimension.{Dimension, FullDimensions}
import com.xiaomi.duckling.Types._
import com.xiaomi.duckling.constraints.TokenSpan
import com.xiaomi.duckling.engine.Engine.{parse, resolveNode}
import com.xiaomi.duckling.ranking.{NaiveBayesRank, Ranker}
import com.xiaomi.duckling.ranking.Rank.{rank, resolveAheadByRange}
import com.xiaomi.duckling.types.LanguageInfo

class DuckParser(dimensions: Set[String], modelResource: String = "naive_bayes.kryo") extends LazyLogging {

  def this(dimensions: JSet[String]) = {
    this(dimensions.asScala.toSet)
  }

  private val dims: Set[Dimension] = FullDimensions.convert(dimensions)
  val ranker = new NaiveBayesRank(modelResource)

  /**
   * Returns a curated list of resolved tokens found
   * When `targets` is non-empty, returns only tokens of such dimensions.
   *
   * @param lang
   * @param context
   * @param options
   * @return
   */
  def analyze(lang: LanguageInfo, context: Context, options: Options): List[Answer] = {
    val input = lang.sentence

    val _targets = options.targets.intersect(dims)
    if (_targets.size != options.targets.size) {
      logger.warn(s"targets: ${options.targets} prune to ${_targets}")
    }
    val targets = _targets ++ _targets.flatMap(_.nonOverlapDims)

    val rules = Rules.rulesFor(context.locale, targets)
    val nodes = parse(rules, lang, options)
    val doc = Document.fromLang(lang)
    // 去掉非目标对象,可能由依赖带入
    val ofTargets =
      if (targets.isEmpty) nodes
      else nodes.filter(t => options.targets.contains(t.token.dim))

    // 只保留完整解析的
    val fullMatchFiltered =
      if (options.full) ofTargets.filter(_.range match {
        case Range(0, l) => l == input.length
        case _ => false
      })
      else ofTargets

    val resolvedTokens = {
      // 在需要做 overlap 组合计算时,需要关闭,Range 的包含并不保证组合上的最优
      if (options.rankOptions.rangeRankAhead) {
        resolveAheadByRange(doc, context, options, fullMatchFiltered)
      } else {
        // 目前只有这一种条件,有额外后续再抽取
        fullMatchFiltered.flatMap(resolveNode(doc, context, options)).filter(TokenSpan.isValid(lang, _))
      }
    }

    // 只保留非latent的
    val latentFiltered =
      if (!options.withLatent) resolvedTokens.filterNot(rt => rt.isLatent)
      else resolvedTokens

    val answers = latentFiltered.map(Answer(input, _))

    // 增加覆盖过滤
    val overlapFiltered = nonOverlap(
      answers.toIndexedSeq,
      options.targets.flatMap(_.nonOverlapDims).diff(options.targets)
    )

    // 排序,相同范围/dim的结果,保留概率最高的
    val ranked = options.rankOptions.ranker match {
      case Some(Ranker.NaiveBayes) =>
        rank(ranker.score, targets, overlapFiltered, options.rankOptions)
      case _ => overlapFiltered
    }

    ranked
  }

  /**
   * for scala
   */
  def analyze(input: String, context: Context, options: Options): List[Answer] = {
    analyze(LanguageInfo.fromText(input, options.enableAnalyzer), context, options)
  }

  /**
   * for java
   */
  def analyzeJ(input: String, context: Context, options: Options): util.List[Answer] = {
    analyze(input, context, options).asJava
  }

  /**
   * 去掉有dim重叠的结果,比如:
   * 第五十五号 => [五十五, 十五号]
   */
  def nonOverlap(a: Seq[Answer], dims: Set[Dimension]): List[Answer] = {
    def overlap(r1: Types.Range, r2: Types.Range): Boolean = {
      r1.start < r2.start && r1.end < r2.end && r1.end > r2.start ||
        r2.start < r1.start && r2.end < r1.end && r2.end > r1.start
    }

    val nonOverlap = (for (i <- a.indices) yield {
      val f1 = a.indices.filter(j => a(i).dim.nonOverlapDims.contains(a(j).dim))

      val jOpt = f1.find { j =>
        overlap(a(i).token.range, a(j).token.range) &&
          a(i).dim.nonOverlapDims.contains(a(j).dim) &&
          a(i).dim.overlap(a(j).token.node.token)
      }

      jOpt match {
        case Some(j) =>
          logger.debug(s"overlap found: ${a(i).text} & ${a(j).text} => drop ${a(i).text}")
          None
        case None => Some(a(i))
      }
    }).flatten.toList

    nonOverlap.filter(a => !dims.contains(a.dim))
  }
}

object DuckParser {

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy