weka.classifiers.bayes.blr.LaplacePriorImpl Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of weka-stable Show documentation
Show all versions of weka-stable Show documentation
The Waikato Environment for Knowledge Analysis (WEKA), a machine
learning workbench. This is the stable version. Apart from bugfixes, this version
does not receive any other updates.
/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
*/
/*
* GaussianPrior.java
* Copyright (C) 2008 Illinois Institute of Technology
*
*/
package weka.classifiers.bayes.blr;
import weka.classifiers.bayes.BayesianLogisticRegression;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.RevisionUtils;
/**
* Implementation of the Gaussian Prior update function based on modified
* CLG Algorithm (CLG-Lasso) with a certain Trust Region Update based
* on Laplace Priors.
*
* @author Navendu Garg([email protected])
* @version $Revision: 1.2 $
*/
public class LaplacePriorImpl
extends Prior {
/** for serialization. */
private static final long serialVersionUID = 2353576123257012607L;
Instances m_Instances;
double Beta;
double Hyperparameter;
double DeltaUpdate;
double[] R;
double Delta;
/**
* Update function specific to Laplace Prior.
*/
public double update(int j, Instances instances, double beta,
double hyperparameter, double[] r, double deltaV) {
double sign = 0.0;
double change = 0.0;
DeltaUpdate = 0.0;
m_Instances = instances;
Beta = beta;
Hyperparameter = hyperparameter;
R = r;
Delta = deltaV;
if (Beta == 0) {
sign = 1.0;
DeltaUpdate = laplaceUpdate(j, sign);
if (DeltaUpdate <= 0.0) { // positive direction failed.
sign = -1.0;
DeltaUpdate = laplaceUpdate(j, sign);
if (DeltaUpdate >= 0.0) {
DeltaUpdate = 0;
}
}
} else {
sign = Beta / Math.abs(Beta);
DeltaUpdate = laplaceUpdate(j, sign);
change = Beta + DeltaUpdate;
change = change / Math.abs(change);
if (change < 0) {
DeltaUpdate = 0 - Beta;
}
}
return DeltaUpdate;
}
/**
* This is the CLG-lasso update function described in the
*
* @TechReport{blrtext04,
*author = {Alexander Genkin and David D. Lewis and David Madigan},
*title = {Large-scale bayesian logistic regression for text categorization},
*institution = {DIMACS},
*year = {2004},
*url = "http://www.stat.rutgers.edu/~madigan/PAPERS/shortFat-v3a.pdf",
*OPTannote = {}
*}
*
* @param j
* @return double value
*/
public double laplaceUpdate(int j, double sign) {
double value = 0.0;
double numerator = 0.0;
double denominator = 0.0;
Instance instance;
for (int i = 0; i < m_Instances.numInstances(); i++) {
instance = m_Instances.instance(i);
if (instance.value(j) != 0) {
numerator += (instance.value(j) * BayesianLogisticRegression.classSgn(instance.classValue()) * (1.0 / (1.0 +
Math.exp(R[i]))));
denominator += (instance.value(j) * instance.value(j) * BayesianLogisticRegression.bigF(R[i],
Delta * instance.value(j)));
}
}
numerator -= (Math.sqrt(2.0 / Hyperparameter) * sign);
if (denominator != 0.0) {
value = numerator / denominator;
}
return value;
}
/**
* Computes the log-likelihood values using the implementation in the Prior class.
* @param betas
* @param instances
* @param hyperparameter
*/
public void computeLogLikelihood(double[] betas, Instances instances) {
//Basic implementation done in the prior class.
super.computelogLikelihood(betas, instances);
}
/**
* This function computes the penalty term specific to Laplacian distribution.
* @param betas
* @param hyperparameters
*/
public void computePenalty(double[] betas, double[] hyperparameters) {
penalty = 0.0;
double lambda = 0.0;
for (int j = 0; j < betas.length; j++) {
lambda = Math.sqrt(hyperparameters[j]);
penalty += (Math.log(2) - Math.log(lambda) +
(lambda * Math.abs(betas[j])));
}
penalty = 0 - penalty;
}
/**
* Returns the revision string.
*
* @return the revision
*/
public String getRevision() {
return RevisionUtils.extract("$Revision: 1.2 $");
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy