All Downloads are FREE. Search and download functionalities are using the official Maven repository.

opennlp.tools.ml.maxent.quasinewton.QNModel Maven / Gradle / Ivy

There is a newer version: 2.5.0
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License. You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package opennlp.tools.ml.maxent.quasinewton;

import opennlp.tools.ml.model.AbstractModel;
import opennlp.tools.ml.model.Context;

public class QNModel extends AbstractModel {

  public QNModel(Context[] params, String[] predLabels, String[] outcomeNames) {
    super(params, predLabels, outcomeNames);
    this.modelType = ModelType.MaxentQn;
  }

  public int getNumOutcomes() {
    return this.outcomeNames.length;
  }

  private Integer getPredIndex(String predicate) {
    return pmap.get(predicate);
  }

  public double[] eval(String[] context) {
    return eval(context, new double[evalParams.getNumOutcomes()]);
  }

  public double[] eval(String[] context, double[] probs) {
    return eval(context, null, probs);
  }

  public double[] eval(String[] context, float[] values) {
    return eval(context, values, new double[evalParams.getNumOutcomes()]);
  }

  /**
   * Model evaluation which should be used during inference.
   * @param context
   *          The predicates which have been observed at the present
   *          decision point.
   * @param values
   *          Weights of the predicates which have been observed at
   *          the present decision point.
   * @param probs
   *          Probability for outcomes.
   * @return Normalized probabilities for the outcomes given the context.
   */
  private double[] eval(String[] context, float[] values, double[] probs) {
    Context[] params = evalParams.getParams();

    for (int ci = 0; ci < context.length; ci++) {
      Integer predIdx = getPredIndex(context[ci]);

      if (predIdx != null) {
        double predValue = 1.0;
        if (values != null) predValue = values[ci];

        double[] parameters = params[predIdx].getParameters();
        int[] outcomes = params[predIdx].getOutcomes();
        for (int i = 0; i < outcomes.length; i++) {
          int oi = outcomes[i];
          probs[oi] += predValue * parameters[i];
        }
      }
    }

    double logSumExp = ArrayMath.logSumOfExps(probs);
    for (int oi = 0; oi < outcomeNames.length; oi++) {
      probs[oi] = Math.exp(probs[oi] - logSumExp);
    }
    return probs;
  }

  /**
   * Model evaluation which should be used during training to report model accuracy.
   * @param context
   *          Indices of the predicates which have been observed at the present
   *          decision point.
   * @param values
   *          Weights of the predicates which have been observed at
   *          the present decision point.
   * @param probs
   *          Probability for outcomes
   * @param nOutcomes
   *          Number of outcomes
   * @param nPredLabels
   *          Number of unique predicates
   * @param parameters
   *          Model parameters
   * @return Normalized probabilities for the outcomes given the context.
   */
  @Deprecated // visibility will be reduced in 1.8.1
  public static double[] eval(int[] context, float[] values, double[] probs,
      int nOutcomes, int nPredLabels, double[] parameters) {

    for (int i = 0; i < context.length; i++) {
      int predIdx = context[i];
      double predValue = values != null ? values[i] : 1.0;
      for (int oi = 0; oi < nOutcomes; oi++) {
        probs[oi] += predValue * parameters[oi * nPredLabels + predIdx];
      }
    }

    double logSumExp = ArrayMath.logSumOfExps(probs);

    for (int oi = 0; oi < nOutcomes; oi++) {
      probs[oi] = Math.exp(probs[oi] - logSumExp);
    }

    return probs;
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy