All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.github.steveash.jg2p.eval.BulkEval.groovy Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2015 Steve Ash
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.github.steveash.jg2p.eval

import com.github.steveash.jg2p.DuplicateStrippingEncoder
import com.github.steveash.jg2p.Encoder
import com.github.steveash.jg2p.EncodingResult
import com.github.steveash.jg2p.align.InputRecord
import com.github.steveash.jg2p.util.Percent
import com.google.common.base.Preconditions
import com.google.common.collect.Lists
import com.google.common.util.concurrent.RateLimiter
import groovyx.gpars.GParsPool
import org.slf4j.Logger
import org.slf4j.LoggerFactory

/**
 * Evaluates a bulk input (and optionally deals with classifying mulitple pronunciations
 * @author Steve Ash
 */
class BulkEval {

  private static final Logger log = LoggerFactory.getLogger(BulkEval.class);

  private final int considerTopK = 4;
  private final Encoder encoder;
  private final RateLimiter limiter = RateLimiter.create(1.0 / 2.0)

  public BulkEval(Encoder encoder) {
    this.encoder = DuplicateStrippingEncoder.decorateIfNotAlready(encoder)
  }

  public EvalStats groupAndEval(Iterable ungrouped) {
    // group the inputs into Candidate instances that have all acceptable answers
    ArrayList inputGroups = groupInputRecords(ungrouped)

    return eval(inputGroups)
  }

  public static List groupInputRecords(Iterable ungrouped) {
    def groupedInput = ungrouped.groupBy { it.left }
    def inputGroups = Lists.newArrayListWithCapacity(groupedInput.size())
    groupedInput.values().each { grp ->
      inputGroups << new InputRecordGroup(grp[0].left, grp.collect { it.right }.toSet())
    }
    inputGroups
  }

  public EvalStats eval(Collection groups) {
    log.info("Evaluating {} input word groups", groups.size())
    def stats = new EvalStats()
    def totalGroups = groups.size()
    GParsPool.withPool {
      groups.everyParallel { InputRecordGroup group ->
        stats.wordOptionsHisto.add(group.acceptableYWords.size())
        def results = encoder.encode(group.getTestWord())
        Preconditions.checkNotNull(results, "should always get results")
        stats.resultsSizeHisto.add(results.size())
        int currentTotal;
        if (results.isEmpty()) {
          currentTotal = stats.onNewResult(group, null, -1)
        } else {
          int rank = updateTopK(results, group, stats)
          currentTotal = stats.onNewResult(group, results[0], rank)
          updateIrs(results, group, stats)
        }

        if (limiter.tryAcquire()) {
          log.info("Completed " + currentTotal + " of " + totalGroups + "  " + Percent.print(currentTotal, totalGroups))
        }
        return true
      }
    }
    log.info("Finished evaluating all " + totalGroups)
    return stats
  }

  private def updateIrs(List encodings, InputRecordGroup group, EvalStats stats) {
    // updates the IR stats for different configurations
    def good = group.acceptableYWords.size()
    def ranks = calcRanks(encodings, group)
    for (int i = 0; i < this.considerTopK; i++) {
      int goodInResults = countRanks(ranks, Revelant.Good, 0, i)
      int badInResults = countRanks(ranks, Revelant.Bad, 0, i)
      int totalResults = goodInResults + badInResults;
      int possible = Math.min(good, (i + 1));

      if (goodInResults > possible) {
        throw new IllegalStateException("Got $possible possible but counted $goodInResults good from $ranks " +
                                        "based on good words ${group.acceptableYWords} from encodings $encodings")
      }

      stats.irConfigSetup.get("IR_ALL_TOP" + (i+1)).onNewQuery(goodInResults, totalResults, good, possible)
      if (good >= 2) {
        stats.irConfigSetup.get("IR_MULTI_TOP" + (i+1)).onNewQuery(goodInResults, totalResults, good, possible)
      }
    }
  }

  private int countRanks(Revelant[] ranks, Revelant statusToCount, int fromIncl, int toIncl) {
    int count = 0;
    for (int i = fromIncl; i <= toIncl && i < ranks.length; i++) {
      if (ranks[i] == statusToCount) {
        count += 1;
      }
    }
    return count;
  }

  private calcRanks(List encodings, InputRecordGroup group) {
    Revelant[] ranks = new Revelant[this.considerTopK];
    for (int i = 0; i < this.considerTopK; i++) {
      if (i >= encodings.size()) {
        ranks[i] = Revelant.Missing
        continue
      }
      if (group.isMatching(encodings.get(i).phones)) {
        ranks[i] = Revelant.Good
      } else {
        ranks[i] = Revelant.Bad
      }
    }
    return ranks;
  }

  private int updateTopK(List results, InputRecordGroup group, EvalStats stats) {
    // try to find what rank is the matching
    for (int i = 0; i < results.size(); i++) {
      if (group.isMatching(results[i].phones)) {
        stats.counters.add(String.format("RIGHT_TOP_%02d", i))
        return i;
      }
    }
    stats.counters.add("RIGHT_TOP_NONE")
    return -1;
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy