com.mayabot.nlp.perceptron.FeatureSet.kt Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of mynlp Show documentation
Show all versions of mynlp Show documentation
Maya Nlp subproject :mynlp
/*
* Copyright 2018 mayabot.com authors. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.mayabot.nlp.perceptron
import com.mayabot.nlp.collection.dat.DoubleArrayTrie
import java.io.DataInputStream
import java.io.DataOutputStream
import java.io.File
import java.io.InputStream
/**
* 特征集。主要功能是查询feature对应的Id.
* 逻辑上类似
* ["f1","inputSeq2FeatureSet","f3]
* 那么 f1对应的Id是0,f2对应的Id=1
* 每个特征都对应数组下标。
* FeatureSet中每个特征都有一个唯一的Int编号。
* 所谓特征就是字符串构成的特征,具体特征是什么有应用层自己去定义。
* 在mynlp中为了计算性能,需要在第一遍扫描语料时,构建FeatureSet是要求对所有的feature进行排序,
* 然后采用压缩性能更好的DAT结构来存储这些信息。
* @author jimichan
*/
class FeatureSet(
private val dat: DoubleArrayTrie,
/**
* keys存在的话有个好处,模型可以被压缩
*/
var keys: List?
) {
//预留的多一点,就不会有hash冲突
var extMap = HashMap(1000)
var nextId = dat.size()
// 要搞两种模式,一个是训练的时候,就不动了一个是工作模式
/**
* 返回一个特征对应的ID
* @return -1表示特征不存在
*/
fun featureId(feature: String): Int {
val id = dat.indexOf(feature)
if (id >= 0) {
return id
}
return if (extMap.isNotEmpty()) {
extMap[feature] ?: -1
} else {
id
}
}
fun featureId(feature: CharSequence): Int {
val id = dat.indexOf(feature)
if (id >= 0) {
return id
}
return if (extMap.isNotEmpty()) {
extMap[feature.toString()] ?: -1
} else {
id
}
}
fun newExtId(feature: String): Int {
if (dat.indexOf(feature) < 0 && !extMap.containsKey(feature)) {
extMap[feature] = nextId++
return extMap[feature]!!
}
return featureId(feature)
}
/**
* 特征大小
* @return 特征集合的大小
*/
fun size() = dat.size() + extMap.size
/**
* 保存到文件
*/
fun save(datFile: File, textFile: File?) {
datFile.outputStream().buffered().use {
val out = DataOutputStream(it)
dat.write(out)
out.flush()
}
val keys = this.keys
if (keys != null && textFile != null) {
textFile.bufferedWriter().use { writer ->
keys.forEach { line ->
writer.write(line + "\n")
}
}
}
}
companion object {
/**
* 只读取DAT文件
*/
@JvmStatic
fun read(datInput: InputStream): FeatureSet {
return datInput.use {
val datArray = DoubleArrayTrie(DataInputStream(it))
FeatureSet(datArray, null)
}
}
@JvmStatic
fun read(datInput: InputStream, textInput: InputStream): FeatureSet {
return datInput.use {
val datArray = DoubleArrayTrie(DataInputStream(it))
val lines = textInput.use { x ->
x.bufferedReader().readLines()
}
FeatureSet(datArray, lines)
}
}
@JvmStatic
fun readFromText(textInput: InputStream): FeatureSet {
return textInput.use {
val reader = textInput.bufferedReader()
val list = reader.readLines()
val datArray = DoubleArrayTrie(list)
FeatureSet(datArray, list)
}
}
@JvmStatic
fun readFromTextButNotSave(textInput: InputStream): FeatureSet {
return textInput.use {
val reader = textInput.bufferedReader()
val list = reader.readLines()
val datArray = DoubleArrayTrie(list)
FeatureSet(datArray, null)
}
}
}
}
/**
* DAT的特征集合构建器
*/
class DATFeatureSetBuilder(labelCount: Int) {
private val keys = HashSet()
init {
// Hanlp需要从 0=< <= labelCount 上站位 占用labelCount+1个位置
// BL= 要保证这个排在前面
for (i in 0..labelCount) {
keys.add("\u0000\u0001BL=$i")
}
}
fun put(feature: String) {
keys.add(feature)
}
fun build(): FeatureSet {
val list = keys.sorted()
return FeatureSet(DoubleArrayTrie(list), list)
}
}