All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.xiaomi.duckling.ranking.OverlapResolver.scala Maven / Gradle / Ivy

The newest version!
/*
 * Copyright (c) 2020, Xiaomi and/or its affiliates. All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.xiaomi.duckling.ranking

import com.typesafe.scalalogging.LazyLogging
import scalaz.Scalaz._

import scala.collection.mutable

import com.xiaomi.duckling.Types.{Answer, Rule, Token}
import com.xiaomi.duckling.dimension.Dimension
import com.xiaomi.duckling.dimension.implicits._
import com.xiaomi.duckling.dimension.numeral.{Numeral, NumeralValue}
import com.xiaomi.duckling.dimension.time.enums.Grain
import com.xiaomi.duckling.dimension.time.{Time, TimeData}
import com.xiaomi.duckling.types.Node

object OverlapResolver extends LazyLogging {

  /**
    * 解决[九月三月] => 九月三/三月 有重叠的问题
    */
  def solveConflictOfSameDimension(winners: List[Answer],
                                   answers: Map[Dimension, List[Answer]]): List[Answer] = {
    if (winners.size >= 2) {
      winners
        .groupBy(_.dim)
        .flatMap {
          case (dim, values) =>
            if (values.size >= 2) {
              val buf = mutable.Buffer[Answer]()
              values.sliding(2).foreach {
                case p :: q :: _ =>
                  val a = cross(dim, p, buf.lastOption, q, answers)
                  // 五百六五百六 => 五百六/六五/五百六 => 五百六/五/五百六
                  // 五应该被去掉
                  val rangeCovered = buf.lastOption.exists(_.range.include(a.range)) ||
                    q.range.include(a.range)
                  if (!rangeCovered || p.range == q.range) buf += a
                case _ =>
              }
              // TODO 现在只做了sliding判断左右交叉时,左应该如何处理,最后一个应该与入列的结果的最后一个进行比较
              buf += cross(dim, values.last, buf.lastOption, None, answers)
              buf.toList
            } else values
        }
        .toList
        .sortBy(a => a.range)
    } else winners
  }

  def isCross(a: Answer, b: Answer): Boolean = {
    a.range.end > b.range.start
  }

  def cross(dim: Dimension,
            current: Answer,
            before: Option[Answer],
            after: Option[Answer],
            answers: Map[Dimension, List[Answer]]): Answer = {
    val p1 =
      if (before.nonEmpty && before.get.range != current.range && isCross(before.get, current)) {
        dim match {
          case Numeral => numeralCrossBackward(current, before.get, answers(dim))
          case _       => current
        }
      } else current
    if (after.nonEmpty && after.get.range != p1.range && isCross(p1, after.get)) {
      dim match {
        case Time    => timeCrossForward(p1, after.get, answers(dim))
        case Numeral => numeralCrossForward(p1, after.get, answers(dim))
        case _       => p1
      }
    } else p1
  }

  def rangeMaximum(answers: List[Answer], start: Int, end: Int): Option[Answer] = {
    answers
      .filter(t => t.range.start == start && t.range.end == end)
      .maximumBy(_.score)
  }

  def crossBackward(answers: List[Answer],
                    current: Answer,
                    before: Answer,
                    start: Int,
                    end: Int): Answer = {
    rangeMaximum(answers, start, end) match {
      case Some(max) =>
        logger.info(
          s"numeral cross: ${current.sentence} => ${before.text}/${current.text} => ${before.text}/${max.text}"
        )
        max
      case None =>
        logger.error(s"[${current.dim.name}] empty range result found for: ${current.sentence}")
        current
    }
  }

  def crossForward(answers: List[Answer],
                   current: Answer,
                   after: Answer,
                   start: Int,
                   end: Int): Answer = {
    rangeMaximum(answers, start, end) match {
      case Some(max) =>
        logger.info(
          s"numeral cross: ${current.sentence} => ${current.text}/${after.text} => ${max.text}/${after.text}"
        )
        max
      case None =>
        logger.error(s"[${current.dim.name}] empty range result found for: ${current.sentence}")
        current
    }
  }

  def isRule(a: Answer, rule: Rule): Boolean = {
    a.token.node.rule.contains(rule.name)
  }

  def timeCrossForward(a: Answer, b: Answer, answers: List[Answer]): Answer = {
    // 问题: 十月三月 => 十月三/三月
    namedMonthWithoutDayUnit(a.token.node, "date:  ") match {
      case Some(node) =>
        crossForward(answers, a, b, a.range.start, node.children(0).range.end)
      case _ => a
    }
  }

  /**
    * 求应用规则的最右Node
    */
  def namedMonthWithoutDayUnit(node: Node, rule: String): Option[Node] = {
    if (node.rule.contains(rule)) Some(node)
    else {
      node.children.flatMap(n => namedMonthWithoutDayUnit(n, rule)).maximumBy(_.range.end)
    }
  }

  def grain(a: Answer): Option[Grain] = {
    a.token.node.token match {
      case Token(Time, td: TimeData) => Some(td.timeGrain)
      case _                         => None
    }
  }

  def numeralCrossForward(current: Answer, after: Answer, answers: List[Answer]): Answer = {
    // 三十四百 => 三十四/四百 => 三十/四百
    val case1 = isRule(current, Numeral.ruleNumeralsIntersectConsecutiveUnit) &&
      isRule(after, Numeral.ruleMultiplys)
    // 三十三百四十 => 三十三/三百四十 => 三十/三百四十
    val case2 = isRule(current, Numeral.ruleNumeralsIntersectConsecutiveUnit) &&
      isRule(after, Numeral.ruleNumeralsIntersectConsecutiveUnit)
    // 五三百六 => 五三/三百六 => 五/三百六
    val case3 = isRule(current, Numeral.ruleCnSequence)
    // 十一十二 保持前面的不变
    if (case2 && numeralValue(current) < 100 && numeralValue(after) < 100) current
    else if (case1 || case2 || case3) {
      crossForward(answers, current, after, current.range.start, after.range.start)
    } else current
  }

  def numeralCrossBackward(current: Answer, before: Answer, answers: List[Answer]): Answer = {
    // 十二十三 => 十二/二十三 => 十二/十三
    val case1 = isRule(before, Numeral.ruleNumeralsIntersectConsecutiveUnit) &&
      isRule(current, Numeral.ruleNumeralsIntersectConsecutiveUnit)
    // 二十六点二十 => 二十六点二/二十 => 二十六点二/十
    val case2 = isRule(before, Numeral.ruleDecimalCharNumeral) &&
      isRule(current, Numeral.ruleMultiplys)
    // 二百五三
    val case3 = isRule(current, Numeral.ruleCnSequence)
    if (case1 || case2 || case3) {
      crossBackward(answers, current, before, before.range.end, current.range.end)
    } else current
  }

  def numeralValue(a: Answer): Double = {
    a.token.value.asInstanceOf[NumeralValue].n
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy