
cc.mallet.grmm.test.TestPottsFactor Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of mallet Show documentation
Show all versions of mallet Show documentation
MALLET is a Java-based package for statistical natural language processing,
document classification, clustering, topic modeling, information extraction,
and other machine learning applications to text.
The newest version!
/* Copyright (C) 2006 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
http://www.cs.umass.edu/~mccallum/mallet
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org. For further
information, see the file `LICENSE' included with this distribution. */
package cc.mallet.grmm.test;
import cc.mallet.grmm.types.*;
import junit.framework.*;
/**
* $Id: TestPottsFactor.java,v 1.1 2007/10/22 21:37:41 mccallum Exp $
*/
public class TestPottsFactor extends TestCase {
private PottsTableFactor factor;
private Variable alpha;
private VarSet vars;
public TestPottsFactor (String name)
{
super (name);
}
/**
* @return a TestSuite
*/
public static TestSuite suite ()
{
return new TestSuite (TestPottsFactor.class);
}
protected void setUp () throws Exception
{
alpha = new Variable (Variable.CONTINUOUS);
Variable v1 = new Variable (2);
Variable v2 = new Variable (2);
vars = new HashVarSet (new Variable[] { v1,v2 });
factor = new PottsTableFactor (vars, alpha);
}
public void testSlice ()
{
Assignment assn = new Assignment (alpha, 1.0);
Factor sliced = factor.slice (assn);
assertTrue (sliced instanceof AbstractTableFactor);
assertTrue (sliced.varSet ().equals (vars));
TableFactor expected = new TableFactor (vars, new double[] { 1.0, Math.exp(-1), Math.exp(-1), 1.0 });
assertTrue (sliced.almostEquals (expected));
}
public void testSumGradLog ()
{
Assignment alphaAssn = new Assignment (alpha, 1.0);
double[] values = new double[] { 0.4, 0.1, 0.3, 0.2 };
Factor q = new TableFactor (vars, values);
double grad = factor.sumGradLog (q, alpha, alphaAssn);
assertEquals (-0.4, grad, 1e-5);
}
public void testSumGradLog2 ()
{
Assignment alphaAssn = new Assignment (alpha, 1.0);
double[] values = new double[] { 0.4, 0.1, 0.3, 0.2 };
Factor q1 = new TableFactor (vars, values);
Factor q2 = new TableFactor (new Variable(2), new double[] { 0.7, 0.3 });
Factor q = q1.multiply (q2);
double grad = factor.sumGradLog (q, alpha, alphaAssn);
assertEquals (-0.4, grad, 1e-5);
}
public static void main (String[] args)
{
TestSuite theSuite;
if (args.length > 0) {
theSuite = new TestSuite ();
for (int i = 0; i < args.length; i++) {
theSuite.addTest (new TestPottsFactor (args[i]));
}
} else {
theSuite = suite ();
}
junit.textui.TestRunner.run (theSuite);
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy