All Downloads are FREE. Search and download functionalities are using the official Maven repository.

cc.mallet.grmm.test.TestPottsFactor Maven / Gradle / Ivy

Go to download

MALLET is a Java-based package for statistical natural language processing, document classification, clustering, topic modeling, information extraction, and other machine learning applications to text.

The newest version!
/* Copyright (C) 2006 Univ. of Massachusetts Amherst, Computer Science Dept.
   This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
   http://www.cs.umass.edu/~mccallum/mallet
   This software is provided under the terms of the Common Public License,
   version 1.0, as published by http://www.opensource.org.  For further
   information, see the file `LICENSE' included with this distribution. */

package cc.mallet.grmm.test;

import cc.mallet.grmm.types.*;
import junit.framework.*;

/**
 * $Id: TestPottsFactor.java,v 1.1 2007/10/22 21:37:41 mccallum Exp $
 */
public class TestPottsFactor extends TestCase {

  private PottsTableFactor factor;
  private Variable alpha;
  private VarSet vars;


  public TestPottsFactor (String name)
  {
    super (name);


  }

  /**
   * @return a TestSuite
   */
  public static TestSuite suite ()
  {
    return new TestSuite (TestPottsFactor.class);
  }

  protected void setUp () throws Exception
  {
    alpha = new Variable (Variable.CONTINUOUS);
    Variable v1 = new Variable (2);
    Variable v2 = new Variable (2);
    vars = new HashVarSet (new Variable[] { v1,v2 });
    factor = new PottsTableFactor (vars, alpha);
  }

  public void testSlice ()
  {
    Assignment assn = new Assignment (alpha, 1.0);
    Factor sliced = factor.slice (assn);

    assertTrue (sliced instanceof AbstractTableFactor);
    assertTrue (sliced.varSet ().equals (vars));

    TableFactor expected = new TableFactor (vars, new double[] { 1.0, Math.exp(-1), Math.exp(-1), 1.0 });
    assertTrue (sliced.almostEquals (expected));
  }

  public void testSumGradLog ()
  {
    Assignment alphaAssn = new Assignment (alpha, 1.0);

    double[] values = new double[] { 0.4, 0.1, 0.3, 0.2 };
    Factor q = new TableFactor (vars, values);

    double grad = factor.sumGradLog (q, alpha, alphaAssn);
    assertEquals (-0.4, grad, 1e-5);
  }

  public void testSumGradLog2 ()
  {
    Assignment alphaAssn = new Assignment (alpha, 1.0);

    double[] values = new double[] { 0.4, 0.1, 0.3, 0.2 };
    Factor q1 = new TableFactor (vars, values);
    Factor q2 = new TableFactor (new Variable(2), new double[] { 0.7, 0.3 });
    Factor q = q1.multiply (q2);

    double grad = factor.sumGradLog (q, alpha, alphaAssn);
    assertEquals (-0.4, grad, 1e-5);
  }

  public static void main (String[] args)
  {
    TestSuite theSuite;
    if (args.length > 0) {
      theSuite = new TestSuite ();
      for (int i = 0; i < args.length; i++) {
        theSuite.addTest (new TestPottsFactor (args[i]));
      }
    } else {
      theSuite = suite ();
    }

    junit.textui.TestRunner.run (theSuite);
  }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy