![JAR search and dependency download from the Maven repository](/logo.png)
com.johnsnowlabs.ml.ai.util.Generation.Search.BeamHypotheses.scala Maven / Gradle / Ivy
/*
* Copyright 2017 - 2023 John Snow Labs
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.johnsnowlabs.ml.ai.util.Generation.Search
class BeamHypotheses(
var lengthPenalty: Double,
var numBeams: Int,
var earlyStopping: Boolean = false,
var maxLength: Int) {
private var beams: Seq[(Double, Array[Int], Array[Int])] = Seq()
private var worstScore: Double = 1e9
def length(): Int = {
beams.length
}
def getBeams: Seq[(Double, Array[Int], Array[Int])] = {
this.beams
}
/** Add a new hypotheses to the list
* @param hypotheses
* Hypothesis
* @param sumLogProbs
* Sum of Log Probabilities
* @param beamIndices
* Beam Indices
*/
def add(hypotheses: Array[Int], sumLogProbs: Double, beamIndices: Array[Int]): Unit = {
val score = sumLogProbs / Math.pow(hypotheses.length, this.lengthPenalty)
if (this.beams.length < this.numBeams || score > this.worstScore) {
this.beams = beams :+ (score, hypotheses, beamIndices)
if (this.beams.length > this.numBeams) {
val sortedNextScores = this.beams.zipWithIndex.sortBy(_._1._1)
this.beams = this.beams.zipWithIndex.filter(_._2 != sortedNextScores.head._2).map(_._1)
this.worstScore = sortedNextScores(1)._1._1
} else {
this.worstScore = Math.min(score, this.worstScore)
}
}
}
/** If there are enough hypotheses and that none of the hypotheses being generated can become
* better than the worst one in the heap, then we are done with this sentence.
*
* @param bestSumLogProbs
* Best Sum of Log Probabilities
* @param currentLength
* Current Length
* @return
* Status of the sentence
*/
def isDone(bestSumLogProbs: Double, currentLength: Int): Boolean = {
if (this.beams.length < this.numBeams) {
false
} else if (this.earlyStopping) {
true
} else if (!this.earlyStopping) {
val currentScore = bestSumLogProbs / Math.pow(currentLength, this.lengthPenalty)
this.worstScore >= currentScore
} else {
if (this.lengthPenalty > 0) {
this.worstScore >= (bestSumLogProbs / Math.pow(this.maxLength, this.lengthPenalty))
} else {
val currentScore = bestSumLogProbs / Math.pow(currentLength, this.lengthPenalty)
this.worstScore >= currentScore
}
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy