com.github.steveash.kylm.main.CrossEntropy Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of kylm Show documentation
Show all versions of kylm Show documentation
KYLM language modelling for java (forked from official repo to make production ready)
/*
$Rev$
The Kyoto Language Modeling Toolkit.
Copyright (C) 2009 Kylm Development Team
This library is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 3 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public
License along with this library; if not, write to the Free Software
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
*/
package com.github.steveash.kylm.main;
import java.io.FileInputStream;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.List;
import com.github.steveash.kylm.model.LanguageModel;
import com.github.steveash.kylm.model.ngram.reader.ArpaNgramReader;
import com.github.steveash.kylm.model.ngram.reader.SerializedNgramReader;
import com.github.steveash.kylm.reader.TextStreamSentenceReader;
import com.github.steveash.kylm.util.KylmConfigUtils;
import com.github.steveash.kylm.util.KylmMathUtils;
import com.github.steveash.kylm.util.KylmTextUtils;
import com.github.steveash.kylm.util.SymbolSet;
/**
* A program to calculate cross entropy and perplexity of one or more models
* @author neubig
*/
public class CrossEntropy {
public static String makeEnt(float all, float simp, float cls, float unk, String unkSym) {
StringBuilder sb = new StringBuilder();
sb.append(all);
if (simp != all) {
sb.append("(s=");
sb.append(simp);
if (cls != 0)
sb.append(",c=").append(cls);
if (unk != 0) {
sb.append(",u");
if (unkSym != null)
sb.append('[').append(unkSym).append(']');
sb.append('=').append(unk);
}
sb.append(')');
}
return sb.toString();
}
public static void main(String args[]) throws Exception {
final String br = System.getProperty("line.separator");
KylmConfigUtils config = new KylmConfigUtils(
"CrossEntropy" + br +
"A program to find the cross-entropy of one or more language models over a test set" + br +
"Example: java -cp kylm.jar kylm.main.CrossEntropy -arpa model1.arpa:model2.arpa test.txt");
// Input format options
config.addEntry("arpa", KylmConfigUtils.STRING_ARRAY_TYPE, null, false, "models in arpa format (model1.arpa:model2.arpa)");
config.addEntry("bin", KylmConfigUtils.STRING_ARRAY_TYPE, null, false, "models in binary format (model3.bin:model4.bin)");
// Debugging options
config.addEntry("debug", KylmConfigUtils.INT_TYPE, 0, false, "the level of debugging information to print");
// parse the arguments
args = config.parseArguments(args);
int debug = config.getInt("debug");
// a vector to hold the models
List models = new ArrayList<>();
// load the arpa files
String[] arpaFiles = config.getStringArray("arpa");
if (arpaFiles != null) {
ArpaNgramReader anr = new ArpaNgramReader();
for (String arpa : arpaFiles) {
LanguageModel next = anr.read(arpa);
if (next.getName() == null) next.setName(arpa);
models.add(next);
}
}
// load the binary files
String[] binFiles = config.getStringArray("bin");
if (binFiles != null) {
SerializedNgramReader snr = new SerializedNgramReader();
for (String bin : binFiles) {
LanguageModel next = snr.read(bin);
if (next.getName() == null) next.setName(bin);
models.add(next);
}
}
// check to make sure at least one language model has been loaded
if (models.size() == 0) {
System.err.println("At least one language model must be specified." + br);
config.exitOnUsage(1);
}
// get the input stream to load the input
InputStream is = (args.length == 0 ? System.in : new FileInputStream(args[0]));
TextStreamSentenceReader tssl = new TextStreamSentenceReader(is);
// calculate the entropies
float[] words = new float[models.size()], simples = new float[models.size()],
unknowns = new float[models.size()], classes = new float[models.size()];
float[] wordSents = new float[words.length], simpleSents = new float[words.length],
unkSents = new float[words.length], classSents = new float[words.length];
float[][] wordEnts = new float[words.length][], simpleEnts = new float[words.length][],
unkEnts = new float[words.length][], classEnts = new float[words.length][];
String[][] unkSyms = new String[words.length][];
int wordCount = 0, sentenceCount = 0;
for (String[] sent : tssl) {
wordCount += sent.length;
sentenceCount++;
// calculate
for (int i = 0; i < words.length; i++) {
LanguageModel mod = models.get(i);
wordEnts[i] = mod.getWordEntropies(sent);
words[i] += (wordSents[i] = KylmMathUtils.sum(wordEnts[i]));
simpleEnts[i] = mod.getSimpleEntropies();
simples[i] += (simpleSents[i] = KylmMathUtils.sum(simpleEnts[i]));
classEnts[i] = mod.getClassEntropies();
classes[i] += (classSents[i] = KylmMathUtils.sum(classEnts[i]));
unkEnts[i] = mod.getUnknownEntropies();
unknowns[i] += (unkSents[i] = KylmMathUtils.sum(unkEnts[i]));
unkSyms[i] = new String[unkEnts[i].length];
SymbolSet vocab = models.get(i).getVocab();
for (int j = 0; j < unkEnts[i].length; j++)
if (unkEnts[i][j] != 0)
unkSyms[i][j] = vocab.getSymbol(models.get(i).findUnknownId(sent[j]));
}
if (debug > 0) {
System.out.println(KylmTextUtils.join(" ", sent));
for (int i = 0; i < wordSents.length; i++)
System.out.println(models.get(i).getName() + ": " +
makeEnt(wordSents[i], simpleSents[i], classSents[i], unkSents[i], null));
if (debug > 1) {
for (int j = 0; j < wordEnts[0].length; j++) {
System.out.print(" " + (j < sent.length ? sent[j] : models.get(0)
.getTerminalSymbol()) + "\tent: ");
for (int i = 0; i < wordEnts.length; i++) {
if (i != 0) System.out.print(", ");
System.out.print(makeEnt(wordEnts[i][j], simpleEnts[i][j], classEnts[i][j], unkEnts[i][j], unkSyms[i][j]));
}
System.out.println();
}
}
System.out.println();
}
}
// change from log10
final float log2 = (float) Math.log10(2);
for (int i = 0; i < words.length; i++) {
System.out.println("Found entropy over " + wordCount + " words, " + sentenceCount + " sentences");
words[i] /= wordCount * log2 * -1;
simples[i] /= wordCount * log2 * -1;
unknowns[i] /= wordCount * log2 * -1;
classes[i] /= wordCount * log2 * -1;
System.out.print(models.get(i).getName() + ": entropy=" +
makeEnt(words[i], simples[i], classes[i], unknowns[i], null));
System.out.println(", perplexity=" + Math.pow(2, words[i]));
System.out.println(models.get(i).printReport());
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy