jmaxent.Train Maven / Gradle / Ivy
/*
Copyright (C) 2010 by
*
* Cam-Tu Nguyen
* [email protected] or [email protected]
*
* Xuan-Hieu Phan
* [email protected]
*
* College of Technology, Vietnamese University, Hanoi
* Graduate School of Information Sciences, Tohoku University
*
* JVnTextPro-v.2.0 is a free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published
* by the Free Software Foundation; either version 2 of the License,
* or (at your option) any later version.
*
* JVnTextPro-v.2.0 is distributed in the hope that it will be useful, but
* WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with JVnTextPro-v.2.0); if not, write to the Free Software Foundation,
* Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA.
*/
package jmaxent;
import java.io.*;
import org.riso.numerical.*;
// TODO: Auto-generated Javadoc
/**
* The Class Train.
*/
public class Train {
// the model object
/** The model. */
public Model model = null;
/** The num labels. */
public int numLabels = 0;
/** The num features. */
public int numFeatures = 0;
/** The lambda. */
double[] lambda = null;
/** The temp lambda. */
double[] tempLambda = null;
// for L-BFGS
/** The grad log li. */
double[] gradLogLi = null;
/** The diag. */
double[] diag = null;
/** The temp. */
double[] temp = null;
/** The ws. */
double[] ws = null;
/** The iprint. */
int[] iprint = null;
/** The iflag. */
int[] iflag = null;
/**
* Instantiates a new train.
*/
public Train() {
// do nothing
}
/**
* Inits the.
*/
public void init() {
numLabels = model.data.numLabels();
numFeatures = model.feaGen.numFeatures();
if (numLabels <= 0 || numFeatures <= 0) {
System.out.println("Invalid number of labels or features");
return;
}
lambda = model.lambda;
tempLambda = new double[numFeatures];
gradLogLi = new double[numFeatures];
diag = new double[numFeatures];
temp = new double[numLabels];
int wsSize = numFeatures * (2 * model.option.mForHessian + 1) +
2 * model.option.mForHessian;
ws = new double[wsSize];
iprint = new int[2];
iflag = new int[1];
}
/**
* Norm.
*
* @param vect the vect
* @return the double
*/
public static double norm(double[] vect) {
double res = 0.0;
for (int i = 0; i < vect.length; i++) {
res += vect[i] * vect[i];
}
return Math.sqrt(res);
}
/**
* Do train.
*
* @param fout the fout
*/
public void doTrain(PrintWriter fout) {
long start_train, end_train, elapsed_train;
long start_iter, end_iter, elapsed_iter;
// initialization
init();
double f = 0.0;
//double old_f;
double xtol = 1.0e-16;
int numIter = 0;
// for L-BFGS
iprint[0] = model.option.debugLevel - 2;
iprint[1] = model.option.debugLevel - 1;
iflag[0] = 0;
// counter
int i;
// get initial values for lambda
for (i = 0; i < numFeatures; i++) {
lambda[i] = model.option.initLambdaVal;
}
System.out.println("Start to train ...");
if (model.option.isLogging) {
model.option.writeOptions(fout);
fout.println("Start to train ...");
}
// starting time of the training process
start_train = System.currentTimeMillis();
double maxAccuracy = 0.0;
int maxAccuracyIter = -1;
// the training loop
do {
// starting time of iteration
start_iter = System.currentTimeMillis();
// call this to compute two things:
// 1. log-likelihood value
// 2. the gradient vector of log-likelihood function
f = computeLogLiGradient(lambda, gradLogLi, numIter + 1, fout);
// negate f and its gradient because L-BFGS minimizes the objective function
// while we would like to maximize it
f *= -1;
for (i = 0; i < numFeatures; i++) {
gradLogLi[i] *= -1;
}
// calling L-BFGS
try {
new LBFGS().lbfgs(numFeatures, model.option.mForHessian, lambda, f, gradLogLi,
false, diag, iprint, model.option.epsForConvergence, xtol, iflag);
} catch (LBFGS.ExceptionWithIflag e) {
System.out.println("L-BFGS failed!");
if (model.option.isLogging) {
fout.println("L-BFGS failed!");
}
break;
}
numIter++;
// get the end time of the current iteration
end_iter = System.currentTimeMillis();
elapsed_iter = end_iter - start_iter;
System.out.println("\tIteration elapsed: " +
Double.toString((double)elapsed_iter / 1000) + " seconds");
if (model.option.isLogging) {
fout.println("\tIteration elapsed: " +
Double.toString((double)elapsed_iter / 1000) + " seconds");
}
// evaluate during training
if (model.option.evaluateDuringTraining) {
// inference on testing data
model.doInference(model.data.tstData);
// evaluation
double accuracy = model.evaluation.evaluate(fout);
if (accuracy > maxAccuracy) {
maxAccuracy = accuracy;
maxAccuracyIter = numIter;
// save the best model towards testing evaluation
if (model.option.saveBestModel) {
for (i = 0; i < numFeatures; i++) {
tempLambda[i] = lambda[i];
}
}
}
System.out.println("\tCurrent max accuracy: " +
Double.toString(maxAccuracy) + " (at iteration " +
Integer.toString(maxAccuracyIter) + ")");
if (model.option.isLogging) {
fout.println("\tCurrent max accuracy: " +
Double.toString(maxAccuracy) + " (at iteration " +
Integer.toString(maxAccuracyIter) + ")");
}
// get the end time of the current iteration
end_iter = System.currentTimeMillis();
elapsed_iter = end_iter - start_iter;
System.out.println("\tIteration elapsed (including testing & evaluation): " +
Double.toString((double)elapsed_iter / 1000) + " seconds");
if (model.option.isLogging) {
fout.println("\tIteration elapsed (including testing & evaluation): " +
Double.toString((double)elapsed_iter / 1000) + " seconds");
fout.flush();
}
}
} while (iflag[0] != 0 && numIter < model.option.numIterations);
// get the end time of the training process
end_train = System.currentTimeMillis();
elapsed_train = end_train - start_train;
System.out.println("\tThe training process elapsed: " +
Double.toString((double)elapsed_train / 1000) + " seconds");
if (model.option.isLogging) {
fout.println("\tThe training process elapsed: " +
Double.toString((double)elapsed_train / 1000) + " seconds");
}
if (model.option.evaluateDuringTraining && model.option.saveBestModel) {
for (i = 0; i < numFeatures; i++) {
lambda[i] = tempLambda[i];
}
}
}
/**
* Compute log li gradient.
*
* @param lambda the lambda
* @param gradLogLi the grad log li
* @param numIter the num iter
* @param fout the fout
* @return the double
*/
public double computeLogLiGradient(double[] lambda, double[] gradLogLi,
int numIter, PrintWriter fout) {
double logLi = 0.0;
int ii, i;//, j, k;
for (i = 0; i < numFeatures; i++) {
gradLogLi[i] = -1 * lambda[i] / model.option.sigmaSquare;
logLi -= (lambda[i] * lambda[i]) / (2 * model.option.sigmaSquare);
}
// go through all training data examples/observations
for (ii = 0; ii < model.data.trnData.size(); ii++) {
Observation obsr = (Observation)model.data.trnData.get(ii);
for (i = 0; i < numLabels; i++) {
temp[i] = 0.0;
}
// log-likelihood value of the current data observation
double obsrLogLi = 0.0;
// start to scan all features at the current obsr
model.feaGen.startScanFeatures(obsr);
while (model.feaGen.hasNextFeature()) {
Feature f = model.feaGen.nextFeature();
if (f.label == obsr.humanLabel) {
gradLogLi[f.idx] += f.val;
obsrLogLi += lambda[f.idx] * f.val;
}
temp[f.label] += lambda[f.idx] * f.val;
}
double Zx = 0.0;
for (i = 0; i < numLabels; i++) {
Zx += Math.exp(temp[i]);
}
model.feaGen.scanReset();
while (model.feaGen.hasNextFeature()) {
Feature f = model.feaGen.nextFeature();
gradLogLi[f.idx] -= f.val * Math.exp(temp[f.label]) / Zx;
}
obsrLogLi -= Math.log(Zx);
logLi += obsrLogLi;
} // end of the main loop
System.out.println();
System.out.println("Iteration: " + Integer.toString(numIter));
System.out.println("\tLog-likelihood = " + Double.toString(logLi));
double gradLogLiNorm = Train.norm(gradLogLi);
System.out.println("\tNorm (log-likelihood gradient) = " + Double.toString(gradLogLiNorm));
double lambdaNorm = Train.norm(lambda);
System.out.println("\tNorm (lambda) = " + Double.toString(lambdaNorm));
if (model.option.isLogging) {
fout.println();
fout.println("Iteration: " + Integer.toString(numIter));
fout.println("\tLog-likelihood = " + Double.toString(logLi));
fout.println("\tNorm (log-likelihood gradient) = " + Double.toString(gradLogLiNorm));
fout.println("\tNorm (lambda) = " + Double.toString(lambdaNorm));
}
return logLi;
}
} // end of class Train
© 2015 - 2025 Weber Informatics LLC | Privacy Policy