All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.naivebayes.NaiveBayes Maven / Gradle / Ivy

There is a newer version: 3.46.0.6
Show newest version
package hex.naivebayes;

import hex.*;
import hex.schemas.ModelBuilderSchema;
import hex.schemas.NaiveBayesV3;
import water.*;
import water.exceptions.H2OModelBuilderIllegalArgumentException;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.Vec;
import water.util.ArrayUtils;
import water.util.Log;
import water.util.PrettyPrint;
import water.util.TwoDimTable;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

/**
 * Naive Bayes
 * This is an algorithm for computing the conditional a-posterior probabilities of a categorical
 * response from independent predictors using Bayes rule.
 * Naive Bayes on Wikipedia
 * Lecture Notes by Andrew Ng
 * @author anqi_fu
 *
 */
public class NaiveBayes extends ModelBuilder {
  @Override
  public ModelBuilderSchema schema() {
    return new NaiveBayesV3();
  }

  public boolean isSupervised(){return true;}

  @Override
  public Job trainModel() {
    return start(new NaiveBayesDriver(), 1);
  }

  @Override
  public ModelCategory[] can_build() {
    return new ModelCategory[]{ ModelCategory.Unknown };
  }

  @Override public BuilderVisibility builderVisibility() { return BuilderVisibility.Stable; };

  @Override
  protected void checkMemoryFootPrint() {
    // compute memory usage for pcond matrix
    long mem_usage = (_train.numCols() - 1) * _train.lastVec().cardinality();
    String[][] domains = _train.domains();
    long count = 0;
    for (int i = 0; i < _train.numCols() - 1; i++) {
      count += domains[i] == null ? 2 : domains[i].length;
    }
    mem_usage *= count;
    mem_usage *= 8; //doubles
    long max_mem = H2O.SELF.get_max_mem();
    if (mem_usage > max_mem) {
      String msg = "Conditional probabilities won't fit in the driver node's memory ("
              + PrettyPrint.bytes(mem_usage) + " > " + PrettyPrint.bytes(max_mem)
              + ") - try reducing the number of columns, the number of response classes or the number of categorical factors of the predictors.";
      error("_train", msg);
      cancel(msg);
    }
  }

  // Called from an http request
  public NaiveBayes(NaiveBayesModel.NaiveBayesParameters parms) {
    super("NaiveBayes", parms);
    init(false);
  }

  @Override
  public void init(boolean expensive) {
    super.init(expensive);
    if (_response != null) {
      if (!_response.isEnum()) error("_response", "Response must be a categorical column");
      else if (_response.isConst()) error("_response", "Response must have at least two unique categorical levels");
    }
    if (_parms._laplace < 0) error("_laplace", "Laplace smoothing must be an integer >= 0");
    if (_parms._min_sdev < 1e-10) error("_min_sdev", "Min. standard deviation must be at least 1e-10");
    if (_parms._eps_sdev < 0) error("_eps_sdev", "Threshold for standard deviation must be positive");
    if (_parms._min_prob < 1e-10) error("_min_prob", "Min. probability must be at least 1e-10");
    if (_parms._eps_prob < 0) error("_eps_prob", "Threshold for probability must be positive");
    hide("_balance_classes", "Balance classes is not applicable to NaiveBayes.");
    hide("_class_sampling_factors", "Class sampling factors is not applicable to NaiveBayes.");
    hide("_max_after_balance_size", "Max after balance size is not applicable to NaiveBayes.");
    if (expensive && error_count() == 0) checkMemoryFootPrint();
  }
  private static boolean couldBeBool(Vec v) { return v != null && v.isInt() && v.min()+1==v.max(); }

  class NaiveBayesDriver extends H2O.H2OCountedCompleter {

    public void computeStatsFillModel(NaiveBayesModel model, DataInfo dinfo, NBTask tsk) {
      model._output._levels = _response.domain();
      model._output._rescnt = tsk._rescnt;
      model._output._ncats = dinfo._cats;

      // String[][] domains = dinfo._adaptedFrame.domains();
      String[][] domains = model._output._domains;
      double[] apriori = new double[tsk._nrescat];
      double[][][] pcond = new double[tsk._npreds][][];
      for(int i = 0; i < pcond.length; i++) {
        int ncnt = domains[i] == null ? 2 : domains[i].length;
        pcond[i] = new double[tsk._nrescat][ncnt];
      }

      // A-priori probability of response y
      for(int i = 0; i < apriori.length; i++)
        apriori[i] = ((double)tsk._rescnt[i] + _parms._laplace)/(tsk._nobs + tsk._nrescat * _parms._laplace);
        // apriori[i] = tsk._rescnt[i]/tsk._nobs;     // Note: R doesn't apply laplace smoothing to priors, even though this is textbook definition

      // Probability of categorical predictor x_j conditional on response y
      for(int col = 0; col < dinfo._cats; col++) {
        assert pcond[col].length == tsk._nrescat;
        for(int i = 0; i < pcond[col].length; i++) {
          for(int j = 0; j < pcond[col][i].length; j++)
            pcond[col][i][j] = ((double)tsk._jntcnt[col][i][j] + _parms._laplace)/((double)tsk._rescnt[i] + domains[col].length * _parms._laplace);
        }
      }

      // Mean and standard deviation of numeric predictor x_j for every level of response y
      for(int col = 0; col < dinfo._nums; col++) {
        for(int i = 0; i < pcond[0].length; i++) {
          int cidx = dinfo._cats + col;
          double num = tsk._rescnt[i];
          double pmean = tsk._jntsum[col][i][0]/num;

          pcond[cidx][i][0] = pmean;
          // double pvar = tsk._jntsum[col][i][1]/num - pmean * pmean;
          double pvar = tsk._jntsum[col][i][1]/(num - 1) - pmean * pmean * num/(num - 1);
          pcond[cidx][i][1] = Math.sqrt(pvar);
        }
      }
      model._output._apriori_raw = apriori;
      model._output._pcond_raw = pcond;

      // Create table of conditional probabilities for every predictor
      model._output._pcond = new TwoDimTable[pcond.length];
      String[] rowNames = _response.domain();
      for(int col = 0; col < dinfo._cats; col++) {
        String[] colNames = _train.vec(col).domain();
        String[] colTypes = new String[colNames.length];
        String[] colFormats = new String[colNames.length];
        Arrays.fill(colTypes, "double");
        Arrays.fill(colFormats, "%5f");
        model._output._pcond[col] = new TwoDimTable(_train.name(col), null, rowNames, colNames, colTypes, colFormats,
                "Y_by_" + _train.name(col), new String[rowNames.length][], pcond[col]);
      }

      for(int col = 0; col < dinfo._nums; col++) {
        int cidx = dinfo._cats + col;
        model._output._pcond[cidx] = new TwoDimTable(_train.name(cidx), null, rowNames, new String[] {"Mean", "Std_Dev"},
                new String[] {"double", "double"}, new String[] {"%5f", "%5f"}, "Y_by_" + _train.name(cidx),
                new String[rowNames.length][], pcond[cidx]);
      }

      // Create table of a-priori probabilities for the response
      String[] colTypes = new String[_response.cardinality()];
      String[] colFormats = new String[_response.cardinality()];
      Arrays.fill(colTypes, "double");
      Arrays.fill(colFormats, "%5f");
      model._output._apriori = new TwoDimTable("A Priori Response Probabilities", null, new String[1], _response.domain(), colTypes, colFormats, "",
              new String[1][], new double[][] {apriori});

      model._output._model_summary = createModelSummaryTable(model._output);
      if (_parms._compute_metrics) {
        model.score(_parms.train()).delete(); // This scores on the training data and appends a ModelMetrics
        ModelMetricsSupervised mm = DKV.getGet(model._output._model_metrics[model._output._model_metrics.length - 1]);
        model._output._training_metrics = mm;
      }

      // At the end: validation scoring (no need to gather scoring history)
      if (_valid != null) {
        Frame pred = model.score(_parms.valid()); //this appends a ModelMetrics on the validation set
        model._output._validation_metrics = DKV.getGet(model._output._model_metrics[model._output._model_metrics.length - 1]);
        pred.delete();
      }
    }

    @Override
    protected void compute2() {
      NaiveBayesModel model = null;
      DataInfo dinfo = null;

      try {
        init(true);   // Initialize parameters
        _parms.read_lock_frames(NaiveBayes.this); // Fetch & read-lock input frames
        if (error_count() > 0) throw H2OModelBuilderIllegalArgumentException.makeFromBuilder(NaiveBayes.this);
        dinfo = new DataInfo(Key.make(), _train, _valid, 1, false, DataInfo.TransformType.NONE, DataInfo.TransformType.NONE, true, false, false, false);

        // The model to be built
        model = new NaiveBayesModel(dest(), _parms, new NaiveBayesModel.NaiveBayesOutput(NaiveBayes.this));
        model.delete_and_lock(_key);
        _train.read_lock(_key);

        NBTask tsk = new NBTask(dinfo, _response.cardinality()).doAll(dinfo._adaptedFrame);
        computeStatsFillModel(model, dinfo, tsk);
        model.update(_key);
        update(1);

        done();
      } catch (Throwable t) {
        Job thisJob = DKV.getGet(_key);
        if (thisJob._state == JobState.CANCELLED) {
          Log.info("Job cancelled by user.");
        } else {
          t.printStackTrace();
          failed(t);
          throw t;
        }
      } finally {
        _train.unlock(_key);
        if (model != null) model.unlock(_key);
        if (dinfo != null) dinfo.remove();
        _parms.read_unlock_frames(NaiveBayes.this);
      }
      tryComplete();
    }
  }

  private TwoDimTable createModelSummaryTable(NaiveBayesModel.NaiveBayesOutput output) {
    List colHeaders = new ArrayList<>();
    List colTypes = new ArrayList<>();
    List colFormat = new ArrayList<>();
    colHeaders.add("Number of Response Levels"); colTypes.add("long"); colFormat.add("%d");
    colHeaders.add("Min Apriori Probability"); colTypes.add("double"); colFormat.add("%.5f");
    colHeaders.add("Max Apriori Probability"); colTypes.add("double"); colFormat.add("%.5f");

    double apriori_min = output._apriori_raw[0];
    double apriori_max = output._apriori_raw[0];
    for(int i = 1; i < output._apriori_raw.length; i++) {
      if(output._apriori_raw[i] < apriori_min) apriori_min = output._apriori_raw[i];
      else if(output._apriori_raw[i] > apriori_max) apriori_max = output._apriori_raw[i];
    }

    final int rows = 1;
    TwoDimTable table = new TwoDimTable(
            "Model Summary", null,
            new String[rows],
            colHeaders.toArray(new String[0]),
            colTypes.toArray(new String[0]),
            colFormat.toArray(new String[0]),
            "");
    int row = 0;
    int col = 0;
    table.set(row, col++, output._apriori_raw.length);
    table.set(row, col++, apriori_min);
    table.set(row, col++, apriori_max);
    return table;
  }

  // Note: NA handling differs from R for efficiency purposes
  // R's method: For each predictor x_j, skip counting that row for p(x_j|y) calculation if x_j = NA.
  //             If response y = NA, skip counting row entirely in all calculations
  // H2O's method: Just skip all rows where any x_j = NA or y = NA. Should be more memory-efficient, but results incomparable with R.
  private static class NBTask extends MRTask {
    final DataInfo _dinfo;
    final String[][] _domains;  // Domains of the training frame
    final int _nrescat;         // Number of levels for the response y
    final int _npreds;          // Number of predictors in the training frame

    public int _nobs;                     // Number of rows counted in calculation
    public int[/*nrescat*/] _rescnt;      // Count of each level in the response
    public int[/*npreds*/][/*nrescat*/][] _jntcnt;  // For each categorical predictor, joint count of response and predictor levels
    public double[/*npreds*/][/*nrescat*/][] _jntsum; // For each numeric predictor, sum and squared sum of entries for every response level

    public NBTask(DataInfo dinfo, int nres) {
      _dinfo = dinfo;
      _nrescat = nres;
      _domains = dinfo._adaptedFrame.domains();
      _npreds = dinfo._adaptedFrame.numCols()-1;
      assert _npreds == dinfo._nums + dinfo._cats;
      assert _nrescat == _domains[_npreds].length;       // Response in last vec of adapted frame
    }

    @Override public void map(Chunk[] chks) {
      _nobs = 0;
      _rescnt = new int[_nrescat];

      if(_dinfo._cats > 0) {
        _jntcnt = new int[_dinfo._cats][][];
        for (int i = 0; i < _dinfo._cats; i++) {
          _jntcnt[i] = new int[_nrescat][_domains[i].length];
        }
      }

      if(_dinfo._nums > 0) {
        _jntsum = new double[_dinfo._nums][][];
        for (int i = 0; i < _dinfo._nums; i++) {
          _jntsum[i] = new double[_nrescat][2];
        }
      }

      Chunk res = chks[_npreds];    // Response at the end
      OUTER:
      for(int row = 0; row < chks[0]._len; row++) {
        // Skip row if any entries in it are NA
        for(int col = 0; col < chks.length; col++) {
          if(Double.isNaN(chks[col].atd(row))) continue OUTER;
        }

        // Record joint counts of categorical predictors and response
        int rlevel = (int)res.atd(row);
        for(int col = 0; col < _dinfo._cats; col++) {
          int plevel = (int)chks[col].atd(row);
          _jntcnt[col][rlevel][plevel]++;
        }

        // Record sum for each pair of numerical predictors and response
        for(int col = 0; col < _dinfo._nums; col++) {
          int cidx = _dinfo._cats + col;
          double x = chks[cidx].atd(row);
          _jntsum[col][rlevel][0] += x;
          _jntsum[col][rlevel][1] += x*x;
        }
        _rescnt[rlevel]++;
        _nobs++;
      }
    }

    @Override public void reduce(NBTask nt) {
      _nobs += nt._nobs;
      ArrayUtils.add(_rescnt, nt._rescnt);
      if(null != _jntcnt) {
        for (int col = 0; col < _jntcnt.length; col++)
          ArrayUtils.add(_jntcnt[col], nt._jntcnt[col]);
      }
      if(null != _jntsum) {
        for (int col = 0; col < _jntsum.length; col++)
          ArrayUtils.add(_jntsum[col], nt._jntsum[col]);
      }
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy