All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.naivebayes.NaiveBayesModel Maven / Gradle / Ivy

There is a newer version: 3.46.0.6
Show newest version
package hex.naivebayes;

import hex.Model;
import hex.ModelMetrics;
import hex.ModelMetricsBinomial;
import hex.ModelMetricsMultinomial;
import hex.genmodel.GenModel;
import hex.schemas.NaiveBayesModelV3;
import water.H2O;
import water.Key;
import water.api.schemas3.ModelSchemaV3;
import water.codegen.CodeGenerator;
import water.codegen.CodeGeneratorPipeline;
import water.exceptions.JCodeSB;
import water.util.JCodeGen;
import water.util.SBPrintStream;
import water.util.TwoDimTable;

public class NaiveBayesModel extends Model {

  public static class NaiveBayesParameters extends Model.Parameters {
    public double _laplace = 0;         // Laplace smoothing parameter
    public double _eps_sdev = 0;   // Cutoff below which standard deviation is replaced with _min_sdev
    public double _min_sdev = 0.001;   // Minimum standard deviation to use for observations without enough data
    public double _eps_prob = 0;   // Cutoff below which probability is replaced with _min_prob
    public double _min_prob = 0.001;   // Minimum conditional probability to use for observations without enough data
    public boolean _compute_metrics = true;   // Should a second pass be made through data to compute metrics?
    public String algoName() { return "NaiveBayes"; }
    public String fullName() { return "Naive Bayes"; }
    public String javaName() { return NaiveBayesModel.class.getName(); }
    @Override public long progressUnits() { return 6; }
  }

  public static class NaiveBayesOutput extends Model.Output {
    // Class distribution of the response
    public TwoDimTable _apriori;
    public double[/*res level*/] _apriori_raw;

    // For every predictor, a table providing, for each attribute level, the conditional probabilities given the target class
    public TwoDimTable[/*predictor*/] _pcond;
    public double[/*predictor*/][/*res level*/][/*pred level*/] _pcond_raw;

    // Count of response levels
    public int[] _rescnt;

    // Domain of the response
    public String[] _levels;

    // Number of categorical predictors
    public int _ncats;

    public NaiveBayesOutput(NaiveBayes b) { super(b); }
  }

  public NaiveBayesModel(Key selfKey, NaiveBayesParameters parms, NaiveBayesOutput output) { super(selfKey,parms,output); }

  public ModelSchemaV3 schema() {
    return new NaiveBayesModelV3();
  }

  // TODO: Constant response shouldn't be regression. Need to override getModelCategory()
  @Override public ModelMetrics.MetricBuilder makeMetricBuilder(String[] domain) {
    switch(_output.getModelCategory()) {
      case Binomial:    return new ModelMetricsBinomial.MetricBuilderBinomial(domain);
      case Multinomial: return new ModelMetricsMultinomial.MetricBuilderMultinomial(domain.length,domain);
      default: throw H2O.unimpl();
    }
  }

  // Note: For small probabilities, product may end up zero due to underflow error. Can circumvent by taking logs.
  @Override protected double[] score0(double[] data, double[] preds) {
    double[] nums = new double[_output._levels.length];    // log(p(x,y)) for all levels of y
    assert preds.length >= (_output._levels.length + 1);   // Note: First column of preds is predicted response class

    // Compute joint probability of predictors for every response class
    for(int rlevel = 0; rlevel < _output._levels.length; rlevel++) {
      // Take logs to avoid overflow: p(x,y) = p(x|y)*p(y) -> log(p(x,y)) = log(p(x|y)) + log(p(y))
      nums[rlevel] = Math.log(_output._apriori_raw[rlevel]);

      for(int col = 0; col < _output._ncats; col++) {
        if(Double.isNaN(data[col])) continue;   // Skip predictor in joint x_1,...,x_m if NA
        int plevel = (int)data[col];
        double prob = plevel < _output._pcond_raw[col][rlevel].length ? _output._pcond_raw[col][rlevel][plevel] :
                _parms._laplace / ((double)_output._rescnt[rlevel] + _parms._laplace * _output._domains[col].length);   // Laplace smoothing if predictor level unobserved in training set
        nums[rlevel] += Math.log(prob <= _parms._eps_prob ? _parms._min_prob : prob);   // log(p(x|y)) = \sum_{j = 1}^m p(x_j|y)
      }

      // For numeric predictors, assume Gaussian distribution with sample mean and variance from model
      for(int col = _output._ncats; col < data.length; col++) {
        if(Double.isNaN(data[col])) continue;   // Skip predictor in joint x_1,...,x_m if NA
        double x = data[col];
        double mean = Double.isNaN(_output._pcond_raw[col][rlevel][0]) ? 0 : _output._pcond_raw[col][rlevel][0];
        double stddev = Double.isNaN(_output._pcond_raw[col][rlevel][1]) ? 1.0 :
          (_output._pcond_raw[col][rlevel][1] <= _parms._eps_sdev ? _parms._min_sdev : _output._pcond_raw[col][rlevel][1]);
        // double prob = Math.exp(new NormalDistribution(mean, stddev).density(data[col])); // slower
        double prob = Math.exp(-((x-mean)*(x-mean))/(2.*stddev*stddev)) / (stddev*Math.sqrt(2.*Math.PI)); // faster
        nums[rlevel] += Math.log(prob <= _parms._eps_prob ? _parms._min_prob : prob);
      }
    }

    // Numerically unstable:
    // p(x,y) = exp(log(p(x,y))), p(x) = \Sum_{r = levels of y} exp(log(p(x,y = r))) -> p(y|x) = p(x,y)/p(x)
    // Instead, we rewrite using a more stable form:
    // p(y|x) = p(x,y)/p(x) = exp(log(p(x,y))) / (\Sum_{r = levels of y} exp(log(p(x,y = r)))
    //        = 1 / ( exp(-log(p(x,y))) * \Sum_{r = levels of y} exp(log(p(x,y = r))) )
    //        = 1 / ( \Sum_{r = levels of y} exp( log(p(x,y = r)) - log(p(x,y)) ))
    for(int i = 0; i < nums.length; i++) {
      double sum = 0;
      for(int j = 0; j < nums.length; j++)
        sum += Math.exp(nums[j] - nums[i]);
      preds[i+1] = 1/sum;
    }

    // Select class with highest conditional probability
    preds[0] = GenModel.getPrediction(preds, _output._priorClassDist, data, defaultThreshold());
    return preds;
  }

  @Override protected SBPrintStream toJavaInit(SBPrintStream sb, CodeGeneratorPipeline fileCtx) {
    sb = super.toJavaInit(sb, fileCtx);
    sb.ip("public boolean isSupervised() { return " + isSupervised() + "; }").nl();
    sb.ip("public int nfeatures() { return " + _output.nfeatures() + "; }").nl();
    sb.ip("public int nclasses() { return " + _output.nclasses() + "; }").nl();

    // This is model name
    final String mname = JCodeGen.toJavaId(_key.toString());

    fileCtx.add(new CodeGenerator() {
      @Override
      public void generate(JCodeSB out) {
        JCodeGen.toClassWithArray(out, null, mname + "_RESCNT", _output._rescnt,
                                  "Count of categorical levels in response.");
        JCodeGen.toClassWithArray(out, null, mname + "_APRIORI", _output._apriori_raw,
                                  "Apriori class distribution of the response.");
        JCodeGen.toClassWithArray(out, null, mname + "_PCOND", _output._pcond_raw,
                                  "Conditional probability of predictors.");
        double[] dlen = null;
        if (_output._ncats > 0) {
          dlen = new double[_output._ncats];
          for (int i = 0; i < _output._ncats; i++)
            dlen[i] = _output._domains[i].length;
        }
        JCodeGen.toClassWithArray(out, null, mname + "_DOMLEN", dlen,
                                  "Number of unique levels for each categorical predictor.");
      }
    });

    return sb;
  }

  @Override protected void toJavaPredictBody(SBPrintStream bodySb,
                                             CodeGeneratorPipeline classCtx,
                                             CodeGeneratorPipeline fileCtx,
                                             final boolean verboseCode) {
    // This is model name
    final String mname = JCodeGen.toJavaId(_key.toString());

    bodySb.i().p("java.util.Arrays.fill(preds,0);").nl();
    bodySb.i().p("double mean, sdev, prob;").nl();
    bodySb.i().p("double[] nums = new double[" + _output._levels.length + "];").nl();

    bodySb.i().p("for(int i = 0; i < " + _output._levels.length + "; i++) {").nl();
    bodySb.i(1).p("nums[i] = Math.log(").pj(mname+"_APRIORI", "VALUES").p("[i]);").nl();
    bodySb.i(1).p("for(int j = 0; j < " + _output._ncats + "; j++) {").nl();
    bodySb.i(2).p("if(Double.isNaN(data[j])) continue;").nl();
    bodySb.i(2).p("int level = (int)data[j];").nl();
    bodySb.i(2).p("prob = level < ").p(_output._pcond_raw.length).p(" ? " + mname + "_PCOND.VALUES[j][i][level] : ")
                .p(_parms._laplace == 0 ? "0" : _parms._laplace + "/("+mname+"_RESCNT.VALUES[i] + " + _parms._laplace
                                              + "*" + mname + "_DOMLEN.VALUES[j])").p(";").nl();
    bodySb.i(2).p("nums[i] += Math.log(prob <= " + _parms._eps_prob + " ? " + _parms._min_prob + " : prob);").nl();
    bodySb.i(1).p("}").nl();

    bodySb.i(1).p("for(int j = " + _output._ncats + "; j < data.length; j++) {").nl();
    bodySb.i(2).p("if(Double.isNaN(data[j])) continue;").nl();
    bodySb.i(2).p("mean = Double.isNaN("+mname+"_PCOND.VALUES[j][i][0]) ? 0 : "+mname+"_PCOND.VALUES[j][i][0];").nl();
    bodySb.i(2).p("sdev = Double.isNaN("+mname+"_PCOND.VALUES[j][i][1]) ? 1 : ("+mname+"_PCOND.VALUES[j][i][1] <= " + _parms._eps_sdev + " ? "
            + _parms._min_sdev + " : "+mname+"_PCOND.VALUES[j][i][1]);").nl();
    bodySb.i(2).p("prob = Math.exp(-((data[j]-mean)*(data[j]-mean))/(2.*sdev*sdev)) / (sdev*Math.sqrt(2.*Math.PI));").nl();
    bodySb.i(2).p("nums[i] += Math.log(prob <= " + _parms._eps_prob + " ? " + _parms._min_prob + " : prob);").nl();
    bodySb.i(1).p("}").nl();
    bodySb.i().p("}").nl();

    bodySb.i().p("double sum;").nl();
    bodySb.i().p("for(int i = 0; i < nums.length; i++) {").nl();
    bodySb.i(1).p("sum = 0;").nl();
    bodySb.i(1).p("for(int j = 0; j < nums.length; j++) {").nl();
    bodySb.i(2).p("sum += Math.exp(nums[j]-nums[i]);").nl();
    bodySb.i(1).p("}").nl();
    bodySb.i(1).p("preds[i+1] = 1/sum;").nl();
    bodySb.i().p("}").nl();

    bodySb.i().p("preds[0] = hex.genmodel.GenModel.getPrediction(preds, PRIOR_CLASS_DISTRIB, data, " + defaultThreshold()+");").nl();
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy