All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.xiaomi.duckling.task.NaiveBayesConsole.scala Maven / Gradle / Ivy

The newest version!
/*
 * Copyright (c) 2020, Xiaomi and/or its affiliates. All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.xiaomi.duckling.task

import java.nio.charset.StandardCharsets
import java.time.ZonedDateTime

import scala.util.control.Breaks.break

import org.fusesource.jansi.Ansi
import org.jline.reader.{EndOfFileException, LineReader, LineReaderBuilder}
import org.jline.reader.impl.completer._
import org.jline.reader.impl.DefaultParser
import org.jline.terminal.TerminalBuilder
import org.json4s.jackson.Serialization.write

import com.typesafe.scalalogging.LazyLogging

import com.xiaomi.duckling.Api
import com.xiaomi.duckling.Api.formatToken
import com.xiaomi.duckling.JsonSerde._
import com.xiaomi.duckling.Types._
import com.xiaomi.duckling.dimension.implicits._
import com.xiaomi.duckling.dimension.matcher._
import com.xiaomi.duckling.dimension.FullDimensions
import com.xiaomi.duckling.ranking.Ranker
import com.xiaomi.duckling.ranking.Testing.testContext
import com.xiaomi.duckling.types.Node

/**
 * sbt duckConsole的jline启动有点问题,可以使用sbt console
 *
 * @example 

sbt core/console *

> com.xiaomi.duckling.task.NaiveBayesConsole.run() *

> dimension time duration *

> 今天的天气 *

> option with-latent false *

... */ object NaiveBayesConsole extends LazyLogging { private val context = testContext.copy(referenceTime = ZonedDateTime.now(ZoneCN)) // 方便设置训练捂的断点 var debug = false def buildLineReader(): LineReader = { val terminal = TerminalBuilder .builder() .encoding(StandardCharsets.UTF_8) .name("DUCK") .build(); val dimension = new ArgumentCompleter( new StringsCompleter("dimension"), new StringsCompleter(FullDimensions.namedDimensions.keys.toList: _*) ) val options = new ArgumentCompleter( new StringsCompleter("option"), new StringsCompleter( "winner-only", "with-latent", "full", "inherit-duration-grain", "seasons", "sequence", "fuzzy-on", "before-end-of-interval", "recent-in-future", "always-in-future" ), NullCompleter.INSTANCE ) val completer = new AggregateCompleter(dimension, options) LineReaderBuilder .builder() .appName("duckling - console") .terminal(terminal) .parser(new DefaultParser()) .completer(completer) .build() } def getPrompt(): String = { Ansi .ansi() .eraseScreen() .fg(Ansi.Color.BLUE) .bold() .a("duckling") .fgBright(Ansi.Color.BLACK) .bold() .a(" > ") .reset() .toString } def setOptions(options: Options, line: String): (Boolean, Options) = { val cols = line.split("\\s+") if (cols.length >= 2) { if (line.startsWith("dimension ")) { val targets = FullDimensions.convert(cols.tail) (true, options.copy(targets = targets)) } else if (line.startsWith("option ")) { val opt = if (cols.length >= 3 && Set("true", "false") .contains(cols(2).toLowerCase)) { val value = cols(2).toBoolean val opt = cols(1) match { case "winner-only" => options.rankOptions.setWinnerOnly(value); options case "with-latent" => options.copy(withLatent = value) case "full" => options.copy(full = value) case "inherit-duration-grain" => options.timeOptions.setInheritGrainOfDuration(value); options case "seasons" => options.timeOptions.setParseFourSeasons(value); options case "sequence" => options.timeOptions.setSequence(value); options case "fuzzy-on" => options.timeOptions.setDurationFuzzyOn(value); options case "before-end-of-interval" => options.timeOptions.setBeforeEndOfInterval(value); options case "recent-in-future" => options.timeOptions.setRecentInFuture(value); options case "always-in-future" => options.timeOptions.setAlwaysInFuture(value); options case _ => options } opt } else { logger.info("{}: true/false expected", cols(1)) options } (true, opt) } else (false, options) } else (false, options) } def round(reader: LineReader, options: Options): (Boolean, Options) = { reader.getTerminal.flush() val line = try { reader.readLine(getPrompt()).trim } catch { case ex: EndOfFileException => logger.info("bye!") "" } val (isOpt, _options) = setOptions(options, line) if (!isOpt) { val answers = Api.analyze(line, context, _options) if (answers.isEmpty) println("empty results") else println(s"found ${answers.size} results") answers.foreach { answer: Answer => val entity = formatToken(line, withNode = true)(answer.token) val json = write(answer.token.value) answer.token.value.schema match { case Some(schema) => println("%.5f => %s\n%s".format(answer.score, schema, json)) case None => println("%.5f => %s".format(answer.score, json)) } ptree(line)(entity) } } (line == "", _options) } def run(): Unit = { var options: Options = Options(targets = Set(), withLatent = false) options.rankOptions.setRanker(Ranker.NaiveBayes) options.rankOptions.setWinnerOnly(true) options.rankOptions.setCombinationRank(false) options.timeOptions.setResetTimeOfDay(false) options.timeOptions.setRecentInFuture(true) options.numeralOptions.setAllowZeroLeadingDigits(false) options.numeralOptions.setCnSequenceAsNumber(false) // 初始化分类器 Api.analyze("今天123", context, options) debug = true val reader = buildLineReader() while (true) { round(reader, options) match { case (true, _) => break() } } } def main(args: Array[String]): Unit = { run() } def pnode(sentence: String, depth: Int)(node: Node): Unit = { val name = node.token.dim match { case RegexMatch => "regex" case VarcharMatch => "varchar" case PhraseMatch => "phrase" case MultiCharMatch => "multi-char" case LexiconMatch => "lexicon" case _ => node.rule.get } val body = sentence.substring(node.range.start, node.range.end) val out = "%s%s[\"%s\"]".format("-- " * depth, name, body) println(out) System.out.flush() node.children.foreach(pnode(sentence, depth + 1)) } def ptree(sentence: String)(entity: Entity): Unit = { entity.enode.foreach(pnode(sentence, 0)) } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy