All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.naivebayes.NaiveBayes Maven / Gradle / Ivy

package hex.naivebayes;

import hex.*;
import hex.naivebayes.NaiveBayesModel.NaiveBayesOutput;
import hex.naivebayes.NaiveBayesModel.NaiveBayesParameters;
import water.*;
import water.exceptions.H2OModelBuilderIllegalArgumentException;
import water.fvec.Chunk;
import water.util.ArrayUtils;
import water.util.PrettyPrint;
import water.util.TwoDimTable;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

/**
 * Naive Bayes
 * This is an algorithm for computing the conditional a-posterior probabilities of a categorical
 * response from independent predictors using Bayes rule.
 * Naive Bayes on Wikipedia
 * Lecture Notes by Andrew Ng
 * @author anqi_fu
 *
 */
public class NaiveBayes extends ModelBuilder {
  public boolean isSupervised(){return true;}
  @Override protected NaiveBayesDriver trainModelImpl() { return new NaiveBayesDriver(); }
  @Override public ModelCategory[] can_build() { return new ModelCategory[]{ ModelCategory.Unknown }; }

  @Override public boolean havePojo() { return true; }
  @Override public boolean haveMojo() { return false; }

  @Override
  protected void checkMemoryFootPrint_impl() {
    // compute memory usage for pcond matrix
    long mem_usage = (_train.numCols() - 1) * _train.lastVec().cardinality();
    String[][] domains = _train.domains();
    long count = 0;
    for (int i = 0; i < _train.numCols() - 1; i++) {
      count += domains[i] == null ? 2 : domains[i].length;
    }
    mem_usage *= count;
    mem_usage *= 8; //doubles
    long max_mem = H2O.SELF._heartbeat.get_free_mem();
    if (mem_usage > max_mem) {
      String msg = "Conditional probabilities won't fit in the driver node's memory ("
              + PrettyPrint.bytes(mem_usage) + " > " + PrettyPrint.bytes(max_mem)
              + ") - try reducing the number of columns, the number of response classes or the number of categorical factors of the predictors.";
      error("_train", msg);
    }
  }

  // Called from an http request
  public NaiveBayes(NaiveBayesModel.NaiveBayesParameters parms) { super(parms); init(false); }
  public NaiveBayes(boolean startup_once) { super(new NaiveBayesParameters(),startup_once); }

  @Override
  public void init(boolean expensive) {
    super.init(expensive);
    if (_response != null) {
      if (!_response.isCategorical()) error("_response", "Response must be a categorical column");
      else if (_response.isConst()) error("_response", "Response must have at least two unique categorical levels");
    }
    if (_parms._laplace < 0) error("_laplace", "Laplace smoothing must be a number >= 0");
    if (_parms._min_sdev < 1e-10) error("_min_sdev", "Min. standard deviation must be at least 1e-10");
    if (_parms._eps_sdev < 0) error("_eps_sdev", "Threshold for standard deviation must be positive");
    if (_parms._min_prob < 1e-10) error("_min_prob", "Min. probability must be at least 1e-10");
    if (_parms._eps_prob < 0) error("_eps_prob", "Threshold for probability must be positive");
    hide("_balance_classes", "Balance classes is not applicable to NaiveBayes.");
    hide("_class_sampling_factors", "Class sampling factors is not applicable to NaiveBayes.");
    hide("_max_after_balance_size", "Max after balance size is not applicable to NaiveBayes.");
    if (expensive && error_count() == 0) checkMemoryFootPrint();
  }
  class NaiveBayesDriver extends Driver {

    public boolean computeStatsFillModel(NaiveBayesModel model, DataInfo dinfo, NBTask tsk) {
      model._output._levels = _response.domain();
      model._output._rescnt = tsk._rescnt;
      model._output._ncats = dinfo._cats;

      if(stop_requested() && !timeout()) return false;
      _job.update(1, "Initializing arrays for model statistics");
      // String[][] domains = dinfo._adaptedFrame.domains();
      String[][] domains = model._output._domains;
      double[] apriori = new double[tsk._nrescat];
      double[][][] pcond = new double[tsk._npreds][][];
      for(int i = 0; i < pcond.length; i++) {
        int ncnt = domains[i] == null ? 2 : domains[i].length;
        pcond[i] = new double[tsk._nrescat][ncnt];
      }

      if(stop_requested() && !timeout()) return false;
      _job.update(1, "Computing probabilities for categorical cols");
      // A-priori probability of response y
      for(int i = 0; i < apriori.length; i++)
        apriori[i] = ((double)tsk._rescnt[i] + _parms._laplace)/(tsk._nobs + tsk._nrescat * _parms._laplace);
        // apriori[i] = tsk._rescnt[i]/tsk._nobs;     // Note: R doesn't apply laplace smoothing to priors, even though this is textbook definition

      // Probability of categorical predictor x_j conditional on response y
      for(int col = 0; col < dinfo._cats; col++) {
        assert pcond[col].length == tsk._nrescat;
        for(int i = 0; i < pcond[col].length; i++) {
          for(int j = 0; j < pcond[col][i].length; j++)
            pcond[col][i][j] = ((double)tsk._jntcnt[col][i][j] + _parms._laplace)/((double)tsk._rescnt[i] + domains[col].length * _parms._laplace);
        }
      }

      if(stop_requested() && !timeout()) return false;
      _job.update(1, "Computing mean and standard deviation for numeric cols");
      // Mean and standard deviation of numeric predictor x_j for every level of response y
      for(int col = 0; col < dinfo._nums; col++) {
        for(int i = 0; i < pcond[0].length; i++) {
          int cidx = dinfo._cats + col;
          double num = tsk._rescnt[i];
          double pmean = tsk._jntsum[col][i][0]/num;

          pcond[cidx][i][0] = pmean;
          // double pvar = tsk._jntsum[col][i][1]/num - pmean * pmean;
          double pvar = tsk._jntsum[col][i][1]/(num - 1) - pmean * pmean * num/(num - 1);
          pcond[cidx][i][1] = Math.sqrt(pvar);
        }
      }
      model._output._apriori_raw = apriori;
      model._output._pcond_raw = pcond;

      // Create table of conditional probabilities for every predictor
      model._output._pcond = new TwoDimTable[pcond.length];
      String[] rowNames = _response.domain();
      for(int col = 0; col < dinfo._cats; col++) {
        String[] colNames = _train.vec(col).domain();
        String[] colTypes = new String[colNames.length];
        String[] colFormats = new String[colNames.length];
        Arrays.fill(colTypes, "double");
        Arrays.fill(colFormats, "%5f");
        model._output._pcond[col] = new TwoDimTable(_train.name(col), null, rowNames, colNames, colTypes, colFormats,
                "Y_by_" + _train.name(col), new String[rowNames.length][], pcond[col]);
      }

      for(int col = 0; col < dinfo._nums; col++) {
        int cidx = dinfo._cats + col;
        model._output._pcond[cidx] = new TwoDimTable(_train.name(cidx), null, rowNames, new String[] {"Mean", "Std_Dev"},
                new String[] {"double", "double"}, new String[] {"%5f", "%5f"}, "Y_by_" + _train.name(cidx),
                new String[rowNames.length][], pcond[cidx]);
      }

      // Create table of a-priori probabilities for the response
      String[] colTypes = new String[_response.cardinality()];
      String[] colFormats = new String[_response.cardinality()];
      Arrays.fill(colTypes, "double");
      Arrays.fill(colFormats, "%5f");
      model._output._apriori = new TwoDimTable("A Priori Response Probabilities", null, new String[1], _response.domain(), colTypes, colFormats, "",
              new String[1][], new double[][] {apriori});
      model._output._model_summary = createModelSummaryTable(model._output);

      if(stop_requested() && !timeout()) return false;
      _job.update(1, "Scoring and computing metrics on training data");
      if (_parms._compute_metrics) {
        model.score(_parms.train()).delete(); // This scores on the training data and appends a ModelMetrics
        model._output._training_metrics = ModelMetrics.getFromDKV(model,_parms.train());
      }

      // At the end: validation scoring (no need to gather scoring history)
      if(stop_requested() && !timeout()) return false;
      _job.update(1, "Scoring and computing metrics on validation data");
      if (_valid != null) {
        model.score(_parms.valid()).delete(); //this appends a ModelMetrics on the validation set
        model._output._validation_metrics = ModelMetrics.getFromDKV(model,_parms.valid());
      }

      return true;
    }

    @Override
    public void computeImpl() {
      NaiveBayesModel model = null;
      DataInfo dinfo = null;

      try {
        init(true);   // Initialize parameters
        if (error_count() > 0) throw H2OModelBuilderIllegalArgumentException.makeFromBuilder(NaiveBayes.this);
        dinfo = new DataInfo(_train, _valid, 1, false, DataInfo.TransformType.NONE, DataInfo.TransformType.NONE, true, false, false, _weights!=null, false, _fold!=null);

        // The model to be built
        model = new NaiveBayesModel(dest(), _parms, new NaiveBayesOutput(NaiveBayes.this));
        model.delete_and_lock(_job);

        _job.update(1, "Begin distributed Naive Bayes calculation");
        NBTask tsk = new NBTask(_job._key, dinfo, _response.cardinality()).doAll(dinfo._adaptedFrame);
        if (computeStatsFillModel(model, dinfo, tsk))
          model.update(_job);
      } finally {
        if (model != null) model.unlock(_job);
        if (dinfo != null) dinfo.remove();
      }
    }
  }

  private TwoDimTable createModelSummaryTable(NaiveBayesOutput output) {
    List colHeaders = new ArrayList<>();
    List colTypes = new ArrayList<>();
    List colFormat = new ArrayList<>();
    colHeaders.add("Number of Response Levels"); colTypes.add("long"); colFormat.add("%d");
    colHeaders.add("Min Apriori Probability"); colTypes.add("double"); colFormat.add("%.5f");
    colHeaders.add("Max Apriori Probability"); colTypes.add("double"); colFormat.add("%.5f");

    double apriori_min = output._apriori_raw[0];
    double apriori_max = output._apriori_raw[0];
    for(int i = 1; i < output._apriori_raw.length; i++) {
      if(output._apriori_raw[i] < apriori_min) apriori_min = output._apriori_raw[i];
      else if(output._apriori_raw[i] > apriori_max) apriori_max = output._apriori_raw[i];
    }

    final int rows = 1;
    TwoDimTable table = new TwoDimTable(
            "Model Summary", null,
            new String[rows],
            colHeaders.toArray(new String[0]),
            colTypes.toArray(new String[0]),
            colFormat.toArray(new String[0]),
            "");
    int row = 0;
    int col = 0;
    table.set(row, col++, output._apriori_raw.length);
    table.set(row, col++, apriori_min);
    table.set(row, col  , apriori_max);
    return table;
  }

  // Note: NA handling differs from R for efficiency purposes
  // R's method: For each predictor x_j, skip counting that row for p(x_j|y) calculation if x_j = NA.
  //             If response y = NA, skip counting row entirely in all calculations
  // H2O's method: Just skip all rows where any x_j = NA or y = NA. Should be more memory-efficient, but results incomparable with R.
  private static class NBTask extends MRTask {
    final protected Key _jobKey;
    final DataInfo _dinfo;
    final String[][] _domains;  // Domains of the training frame
    final int _nrescat;         // Number of levels for the response y
    final int _npreds;          // Number of predictors in the training frame

    public int _nobs;                     // Number of rows counted in calculation
    public int[/*nrescat*/] _rescnt;      // Count of each level in the response
    public int[/*npreds*/][/*nrescat*/][] _jntcnt;  // For each categorical predictor, joint count of response and predictor levels
    public double[/*npreds*/][/*nrescat*/][] _jntsum; // For each numeric predictor, sum and squared sum of entries for every response level

    public NBTask(Key jobKey, DataInfo dinfo, int nres) {
      _jobKey = jobKey;
      _dinfo = dinfo;
      _nrescat = nres;
      _domains = dinfo._adaptedFrame.domains();
      _npreds = dinfo._cats + dinfo._nums;
    }

    @Override public void map(Chunk[] chks) {
      if( _jobKey.get().stop_requested() ) return;
      _nobs = 0;
      _rescnt = new int[_nrescat];

      if(_dinfo._cats > 0) {
        _jntcnt = new int[_dinfo._cats][][];
        for (int i = 0; i < _dinfo._cats; i++) {
          _jntcnt[i] = new int[_nrescat][_domains[i].length];
        }
      }

      if(_dinfo._nums > 0) {
        _jntsum = new double[_dinfo._nums][][];
        for (int i = 0; i < _dinfo._nums; i++) {
          _jntsum[i] = new double[_nrescat][2];
        }
      }

      Chunk res = chks[_dinfo.responseChunkId(0)]; //response
      OUTER:
      for(int row = 0; row < chks[0]._len; row++) {
        if (_dinfo._weights && chks[_dinfo.weightChunkId()].atd(row)==0) continue OUTER;
        if (_dinfo._weights && chks[_dinfo.weightChunkId()].atd(row)!=1) throw new IllegalArgumentException("Weights must be either 0 or 1 for Naive Bayes.");
        // Skip row if any entries in it are NA
        for( Chunk chk : chks ) {
          if(Double.isNaN(chk.atd(row))) continue OUTER;
        }

        // Record joint counts of categorical predictors and response
        int rlevel = (int)res.atd(row);
        for(int col = 0; col < _dinfo._cats; col++) {
          int plevel = (int)chks[col].atd(row);
          _jntcnt[col][rlevel][plevel]++;
        }

        // Record sum for each pair of numerical predictors and response
        for(int col = 0; col < _dinfo._nums; col++) {
          int cidx = _dinfo._cats + col;
          double x = chks[cidx].atd(row);
          _jntsum[col][rlevel][0] += x;
          _jntsum[col][rlevel][1] += x*x;
        }
        _rescnt[rlevel]++;
        _nobs++;
      }
    }

    @Override public void reduce(NBTask nt) {
      _nobs += nt._nobs;
      ArrayUtils.add(_rescnt, nt._rescnt);
      if(null != _jntcnt) {
        for (int col = 0; col < _jntcnt.length; col++)
          ArrayUtils.add(_jntcnt[col], nt._jntcnt[col]);
      }
      if(null != _jntsum) {
        for (int col = 0; col < _jntsum.length; col++)
          ArrayUtils.add(_jntsum[col], nt._jntsum[col]);
      }
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy