hex.naivebayes.NaiveBayesModel Maven / Gradle / Ivy
package hex.naivebayes;
import hex.Model;
import hex.ModelMetrics;
import hex.ModelMetricsBinomial;
import hex.ModelMetricsMultinomial;
import hex.genmodel.GenModel;
import hex.schemas.NaiveBayesModelV3;
import water.H2O;
import water.Key;
import water.api.schemas3.ModelSchemaV3;
import water.codegen.CodeGenerator;
import water.codegen.CodeGeneratorPipeline;
import water.exceptions.JCodeSB;
import water.util.JCodeGen;
import water.util.SBPrintStream;
import water.util.TwoDimTable;
public class NaiveBayesModel extends Model {
public static class NaiveBayesParameters extends Model.Parameters {
public double _laplace = 0; // Laplace smoothing parameter
public double _eps_sdev = 0; // Cutoff below which standard deviation is replaced with _min_sdev
public double _min_sdev = 0.001; // Minimum standard deviation to use for observations without enough data
public double _eps_prob = 0; // Cutoff below which probability is replaced with _min_prob
public double _min_prob = 0.001; // Minimum conditional probability to use for observations without enough data
public boolean _compute_metrics = true; // Should a second pass be made through data to compute metrics?
public String algoName() { return "NaiveBayes"; }
public String fullName() { return "Naive Bayes"; }
public String javaName() { return NaiveBayesModel.class.getName(); }
@Override public long progressUnits() { return 6; }
}
public static class NaiveBayesOutput extends Model.Output {
// Class distribution of the response
public TwoDimTable _apriori;
public double[/*res level*/] _apriori_raw;
// For every predictor, a table providing, for each attribute level, the conditional probabilities given the target class
public TwoDimTable[/*predictor*/] _pcond;
public double[/*predictor*/][/*res level*/][/*pred level*/] _pcond_raw;
// Count of response levels
public int[] _rescnt;
// Domain of the response
public String[] _levels;
// Number of categorical predictors
public int _ncats;
public NaiveBayesOutput(NaiveBayes b) { super(b); }
}
public NaiveBayesModel(Key selfKey, NaiveBayesParameters parms, NaiveBayesOutput output) { super(selfKey,parms,output); }
public ModelSchemaV3 schema() {
return new NaiveBayesModelV3();
}
// TODO: Constant response shouldn't be regression. Need to override getModelCategory()
@Override public ModelMetrics.MetricBuilder makeMetricBuilder(String[] domain) {
switch(_output.getModelCategory()) {
case Binomial: return new ModelMetricsBinomial.MetricBuilderBinomial(domain);
case Multinomial: return new ModelMetricsMultinomial.MetricBuilderMultinomial(domain.length,domain);
default: throw H2O.unimpl();
}
}
// Note: For small probabilities, product may end up zero due to underflow error. Can circumvent by taking logs.
@Override protected double[] score0(double[] data, double[] preds) {
double[] nums = new double[_output._levels.length]; // log(p(x,y)) for all levels of y
assert preds.length >= (_output._levels.length + 1); // Note: First column of preds is predicted response class
// Compute joint probability of predictors for every response class
for(int rlevel = 0; rlevel < _output._levels.length; rlevel++) {
// Take logs to avoid overflow: p(x,y) = p(x|y)*p(y) -> log(p(x,y)) = log(p(x|y)) + log(p(y))
nums[rlevel] = Math.log(_output._apriori_raw[rlevel]);
for(int col = 0; col < _output._ncats; col++) {
if(Double.isNaN(data[col])) continue; // Skip predictor in joint x_1,...,x_m if NA
int plevel = (int)data[col];
double prob = plevel < _output._pcond_raw[col][rlevel].length ? _output._pcond_raw[col][rlevel][plevel] :
_parms._laplace / ((double)_output._rescnt[rlevel] + _parms._laplace * _output._domains[col].length); // Laplace smoothing if predictor level unobserved in training set
nums[rlevel] += Math.log(prob <= _parms._eps_prob ? _parms._min_prob : prob); // log(p(x|y)) = \sum_{j = 1}^m p(x_j|y)
}
// For numeric predictors, assume Gaussian distribution with sample mean and variance from model
for(int col = _output._ncats; col < data.length; col++) {
if(Double.isNaN(data[col])) continue; // Skip predictor in joint x_1,...,x_m if NA
double x = data[col];
double mean = Double.isNaN(_output._pcond_raw[col][rlevel][0]) ? 0 : _output._pcond_raw[col][rlevel][0];
double stddev = Double.isNaN(_output._pcond_raw[col][rlevel][1]) ? 1.0 :
(_output._pcond_raw[col][rlevel][1] <= _parms._eps_sdev ? _parms._min_sdev : _output._pcond_raw[col][rlevel][1]);
// double prob = Math.exp(new NormalDistribution(mean, stddev).density(data[col])); // slower
double prob = Math.exp(-((x-mean)*(x-mean))/(2.*stddev*stddev)) / (stddev*Math.sqrt(2.*Math.PI)); // faster
nums[rlevel] += Math.log(prob <= _parms._eps_prob ? _parms._min_prob : prob);
}
}
// Numerically unstable:
// p(x,y) = exp(log(p(x,y))), p(x) = \Sum_{r = levels of y} exp(log(p(x,y = r))) -> p(y|x) = p(x,y)/p(x)
// Instead, we rewrite using a more stable form:
// p(y|x) = p(x,y)/p(x) = exp(log(p(x,y))) / (\Sum_{r = levels of y} exp(log(p(x,y = r)))
// = 1 / ( exp(-log(p(x,y))) * \Sum_{r = levels of y} exp(log(p(x,y = r))) )
// = 1 / ( \Sum_{r = levels of y} exp( log(p(x,y = r)) - log(p(x,y)) ))
for(int i = 0; i < nums.length; i++) {
double sum = 0;
for(int j = 0; j < nums.length; j++)
sum += Math.exp(nums[j] - nums[i]);
preds[i+1] = 1/sum;
}
// Select class with highest conditional probability
preds[0] = GenModel.getPrediction(preds, _output._priorClassDist, data, defaultThreshold());
return preds;
}
@Override protected SBPrintStream toJavaInit(SBPrintStream sb, CodeGeneratorPipeline fileCtx) {
sb = super.toJavaInit(sb, fileCtx);
sb.ip("public boolean isSupervised() { return " + isSupervised() + "; }").nl();
sb.ip("public int nfeatures() { return " + _output.nfeatures() + "; }").nl();
sb.ip("public int nclasses() { return " + _output.nclasses() + "; }").nl();
// This is model name
final String mname = JCodeGen.toJavaId(_key.toString());
fileCtx.add(new CodeGenerator() {
@Override
public void generate(JCodeSB out) {
JCodeGen.toClassWithArray(out, null, mname + "_RESCNT", _output._rescnt,
"Count of categorical levels in response.");
JCodeGen.toClassWithArray(out, null, mname + "_APRIORI", _output._apriori_raw,
"Apriori class distribution of the response.");
JCodeGen.toClassWithArray(out, null, mname + "_PCOND", _output._pcond_raw,
"Conditional probability of predictors.");
double[] dlen = null;
if (_output._ncats > 0) {
dlen = new double[_output._ncats];
for (int i = 0; i < _output._ncats; i++)
dlen[i] = _output._domains[i].length;
}
JCodeGen.toClassWithArray(out, null, mname + "_DOMLEN", dlen,
"Number of unique levels for each categorical predictor.");
}
});
return sb;
}
@Override protected void toJavaPredictBody(SBPrintStream bodySb,
CodeGeneratorPipeline classCtx,
CodeGeneratorPipeline fileCtx,
final boolean verboseCode) {
// This is model name
final String mname = JCodeGen.toJavaId(_key.toString());
bodySb.i().p("java.util.Arrays.fill(preds,0);").nl();
bodySb.i().p("double mean, sdev, prob;").nl();
bodySb.i().p("double[] nums = new double[" + _output._levels.length + "];").nl();
bodySb.i().p("for(int i = 0; i < " + _output._levels.length + "; i++) {").nl();
bodySb.i(1).p("nums[i] = Math.log(").pj(mname+"_APRIORI", "VALUES").p("[i]);").nl();
bodySb.i(1).p("for(int j = 0; j < " + _output._ncats + "; j++) {").nl();
bodySb.i(2).p("if(Double.isNaN(data[j])) continue;").nl();
bodySb.i(2).p("int level = (int)data[j];").nl();
bodySb.i(2).p("prob = level < ").p(_output._pcond_raw.length).p(" ? " + mname + "_PCOND.VALUES[j][i][level] : ")
.p(_parms._laplace == 0 ? "0" : _parms._laplace + "/("+mname+"_RESCNT.VALUES[i] + " + _parms._laplace
+ "*" + mname + "_DOMLEN.VALUES[j])").p(";").nl();
bodySb.i(2).p("nums[i] += Math.log(prob <= " + _parms._eps_prob + " ? " + _parms._min_prob + " : prob);").nl();
bodySb.i(1).p("}").nl();
bodySb.i(1).p("for(int j = " + _output._ncats + "; j < data.length; j++) {").nl();
bodySb.i(2).p("if(Double.isNaN(data[j])) continue;").nl();
bodySb.i(2).p("mean = Double.isNaN("+mname+"_PCOND.VALUES[j][i][0]) ? 0 : "+mname+"_PCOND.VALUES[j][i][0];").nl();
bodySb.i(2).p("sdev = Double.isNaN("+mname+"_PCOND.VALUES[j][i][1]) ? 1 : ("+mname+"_PCOND.VALUES[j][i][1] <= " + _parms._eps_sdev + " ? "
+ _parms._min_sdev + " : "+mname+"_PCOND.VALUES[j][i][1]);").nl();
bodySb.i(2).p("prob = Math.exp(-((data[j]-mean)*(data[j]-mean))/(2.*sdev*sdev)) / (sdev*Math.sqrt(2.*Math.PI));").nl();
bodySb.i(2).p("nums[i] += Math.log(prob <= " + _parms._eps_prob + " ? " + _parms._min_prob + " : prob);").nl();
bodySb.i(1).p("}").nl();
bodySb.i().p("}").nl();
bodySb.i().p("double sum;").nl();
bodySb.i().p("for(int i = 0; i < nums.length; i++) {").nl();
bodySb.i(1).p("sum = 0;").nl();
bodySb.i(1).p("for(int j = 0; j < nums.length; j++) {").nl();
bodySb.i(2).p("sum += Math.exp(nums[j]-nums[i]);").nl();
bodySb.i(1).p("}").nl();
bodySb.i(1).p("preds[i+1] = 1/sum;").nl();
bodySb.i().p("}").nl();
bodySb.i().p("preds[0] = hex.genmodel.GenModel.getPrediction(preds, PRIOR_CLASS_DISTRIB, data, " + defaultThreshold()+");").nl();
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy