com.johnsnowlabs.ml.crf.DatasetEncoder.scala Maven / Gradle / Ivy
/*
* Copyright 2017-2022 John Snow Labs
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.johnsnowlabs.ml.crf
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.reflect.ClassTag
class DatasetEncoder(val startLabel: String = "@#Start") {
// Attr Name -> AttrId
val attr2Id = mutable.Map[String, Int]()
// Label Name -> labelId
val label2Id = mutable.Map[String, Int](startLabel -> 0)
// AttrId -> Attr
val attributes = mutable.ArrayBuffer[Attr]()
// (attrId, labelId) -> featureId
val attrFeatures2Id = mutable.Map[(Int, Int), Int]()
// featureId -> AttrFeature
val attrFeatures = ArrayBuffer[AttrFeature]()
// featureId -> freq
val attrFeaturesFreq = ArrayBuffer[Int]()
// featureId -> Sum
val attrFeaturesSum = ArrayBuffer[Float]()
// transition -> freq
val transFeaturesFreq = mutable.Map[Transition, Int]()
// All transitions
def transitions = transFeaturesFreq.keys.toSeq
private def addAttrFeature(label: Int, attr: Int, value: Float): Unit = {
val featureId = attrFeatures2Id.getOrElseUpdate((attr, label), attrFeatures2Id.size)
if (featureId >= attrFeatures.size) {
val feature = AttrFeature(featureId, attr, label)
attrFeatures.append(feature)
attrFeaturesFreq.append(0)
attrFeaturesSum.append(0f)
}
attrFeaturesFreq(featureId) += 1
attrFeaturesSum(featureId) += value
}
private def addTransFeature(fromId: Int, toId: Int): Unit = {
val meta = Transition(fromId, toId)
transFeaturesFreq(meta) = transFeaturesFreq.getOrElse(meta, 0) + 1
}
private def getLabel(label: String): Int = {
label2Id.getOrElseUpdate(label, label2Id.size)
}
private def getAttr(attr: String, isNumerical: Boolean): Int = {
val attrId = attr2Id.getOrElseUpdate(attr, attr2Id.size)
if (attrId >= attributes.size) {
attributes.append(Attr(attrId, attr, isNumerical))
}
attrId
}
def getFeatures(
prevLabel: String = startLabel,
label: String,
binaryAttrs: Seq[String],
numAttrs: Seq[Float]): (Int, SparseArray) = {
val labelId = getLabel(label)
val binFeature = binaryAttrs.map { attr =>
val attrId = getAttr(attr, false)
addAttrFeature(labelId, attrId, 1f)
(attrId, 1f)
}
val numFeatures = numAttrs.zipWithIndex.map {
case (value, idx) => {
val attrId = getAttr("num" + idx, true)
addAttrFeature(labelId, attrId, value)
(attrId, value)
}
}
val fromId = getLabel(prevLabel)
addTransFeature(fromId, labelId)
val features = (binFeature ++ numFeatures)
.sortBy(_._1)
.distinct
.toArray
(labelId, new SparseArray(features))
}
def getMetadata: DatasetMetadata = {
val labels = label2Id.toSeq.sortBy(a => a._2).map(a => a._1).toArray
val transitionsStat = transitions
.map(transition => transFeaturesFreq(transition))
.map(freq => new AttrStat(freq, freq))
val attrsStat = attrFeaturesFreq
.zip(attrFeaturesSum)
.map(p => new AttrStat(p._1, p._2))
new DatasetMetadata(
labels,
copy(attributes),
copy(attrFeatures),
transitions.toArray,
(attrsStat ++ transitionsStat).toArray)
}
private def copy[T: ClassTag](source: IndexedSeq[T]): Array[T] = {
if (source.length == 0) {
Array.empty[T]
} else {
val first = source(0)
val result = Array.fill(source.length)(first)
for (i <- 0 until source.length)
result(i) = source(i)
result
}
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy