hex.naivebayes.NaiveBayes Maven / Gradle / Ivy
package hex.naivebayes;
import hex.*;
import hex.naivebayes.NaiveBayesModel.NaiveBayesOutput;
import hex.naivebayes.NaiveBayesModel.NaiveBayesParameters;
import water.*;
import water.exceptions.H2OModelBuilderIllegalArgumentException;
import water.fvec.Chunk;
import water.util.ArrayUtils;
import water.util.PrettyPrint;
import water.util.TwoDimTable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
/**
* Naive Bayes
* This is an algorithm for computing the conditional a-posterior probabilities of a categorical
* response from independent predictors using Bayes rule.
* Naive Bayes on Wikipedia
* Lecture Notes by Andrew Ng
* @author anqi_fu
*
*/
public class NaiveBayes extends ModelBuilder {
public boolean isSupervised(){return true;}
@Override protected NaiveBayesDriver trainModelImpl() { return new NaiveBayesDriver(); }
@Override public ModelCategory[] can_build() { return new ModelCategory[]{ ModelCategory.Unknown }; }
@Override public boolean havePojo() { return true; }
@Override public boolean haveMojo() { return false; }
@Override
protected void checkMemoryFootPrint_impl() {
// compute memory usage for pcond matrix
long mem_usage = (_train.numCols() - 1) * _train.lastVec().cardinality();
String[][] domains = _train.domains();
long count = 0;
for (int i = 0; i < _train.numCols() - 1; i++) {
count += domains[i] == null ? 2 : domains[i].length;
}
mem_usage *= count;
mem_usage *= 8; //doubles
long max_mem = H2O.SELF._heartbeat.get_free_mem();
if (mem_usage > max_mem) {
String msg = "Conditional probabilities won't fit in the driver node's memory ("
+ PrettyPrint.bytes(mem_usage) + " > " + PrettyPrint.bytes(max_mem)
+ ") - try reducing the number of columns, the number of response classes or the number of categorical factors of the predictors.";
error("_train", msg);
}
}
// Called from an http request
public NaiveBayes(NaiveBayesModel.NaiveBayesParameters parms) { super(parms); init(false); }
public NaiveBayes(boolean startup_once) { super(new NaiveBayesParameters(),startup_once); }
@Override
public void init(boolean expensive) {
super.init(expensive);
if (_response != null) {
if (!_response.isCategorical()) error("_response", "Response must be a categorical column");
else if (_response.isConst()) error("_response", "Response must have at least two unique categorical levels");
}
if (_parms._laplace < 0) error("_laplace", "Laplace smoothing must be a number >= 0");
if (_parms._min_sdev < 1e-10) error("_min_sdev", "Min. standard deviation must be at least 1e-10");
if (_parms._eps_sdev < 0) error("_eps_sdev", "Threshold for standard deviation must be positive");
if (_parms._min_prob < 1e-10) error("_min_prob", "Min. probability must be at least 1e-10");
if (_parms._eps_prob < 0) error("_eps_prob", "Threshold for probability must be positive");
hide("_balance_classes", "Balance classes is not applicable to NaiveBayes.");
hide("_class_sampling_factors", "Class sampling factors is not applicable to NaiveBayes.");
hide("_max_after_balance_size", "Max after balance size is not applicable to NaiveBayes.");
if (expensive && error_count() == 0) checkMemoryFootPrint();
}
class NaiveBayesDriver extends Driver {
public boolean computeStatsFillModel(NaiveBayesModel model, DataInfo dinfo, NBTask tsk) {
model._output._levels = _response.domain();
model._output._rescnt = tsk._rescnt;
model._output._ncats = dinfo._cats;
if(stop_requested() && !timeout()) return false;
_job.update(1, "Initializing arrays for model statistics");
// String[][] domains = dinfo._adaptedFrame.domains();
String[][] domains = model._output._domains;
double[] apriori = new double[tsk._nrescat];
double[][][] pcond = new double[tsk._npreds][][];
for(int i = 0; i < pcond.length; i++) {
int ncnt = domains[i] == null ? 2 : domains[i].length;
pcond[i] = new double[tsk._nrescat][ncnt];
}
if(stop_requested() && !timeout()) return false;
_job.update(1, "Computing probabilities for categorical cols");
// A-priori probability of response y
for(int i = 0; i < apriori.length; i++)
apriori[i] = ((double)tsk._rescnt[i] + _parms._laplace)/(tsk._nobs + tsk._nrescat * _parms._laplace);
// apriori[i] = tsk._rescnt[i]/tsk._nobs; // Note: R doesn't apply laplace smoothing to priors, even though this is textbook definition
// Probability of categorical predictor x_j conditional on response y
for(int col = 0; col < dinfo._cats; col++) {
assert pcond[col].length == tsk._nrescat;
for(int i = 0; i < pcond[col].length; i++) {
for(int j = 0; j < pcond[col][i].length; j++)
pcond[col][i][j] = ((double)tsk._jntcnt[col][i][j] + _parms._laplace)/((double)tsk._rescnt[i] + domains[col].length * _parms._laplace);
}
}
if(stop_requested() && !timeout()) return false;
_job.update(1, "Computing mean and standard deviation for numeric cols");
// Mean and standard deviation of numeric predictor x_j for every level of response y
for(int col = 0; col < dinfo._nums; col++) {
for(int i = 0; i < pcond[0].length; i++) {
int cidx = dinfo._cats + col;
double num = tsk._rescnt[i];
double pmean = tsk._jntsum[col][i][0]/num;
pcond[cidx][i][0] = pmean;
// double pvar = tsk._jntsum[col][i][1]/num - pmean * pmean;
double pvar = tsk._jntsum[col][i][1]/(num - 1) - pmean * pmean * num/(num - 1);
pcond[cidx][i][1] = Math.sqrt(pvar);
}
}
model._output._apriori_raw = apriori;
model._output._pcond_raw = pcond;
// Create table of conditional probabilities for every predictor
model._output._pcond = new TwoDimTable[pcond.length];
String[] rowNames = _response.domain();
for(int col = 0; col < dinfo._cats; col++) {
String[] colNames = _train.vec(col).domain();
String[] colTypes = new String[colNames.length];
String[] colFormats = new String[colNames.length];
Arrays.fill(colTypes, "double");
Arrays.fill(colFormats, "%5f");
model._output._pcond[col] = new TwoDimTable(_train.name(col), null, rowNames, colNames, colTypes, colFormats,
"Y_by_" + _train.name(col), new String[rowNames.length][], pcond[col]);
}
for(int col = 0; col < dinfo._nums; col++) {
int cidx = dinfo._cats + col;
model._output._pcond[cidx] = new TwoDimTable(_train.name(cidx), null, rowNames, new String[] {"Mean", "Std_Dev"},
new String[] {"double", "double"}, new String[] {"%5f", "%5f"}, "Y_by_" + _train.name(cidx),
new String[rowNames.length][], pcond[cidx]);
}
// Create table of a-priori probabilities for the response
String[] colTypes = new String[_response.cardinality()];
String[] colFormats = new String[_response.cardinality()];
Arrays.fill(colTypes, "double");
Arrays.fill(colFormats, "%5f");
model._output._apriori = new TwoDimTable("A Priori Response Probabilities", null, new String[1], _response.domain(), colTypes, colFormats, "",
new String[1][], new double[][] {apriori});
model._output._model_summary = createModelSummaryTable(model._output);
if(stop_requested() && !timeout()) return false;
_job.update(1, "Scoring and computing metrics on training data");
if (_parms._compute_metrics) {
model.score(_parms.train()).delete(); // This scores on the training data and appends a ModelMetrics
model._output._training_metrics = ModelMetrics.getFromDKV(model,_parms.train());
}
// At the end: validation scoring (no need to gather scoring history)
if(stop_requested() && !timeout()) return false;
_job.update(1, "Scoring and computing metrics on validation data");
if (_valid != null) {
model.score(_parms.valid()).delete(); //this appends a ModelMetrics on the validation set
model._output._validation_metrics = ModelMetrics.getFromDKV(model,_parms.valid());
}
return true;
}
@Override
public void computeImpl() {
NaiveBayesModel model = null;
DataInfo dinfo = null;
try {
init(true); // Initialize parameters
if (error_count() > 0) throw H2OModelBuilderIllegalArgumentException.makeFromBuilder(NaiveBayes.this);
dinfo = new DataInfo(_train, _valid, 1, false, DataInfo.TransformType.NONE, DataInfo.TransformType.NONE, true, false, false, _weights!=null, false, _fold!=null);
// The model to be built
model = new NaiveBayesModel(dest(), _parms, new NaiveBayesOutput(NaiveBayes.this));
model.delete_and_lock(_job);
_job.update(1, "Begin distributed Naive Bayes calculation");
NBTask tsk = new NBTask(_job._key, dinfo, _response.cardinality()).doAll(dinfo._adaptedFrame);
if (computeStatsFillModel(model, dinfo, tsk))
model.update(_job);
} finally {
if (model != null) model.unlock(_job);
if (dinfo != null) dinfo.remove();
}
}
}
private TwoDimTable createModelSummaryTable(NaiveBayesOutput output) {
List colHeaders = new ArrayList<>();
List colTypes = new ArrayList<>();
List colFormat = new ArrayList<>();
colHeaders.add("Number of Response Levels"); colTypes.add("long"); colFormat.add("%d");
colHeaders.add("Min Apriori Probability"); colTypes.add("double"); colFormat.add("%.5f");
colHeaders.add("Max Apriori Probability"); colTypes.add("double"); colFormat.add("%.5f");
double apriori_min = output._apriori_raw[0];
double apriori_max = output._apriori_raw[0];
for(int i = 1; i < output._apriori_raw.length; i++) {
if(output._apriori_raw[i] < apriori_min) apriori_min = output._apriori_raw[i];
else if(output._apriori_raw[i] > apriori_max) apriori_max = output._apriori_raw[i];
}
final int rows = 1;
TwoDimTable table = new TwoDimTable(
"Model Summary", null,
new String[rows],
colHeaders.toArray(new String[0]),
colTypes.toArray(new String[0]),
colFormat.toArray(new String[0]),
"");
int row = 0;
int col = 0;
table.set(row, col++, output._apriori_raw.length);
table.set(row, col++, apriori_min);
table.set(row, col , apriori_max);
return table;
}
// Note: NA handling differs from R for efficiency purposes
// R's method: For each predictor x_j, skip counting that row for p(x_j|y) calculation if x_j = NA.
// If response y = NA, skip counting row entirely in all calculations
// H2O's method: Just skip all rows where any x_j = NA or y = NA. Should be more memory-efficient, but results incomparable with R.
private static class NBTask extends MRTask {
final protected Key _jobKey;
final DataInfo _dinfo;
final String[][] _domains; // Domains of the training frame
final int _nrescat; // Number of levels for the response y
final int _npreds; // Number of predictors in the training frame
public int _nobs; // Number of rows counted in calculation
public int[/*nrescat*/] _rescnt; // Count of each level in the response
public int[/*npreds*/][/*nrescat*/][] _jntcnt; // For each categorical predictor, joint count of response and predictor levels
public double[/*npreds*/][/*nrescat*/][] _jntsum; // For each numeric predictor, sum and squared sum of entries for every response level
public NBTask(Key jobKey, DataInfo dinfo, int nres) {
_jobKey = jobKey;
_dinfo = dinfo;
_nrescat = nres;
_domains = dinfo._adaptedFrame.domains();
_npreds = dinfo._cats + dinfo._nums;
}
@Override public void map(Chunk[] chks) {
if( _jobKey.get().stop_requested() ) return;
_nobs = 0;
_rescnt = new int[_nrescat];
if(_dinfo._cats > 0) {
_jntcnt = new int[_dinfo._cats][][];
for (int i = 0; i < _dinfo._cats; i++) {
_jntcnt[i] = new int[_nrescat][_domains[i].length];
}
}
if(_dinfo._nums > 0) {
_jntsum = new double[_dinfo._nums][][];
for (int i = 0; i < _dinfo._nums; i++) {
_jntsum[i] = new double[_nrescat][2];
}
}
Chunk res = chks[_dinfo.responseChunkId(0)]; //response
OUTER:
for(int row = 0; row < chks[0]._len; row++) {
if (_dinfo._weights && chks[_dinfo.weightChunkId()].atd(row)==0) continue OUTER;
if (_dinfo._weights && chks[_dinfo.weightChunkId()].atd(row)!=1) throw new IllegalArgumentException("Weights must be either 0 or 1 for Naive Bayes.");
// Skip row if any entries in it are NA
for( Chunk chk : chks ) {
if(Double.isNaN(chk.atd(row))) continue OUTER;
}
// Record joint counts of categorical predictors and response
int rlevel = (int)res.atd(row);
for(int col = 0; col < _dinfo._cats; col++) {
int plevel = (int)chks[col].atd(row);
_jntcnt[col][rlevel][plevel]++;
}
// Record sum for each pair of numerical predictors and response
for(int col = 0; col < _dinfo._nums; col++) {
int cidx = _dinfo._cats + col;
double x = chks[cidx].atd(row);
_jntsum[col][rlevel][0] += x;
_jntsum[col][rlevel][1] += x*x;
}
_rescnt[rlevel]++;
_nobs++;
}
}
@Override public void reduce(NBTask nt) {
_nobs += nt._nobs;
ArrayUtils.add(_rescnt, nt._rescnt);
if(null != _jntcnt) {
for (int col = 0; col < _jntcnt.length; col++)
ArrayUtils.add(_jntcnt[col], nt._jntcnt[col]);
}
if(null != _jntsum) {
for (int col = 0; col < _jntsum.length; col++)
ArrayUtils.add(_jntsum[col], nt._jntsum[col]);
}
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy