All Downloads are FREE. Search and download functionalities are using the official Maven repository.

opennlp.tools.ml.ArrayMath Maven / Gradle / Ivy

There is a newer version: 2024.9.17689.20240905T073330Z-240800
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License. You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package opennlp.tools.ml;

import java.util.List;

import opennlp.tools.ml.model.Context;

/**
 * Utility class for simple vector arithmetic.
 */
public class ArrayMath {

  public static double innerProduct(double[] vecA, double[] vecB) {
    if (vecA == null || vecB == null || vecA.length != vecB.length)
      return Double.NaN;

    double product = 0.0;
    for (int i = 0; i < vecA.length; i++) {
      product += vecA[i] * vecB[i];
    }
    return product;
  }

  /**
   * L1-norm
   */
  public static double l1norm(double[] v) {
    double norm = 0;
    for (int i = 0; i < v.length; i++)
      norm += StrictMath.abs(v[i]);
    return norm;
  }

  /**
   * L2-norm
   */
  public static double l2norm(double[] v) {
    return StrictMath.sqrt(innerProduct(v, v));
  }

  /**
   * Inverse L2-norm
   */
  public static double invL2norm(double[] v) {
    return 1 / l2norm(v);
  }

  /**
   * Computes \log(\sum_{i=1}^n e^{x_i}) using a maximum-element trick
   * to avoid arithmetic overflow.
   *
   * @param x input vector
   * @return log-sum of exponentials of vector elements
   */
  public static double logSumOfExps(double[] x) {
    double max = max(x);
    double sum = 0.0;
    for (int i = 0; i < x.length; i++) {
      if (x[i] != Double.NEGATIVE_INFINITY)
        sum += StrictMath.exp(x[i] - max);
    }
    return max + StrictMath.log(sum);
  }

  public static double max(double[] x) {
    int maxIdx = argmax(x);
    return x[maxIdx];
  }

  /**
   * Find index of maximum element in the vector x
   * @param x input vector
   * @return index of the maximum element. Index of the first
   *     maximum element is returned if multiple maximums are found.
   */
  public static int argmax(double[] x) {
    if (x == null || x.length == 0) {
      throw new IllegalArgumentException("Vector x is null or empty");
    }

    int maxIdx = 0;
    for (int i = 1; i < x.length; i++) {
      if (x[maxIdx] < x[i])
        maxIdx = i;
    }
    return maxIdx;
  }

  public static void sumFeatures(Context[] context, float[] values, double[] prior) {
    for (int ci = 0; ci < context.length; ci++) {
      if (context[ci] != null) {
        Context predParams = context[ci];
        int[] activeOutcomes = predParams.getOutcomes();
        double[] activeParameters = predParams.getParameters();
        double value = 1;
        if (values != null) {
          value = values[ci];
        }
        for (int ai = 0; ai < activeOutcomes.length; ai++) {
          int oid = activeOutcomes[ai];
          prior[oid] += activeParameters[ai] * value;
        }
      }
    }
  }

  // === Not really related to math ===
  /**
   * Convert a list of Double objects into an array of primitive doubles
   */
  public static double[] toDoubleArray(List list) {
    double[] arr = new double[list.size()];
    for (int i = 0; i < arr.length; i++) {
      arr[i] = list.get(i);
    }
    return arr;
  }

  /**
   *  Convert a list of Integer objects into an array of primitive integers
   */
  public static int[] toIntArray(List list) {
    int[] arr = new int[list.size()];
    for (int i = 0; i < arr.length; i++) {
      arr[i] = list.get(i);
    }
    return arr;
  }
}





© 2015 - 2024 Weber Informatics LLC | Privacy Policy