com.github.steveash.jg2p.train.EncoderEval Maven / Gradle / Ivy
The newest version!
/*
* Copyright 2014 Steve Ash
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.github.steveash.jg2p.train;
import com.google.common.base.Joiner;
import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.HashMultiset;
import com.google.common.collect.ListMultimap;
import com.google.common.collect.Multiset;
import com.google.common.primitives.Ints;
import com.github.steveash.jg2p.PhoneticEncoder;
import com.github.steveash.jg2p.align.InputRecord;
import com.github.steveash.jg2p.util.ListEditDistance;
import com.github.steveash.jg2p.util.Percent;
import org.apache.commons.lang3.tuple.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.List;
import java.util.Random;
import static com.google.common.collect.Iterables.limit;
import static com.google.common.collect.Multisets.copyHighestCountFirst;
/**
* @author Steve Ash
*/
@Deprecated // use BulkEval instead
public class EncoderEval {
private static final Logger log = LoggerFactory.getLogger(EncoderEval.class);
private static final Joiner spaceJoin = Joiner.on(' ');
private static final Joiner pipeJoin = Joiner.on('|');
private static final int EXAMPLE_COUNT = 100;
private static final int MAX_EXAMPLE_TO_PRINT = 15;
public enum PrintOpts {ALL, SIMPLE}
private final PhoneticEncoder encoder;
private final boolean collectExamples;
private long totalPhones;
private long totalRightPhones;
private long totalWords;
private long totalRightWords;
private long noCodes;
private final Multiset phoneEditHisto = HashMultiset.create();
private final Multiset rightAnswerInTop = HashMultiset.create();
private final ListMultimap>>
examples =
ArrayListMultimap.create();
private final Random rand = new Random(0xFEEDFEED);
public EncoderEval(PhoneticEncoder encoder) {
this(encoder, false);
}
public EncoderEval(PhoneticEncoder encoder, boolean collectExamples) {
this.encoder = encoder;
this.collectExamples = collectExamples;
}
public void mergeFrom(EncoderEval that) {
this.totalPhones += that.totalPhones;
this.totalRightPhones += that.totalRightPhones;
this.totalWords += that.totalWords;
this.totalRightWords += that.totalRightWords;
this.noCodes += that.noCodes;
this.phoneEditHisto.addAll(that.phoneEditHisto);
this.rightAnswerInTop.addAll(that.rightAnswerInTop);
this.examples.putAll(that.examples);
}
public void evalAndPrint(List inputs, PrintOpts opts) {
doWork(inputs, true, opts);
}
public void evalNoPrint(List inputs) {
doWork(inputs, false, null);
}
private void doWork(List inputs, boolean shouldPrint, PrintOpts opts) {
totalPhones = 0;
totalRightPhones = 0;
totalWords = 0;
totalRightWords = 0;
noCodes = 0;
rightAnswerInTop.clear();
examples.clear();
phoneEditHisto.clear();
for (InputRecord input : inputs) {
List encodings = encoder.encode(input.xWord);
if (encodings.isEmpty()) {
noCodes += 1;
continue;
}
totalWords += 1;
PhoneticEncoder.Encoding encoding = encodings.get(0);
List expected = input.yWord.getValue();
int phonesDiff = ListEditDistance.editDistance(encoding.getPhones(), expected, 100);
totalPhones += expected.size();
totalRightPhones += (expected.size() - phonesDiff);
phoneEditHisto.add(phonesDiff);
if (phonesDiff == 0) {
totalRightWords += 1;
rightAnswerInTop.add(0);
}
if (phonesDiff > 0) {
// find out if the right encoding is in the top-k results
for (int i = 1; i < encodings.size(); i++) {
PhoneticEncoder.Encoding attempt = encodings.get(i);
if (attempt.getPhones().equals(input.yWord.getValue())) {
rightAnswerInTop.add(i);
break;
}
}
}
if (collectExamples && phonesDiff > 0) {
Pair> example = Pair.of(input, encodings);
List>> examples = this.examples.get(phonesDiff);
if (examples.size() < EXAMPLE_COUNT) {
examples.add(example);
} else {
int victim = rand.nextInt(Ints.saturatedCast(totalWords));
if (victim < EXAMPLE_COUNT) {
examples.set(victim, example);
}
}
}
if (shouldPrint) {
if (totalWords % 500 == 0 && opts != PrintOpts.SIMPLE) {
log.info("Processed " + totalWords + " ...");
if (totalWords % 10_000 == 0) {
printStats(opts);
}
}
}
}
if (shouldPrint) {
printStats(opts);
}
}
public void printStats(PrintOpts opts) {
if (opts != PrintOpts.SIMPLE) {
if (collectExamples) {
printExamples();
}
log.info("Phone edit distance histo: ");
int total = 0;
for (Multiset.Entry entry : phoneEditHisto.entrySet()) {
total += entry.getCount();
log.info(" " + entry.getElement() + " = " + entry.getCount() + " - " + Percent.print(total, totalWords));
}
log.info("No phones words that were skipped " + noCodes);
log.info("Answer found in top-k answer?");
total = 0;
for (Multiset.Entry entry : copyHighestCountFirst(rightAnswerInTop).entrySet()) {
total += entry.getCount();
log.info(
" In top " + entry.getElement() + " - " + entry.getCount() + " - " + Percent.print(total, totalWords));
}
}
log.info("Total words " + totalWords + ", total right " + totalRightWords + " - " + Percent
.print(totalRightWords, totalWords));
log.info("Total phones " + totalPhones + ", total right " + totalRightPhones + " - " +
Percent.print(totalRightPhones, totalPhones));
}
private void printExamples() {
for (Integer phoneEdit : examples.keySet()) {
log.info(" ---- Examples with edit distance " + phoneEdit + " ----");
Iterable>> toPrint =
limit(examples.get(phoneEdit), MAX_EXAMPLE_TO_PRINT);
for (Pair> example : toPrint) {
String got = "";
if (example.getRight().size() > 0) {
got = example.getRight().get(0).toString();
}
log.info(" Got: " + got + " expected: " + example.getLeft().getRight().getAsSpaceString());
}
}
}
public long getTotalPhones() {
return totalPhones;
}
public long getTotalRightPhones() {
return totalRightPhones;
}
public long getTotalWords() {
return totalWords;
}
public long getTotalRightWords() {
return totalRightWords;
}
public long getNoCodes() {
return noCodes;
}
public ListMultimap>> getExamples() {
return examples;
}
public Multiset getPhoneEditHisto() {
return phoneEditHisto;
}
}