All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.github.steveash.jg2p.train.EncoderEval Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2014 Steve Ash
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.github.steveash.jg2p.train;

import com.google.common.base.Joiner;
import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.HashMultiset;
import com.google.common.collect.ListMultimap;
import com.google.common.collect.Multiset;
import com.google.common.primitives.Ints;

import com.github.steveash.jg2p.PhoneticEncoder;
import com.github.steveash.jg2p.align.InputRecord;
import com.github.steveash.jg2p.util.ListEditDistance;
import com.github.steveash.jg2p.util.Percent;

import org.apache.commons.lang3.tuple.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.List;
import java.util.Random;

import static com.google.common.collect.Iterables.limit;
import static com.google.common.collect.Multisets.copyHighestCountFirst;

/**
 * @author Steve Ash
 */
@Deprecated // use BulkEval instead
public class EncoderEval {

  private static final Logger log = LoggerFactory.getLogger(EncoderEval.class);
  private static final Joiner spaceJoin = Joiner.on(' ');
  private static final Joiner pipeJoin = Joiner.on('|');

  private static final int EXAMPLE_COUNT = 100;
  private static final int MAX_EXAMPLE_TO_PRINT = 15;

  public enum PrintOpts {ALL, SIMPLE}

  private final PhoneticEncoder encoder;
  private final boolean collectExamples;
  private long totalPhones;
  private long totalRightPhones;
  private long totalWords;
  private long totalRightWords;
  private long noCodes;
  private final Multiset phoneEditHisto = HashMultiset.create();
  private final Multiset rightAnswerInTop = HashMultiset.create();
  private final ListMultimap>>
      examples =
      ArrayListMultimap.create();
  private final Random rand = new Random(0xFEEDFEED);

  public EncoderEval(PhoneticEncoder encoder) {
    this(encoder, false);
  }

  public EncoderEval(PhoneticEncoder encoder, boolean collectExamples) {
    this.encoder = encoder;
    this.collectExamples = collectExamples;
  }

  public void mergeFrom(EncoderEval that) {
    this.totalPhones += that.totalPhones;
    this.totalRightPhones += that.totalRightPhones;
    this.totalWords += that.totalWords;
    this.totalRightWords += that.totalRightWords;
    this.noCodes += that.noCodes;
    this.phoneEditHisto.addAll(that.phoneEditHisto);
    this.rightAnswerInTop.addAll(that.rightAnswerInTop);
    this.examples.putAll(that.examples);
  }

  public void evalAndPrint(List inputs, PrintOpts opts) {
    doWork(inputs, true, opts);
  }

  public void evalNoPrint(List inputs) {
    doWork(inputs, false, null);
  }

  private void doWork(List inputs, boolean shouldPrint, PrintOpts opts) {
    totalPhones = 0;
    totalRightPhones = 0;
    totalWords = 0;
    totalRightWords = 0;
    noCodes = 0;
    rightAnswerInTop.clear();
    examples.clear();
    phoneEditHisto.clear();

    for (InputRecord input : inputs) {

      List encodings = encoder.encode(input.xWord);
      if (encodings.isEmpty()) {
        noCodes += 1;
        continue;
      }

      totalWords += 1;
      PhoneticEncoder.Encoding encoding = encodings.get(0);

      List expected = input.yWord.getValue();
      int phonesDiff = ListEditDistance.editDistance(encoding.getPhones(), expected, 100);
      totalPhones += expected.size();
      totalRightPhones += (expected.size() - phonesDiff);
      phoneEditHisto.add(phonesDiff);
      if (phonesDiff == 0) {
        totalRightWords += 1;
        rightAnswerInTop.add(0);
      }
      if (phonesDiff > 0) {
        // find out if the right encoding is in the top-k results
        for (int i = 1; i < encodings.size(); i++) {
          PhoneticEncoder.Encoding attempt = encodings.get(i);
          if (attempt.getPhones().equals(input.yWord.getValue())) {
            rightAnswerInTop.add(i);
            break;
          }
        }
      }
      if (collectExamples && phonesDiff > 0) {
        Pair> example = Pair.of(input, encodings);
        List>> examples = this.examples.get(phonesDiff);
        if (examples.size() < EXAMPLE_COUNT) {
          examples.add(example);
        } else {
          int victim = rand.nextInt(Ints.saturatedCast(totalWords));
          if (victim < EXAMPLE_COUNT) {
            examples.set(victim, example);
          }
        }
      }

      if (shouldPrint) {
        if (totalWords % 500 == 0 && opts != PrintOpts.SIMPLE) {
          log.info("Processed " + totalWords + " ...");
          if (totalWords % 10_000 == 0) {
            printStats(opts);
          }
        }
      }
    }
    if (shouldPrint) {
      printStats(opts);
    }
  }

  public void printStats(PrintOpts opts) {
    if (opts != PrintOpts.SIMPLE) {
      if (collectExamples) {
        printExamples();
      }
      log.info("Phone edit distance histo: ");
      int total = 0;
      for (Multiset.Entry entry : phoneEditHisto.entrySet()) {
        total += entry.getCount();
        log.info("  " + entry.getElement() + " = " + entry.getCount() + " - " + Percent.print(total, totalWords));
      }
      log.info("No phones words that were skipped " + noCodes);
      log.info("Answer found in top-k answer?");
      total = 0;
      for (Multiset.Entry entry : copyHighestCountFirst(rightAnswerInTop).entrySet()) {
        total += entry.getCount();
        log.info(
            "  In top " + entry.getElement() + " - " + entry.getCount() + " - " + Percent.print(total, totalWords));
      }
    }

    log.info("Total words " + totalWords + ", total right " + totalRightWords + " - " + Percent
        .print(totalRightWords, totalWords));
    log.info("Total phones " + totalPhones + ", total right " + totalRightPhones + " - " +
             Percent.print(totalRightPhones, totalPhones));

  }

  private void printExamples() {
    for (Integer phoneEdit : examples.keySet()) {
      log.info(" ---- Examples with edit distance " + phoneEdit + " ----");
      Iterable>> toPrint =
          limit(examples.get(phoneEdit), MAX_EXAMPLE_TO_PRINT);

      for (Pair> example : toPrint) {
        String got = "";
        if (example.getRight().size() > 0) {
          got = example.getRight().get(0).toString();
        }
        log.info(" Got: " + got + " expected: " + example.getLeft().getRight().getAsSpaceString());
      }
    }
  }

  public long getTotalPhones() {
    return totalPhones;
  }

  public long getTotalRightPhones() {
    return totalRightPhones;
  }

  public long getTotalWords() {
    return totalWords;
  }

  public long getTotalRightWords() {
    return totalRightWords;
  }

  public long getNoCodes() {
    return noCodes;
  }

  public ListMultimap>> getExamples() {
    return examples;
  }

  public Multiset getPhoneEditHisto() {
    return phoneEditHisto;
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy