All Downloads are FREE. Search and download functionalities are using the official Maven repository.

edu.berkeley.nlp.math.CachingObjectiveDifferentiableFunction Maven / Gradle / Ivy

Go to download

The Berkeley parser analyzes the grammatical structure of natural language using probabilistic context-free grammars (PCFGs).

The newest version!
package edu.berkeley.nlp.math;

import edu.berkeley.nlp.mapper.AsynchronousMapper;
import edu.berkeley.nlp.mapper.SimpleMapper;
import edu.berkeley.nlp.util.Pair;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;

/**
 * User: aria42
 * Date: Mar 10, 2009
 */
public class CachingObjectiveDifferentiableFunction extends CachingDifferentiableFunction {

  private List> itemFns;
  private Regularizer regularizer;
  private Collection items;

  public CachingObjectiveDifferentiableFunction(Collection items,
                                                List> itemFns,
                                                Regularizer regularizer)
  {
    this.itemFns = itemFns;
    this.regularizer = regularizer;
    this.items = items;
  }

    public CachingObjectiveDifferentiableFunction(Collection items,
                                                  ObjectiveItemDifferentiableFunction itemFn,
                                                  Regularizer regularizer)
  {
    this(items, Collections.singletonList(itemFn),regularizer);
  }

  private class Mapper implements SimpleMapper {
    ObjectiveItemDifferentiableFunction itemFn;
    double objVal ;
    double[] localGrad ;
    Mapper(ObjectiveItemDifferentiableFunction itemFn) {
      this.itemFn = itemFn;
      this.objVal = 0.0;
      this.localGrad = new double[itemFn.dimension()];
    }
    public void map(I elem) {
      objVal += itemFn.update(elem,localGrad);  
    }
  }

  private List getMappers() {
    List mappers = new ArrayList();
    for (ObjectiveItemDifferentiableFunction itemFn : itemFns) {
      mappers.add(new Mapper(itemFn));
    }
    return mappers;
  }

  protected Pair calculate(double[] x) {
    for (ObjectiveItemDifferentiableFunction itemFn : itemFns) {
      itemFn.setWeights(x);
    }
    List mappers = getMappers();
    AsynchronousMapper.doMapping(items,mappers);
    double objVal = 0.0;
    double[] grad = new double[dimension()];
    for (Mapper mapper : mappers) {
      objVal += mapper.objVal;
      DoubleArrays.addInPlace(grad,mapper.localGrad);
    }
    if (regularizer != null) {
      objVal += regularizer.update(x,grad,1.0);
    }
    return Pair.newPair(objVal,grad);
  }

  public int dimension() {
    return itemFns.get(0).dimension();
  }
}