All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.joshua.metrics.MinimumChangeBLEU Maven / Gradle / Ivy

/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *  http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package org.apache.joshua.metrics;

import java.util.HashMap;
import java.util.Map;

import org.apache.joshua.util.Algorithms;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class MinimumChangeBLEU extends BLEU {

  private static final Logger LOG = LoggerFactory.getLogger(MinimumChangeBLEU.class);

  // we assume that the source for the paraphrasing run is
  // part of the set of references
  private final int sourceReferenceIndex;
  private final double thresholdWER;


  public MinimumChangeBLEU() {
    super();
    this.sourceReferenceIndex = 0;
    this.thresholdWER = 0.3;
    initialize();
  }


  public MinimumChangeBLEU(String[] options) {
    super(options);
    this.sourceReferenceIndex = Integer.parseInt(options[2]);
    this.thresholdWER = Double.parseDouble(options[3]);
    initialize();
  }


  protected void initialize() {
    metricName = "MC_BLEU";
    toBeMinimized = false;
    // adding 1 to the sufficient stats for regular BLEU
    suffStatsCount = 2 * getMaxGramLength() + 3;

    set_weightsArray();
    set_maxNgramCounts();
  }


  protected void set_maxNgramCounts() {
    @SuppressWarnings("unchecked")
    HashMap[] temp_HMA = new HashMap[numSentences];
    maxNgramCounts = temp_HMA;

    String gram;
    int oldCount, nextCount;

    for (int i = 0; i < numSentences; ++i) {
      // update counts as necessary from the reference translations
      for (int r = 0; r < refsPerSen; ++r) {
        // skip source reference
        if (r == this.sourceReferenceIndex) continue;
        if (maxNgramCounts[i] == null) {
          maxNgramCounts[i] = getNgramCountsAll(refSentences[i][r]);
        } else {
          HashMap nextNgramCounts = getNgramCountsAll(refSentences[i][r]);
          for (Map.Entry entry : nextNgramCounts.entrySet()) {
            gram = entry.getKey();
            nextCount = entry.getValue();

            if (maxNgramCounts[i].containsKey(gram)) {
              oldCount = maxNgramCounts[i].get(gram);
              if (nextCount > oldCount) {
                maxNgramCounts[i].put(gram, nextCount);
              }
            } else { // add it
              maxNgramCounts[i].put(gram, nextCount);
            }
          }
        }
      } // for (r)
    } // for (i)

    // for efficiency, calculate the reference lenghts, which will be used
    // in effLength...
    refWordCount = new int[numSentences][refsPerSen];
    for (int i = 0; i < numSentences; ++i) {
      for (int r = 0; r < refsPerSen; ++r) {
        if (r == this.sourceReferenceIndex) continue;
        refWordCount[i][r] = wordCount(refSentences[i][r]);
      }
    }
  }


  public int[] suffStats(String cand_str, int i) {
    int[] stats = new int[suffStatsCount];

    String[] candidate_words;
    if (!cand_str.equals(""))
      candidate_words = cand_str.split("\\s+");
    else
      candidate_words = new String[0];

    // dropping "_OOV" marker
    for (int j = 0; j < candidate_words.length; j++) {
      if (candidate_words[j].endsWith("_OOV"))
        candidate_words[j] = candidate_words[j].substring(0, candidate_words[j].length() - 4);
    }

    set_prec_suffStats(stats, candidate_words, i);
    String[] source_words = refSentences[i][sourceReferenceIndex].split("\\s+");
    stats[suffStatsCount - 1] = Algorithms.levenshtein(candidate_words, source_words);
    stats[suffStatsCount - 2] = effLength(candidate_words.length, i);
    stats[suffStatsCount - 3] = candidate_words.length;

    return stats;
  }


  public int effLength(int candLength, int i) {
    if (getEffLengthMethod() == EffectiveLengthMethod.CLOSEST) {
      int closestRefLength = Integer.MIN_VALUE;
      int minDiff = Math.abs(candLength - closestRefLength);

      for (int r = 0; r < refsPerSen; ++r) {
        if (r == this.sourceReferenceIndex) continue;
        int nextRefLength = refWordCount[i][r];
        int nextDiff = Math.abs(candLength - nextRefLength);

        if (nextDiff < minDiff) {
          closestRefLength = nextRefLength;
          minDiff = nextDiff;
        } else if (nextDiff == minDiff && nextRefLength < closestRefLength) {
          closestRefLength = nextRefLength;
          minDiff = nextDiff;
        }
      }
      return closestRefLength;
    } else if (getEffLengthMethod() == EffectiveLengthMethod.SHORTEST) {
      int shortestRefLength = Integer.MAX_VALUE;

      for (int r = 0; r < refsPerSen; ++r) {
        if (r == this.sourceReferenceIndex) continue;

        int nextRefLength = refWordCount[i][r];
        if (nextRefLength < shortestRefLength) {
          shortestRefLength = nextRefLength;
        }
      }
      return shortestRefLength;
    }

    return candLength; // should never get here anyway
  }


  public double score(int[] stats) {
    if (stats.length != suffStatsCount) {
      String msg ="Mismatch between stats.length and " + "suffStatsCount (" + stats.length
          + " vs. " + suffStatsCount + ") in BLEU.score(int[])";
      LOG.error(msg);
      throw new RuntimeException(msg);
    }

    double accuracy = 0.0;
    double smooth_addition = 1.0; // following bleu-1.04.pl
    double c_len = stats[suffStatsCount - 3];
    double r_len = stats[suffStatsCount - 2];

    double wer = stats[suffStatsCount - 1] / c_len;
    double wer_penalty = (wer >= thresholdWER) ? 1.0 : (wer / thresholdWER);

    double correctGramCount, totalGramCount;

    for (int n = 1; n <= getMaxGramLength(); ++n) {
      correctGramCount = stats[2 * (n - 1)];
      totalGramCount = stats[2 * (n - 1) + 1];

      double prec_n;
      if (totalGramCount > 0) {
        prec_n = correctGramCount / totalGramCount;
      } else {
        prec_n = 1; // following bleu-1.04.pl ???????
      }

      if (prec_n == 0) {
        smooth_addition *= 0.5;
        prec_n = smooth_addition / (c_len - n + 1);
        // isn't c_len-n+1 just totalGramCount ???????
      }
      accuracy += weights[n] * Math.log(prec_n);
    }
    double brevity_penalty = 1.0;
    if (c_len < r_len) brevity_penalty = Math.exp(1 - (r_len / c_len));

    return wer_penalty * brevity_penalty * Math.exp(accuracy);
  }


  public void printDetailedScore_fromStats(int[] stats, boolean oneLiner) {
    double wer = stats[suffStatsCount - 1] / stats[suffStatsCount - 3];
    double wer_penalty = (wer >= thresholdWER) ? 1.0d : (wer / thresholdWER);

    System.out.println("WER_penalty = " + wer_penalty);
    System.out.println("MC_BLEU= " + score(stats));
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy