All Downloads are FREE. Search and download functionalities are using the official Maven repository.

cc.mallet.grmm.learning.PiecewiseACRFTrainer Maven / Gradle / Ivy

Go to download

MALLET is a Java-based package for statistical natural language processing, document classification, clustering, topic modeling, information extraction, and other machine learning applications to text.

The newest version!
/* Copyright (C) 2003 Univ. of Massachusetts Amherst, Computer Science Dept.
   This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
   http://www.cs.umass.edu/~mccallum/mallet
   This software is provided under the terms of the Common Public License,
   version 1.0, as published by http://www.opensource.org.  For further
   information, see the file `LICENSE' included with this distribution. */
package cc.mallet.grmm.learning;


import cc.mallet.grmm.types.Assignment;
import cc.mallet.grmm.types.AssignmentIterator;
import cc.mallet.grmm.types.Factor;
import cc.mallet.grmm.types.Variable;
import cc.mallet.optimize.Optimizable;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.SparseVector;
import cc.mallet.util.MalletLogger;
import cc.mallet.grmm.util.CachingOptimizable;

import java.io.FileOutputStream;
import java.io.IOException;
import java.io.PrintStream;
import java.io.Serializable;
import java.util.BitSet;
import java.util.Iterator;
import java.util.logging.Logger;

/**
 * Created: Mar 15, 2005
 *
 * @author  "+assn.getObject (var)
          +"  ("+assn.get (var)+")");
      }
    }


    private boolean weightValid (double w, int cnum, int j)
    {
      if (Double.isInfinite (w)) {
        logger.warning ("Weight is infinite for clique "+cnum+"assignment "+j);
        return false;
      } else if (Double.isNaN (w)) {
        logger.warning ("Weight is Nan for clique "+cnum+"assignment "+j);
        return false;
      } else {
        return true;
      }
    }


    int numInBatch = 0;

    public double computeValueAndGradient (int instance)
    {
      numInBatch++;
      collectConstraintsForInstance (trainData, instance);
      double value = computeValueForInstance (instance);
      value += (computePrior () / trainData.size ());
      return value;
    }

    public int getNumInstances ()
    {
      return trainData.size ();
    }


    public void getCachedGradient (double[] grad)
    {
      computeValueGradient (grad, ((double)numInBatch) / trainData.size());
    }


    public void resetValueGradient ()
    {
      resetExpectations ();
      resetConstraints ();
    }



  } // OptimizableACRF

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy