All Downloads are FREE. Search and download functionalities are using the official Maven repository.

nlp.NaiveBayes Maven / Gradle / Ivy

Go to download

Natural language processing toolbox using Sigma knowledge engineering system.

There is a newer version: 1.1
Show newest version
package nlp;

/*
Copyright 2014-2015 IPsoft

Author: Adam Pease [email protected]

This program is free software; you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation; either version 2 of the License, or
(at your option) any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with this program ; if not, write to the Free Software
Foundation, Inc., 59 Temple Place, Suite 330, Boston,
MA  02111-1307 USA

Use the Naive Bayes approach to train a classifier to predict
which class a set of values is likely to be a member of
http://guidetodatamining.com/guide/ch6/DataMining-ch6.pdf
 */

import com.google.common.collect.Lists;
import com.articulate.sigma.DocGen;
import com.articulate.sigma.StringUtil;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;

public class NaiveBayes {

    public ArrayList> input = null;
    public ArrayList labels = new ArrayList<>();
    public ArrayList types = new ArrayList<>();

    public HashMap> means = new HashMap<>();
    public HashMap> ssd = new HashMap<>();
    public HashMap> totals = new HashMap<>();
    public HashMap>> numericValues = new HashMap<>();

    public HashMap priorCounts = new HashMap<>();
    public HashMap>> conditionalCounts = new HashMap<>();
    public HashMap priors = new HashMap<>();
    public HashMap>> conds = new HashMap<>();

    /** *************************************************************
     */
    public NaiveBayes(String filename) {

        DocGen dg = DocGen.getInstance();
        if (StringUtil.emptyString(filename))
            input = dg.readSpreadsheetFile("/home/apease/IPsoft/NB/NBdata.txt", ',');
        else
            input = dg.readSpreadsheetFile(filename, ',');
        //System.out.println(input);
        input.remove(0);  // remove "start"
        types.addAll(input.get(0));  // these can be discrete "disc", continuous "cont" or "class"
        labels.addAll(input.get(1));
        input.remove(0);  // remove types
        input.remove(0);  // remove headers
        //input.remove(input.size()-1);
        //System.out.println(input);
    }

    /** *************************************************************
     */
    public NaiveBayes(ArrayList> in,
                      ArrayList labels,
                      ArrayList types) {

        input = new ArrayList>();
        //System.out.println("NaiveBayes with #input: " + in.size());
        input.addAll(in);
        this.types = types; // these can be discrete "disc", continuous "cont" or "class"
        this.labels = labels;
        //System.out.println("NaiveBayes() : starting line: " + input.get(0));
    }

    /** *************************************************************
     * Compute P(x|y) given the mean, sample standard deviation and x
     */
    public static float probDensFunc(float mean, float ssd, float x) {

        float epart = (float) Math.exp((double) -(x-mean)*(x-mean)/(2*ssd*ssd));
        return ((float) 1.0 / ((float) Math.sqrt(2*Math.PI)*ssd)) * epart;
    }

    /** *************************************************************
     * Compute P(x|y) given the mean, sample standard deviation and x
     */
    public float probDensFunc(String clss, String label, float x) {

        float mean = means.get(clss).get(label);
        float dev = ssd.get(clss).get(label);
        return probDensFunc(mean,dev,x);
    }

    /** *************************************************************
     * Count the number of occurrences of each class.
     */
    public void createPriorCounts() {

        //System.out.println("NaiveBayes.createPriorCounts() : starting line: " + input.get(0));
        int classIndex = types.indexOf("class");
        for (ArrayList row : input) {
            String clss = row.get(classIndex);
            if (!priorCounts.containsKey(clss))
                priorCounts.put(clss,new Integer(0));
            priorCounts.put(clss,priorCounts.get(clss) + 1);
        }
    }

    /** *************************************************************
     * Count the number of occurrences of each class.  The class name
     * must be the last element of each row of the input data.
     *
     * {i100={interest={appearance=2, health=1, ...}, ...},
     * i500={interest={appearance=3, health=4,...}, ...} }
     */
    public void createConditionalCounts() {

        //System.out.println("NaiveBayes.createConditionalCounts() : starting line: " + input.get(0));
        for (ArrayList row : input) {
            int classIndex = types.indexOf("class");
            String clss = row.get(classIndex);
            //System.out.println("in createConditionalCounts(): " + clss);
            HashMap> classInfo = conditionalCounts.get(clss);
            if (classInfo == null) {
                classInfo = new HashMap>();
                conditionalCounts.put(clss,classInfo);
            }
            for (String label : labels) {
                int column = labels.indexOf(label);
                if (!types.get(column).equals("disc"))
                    continue;

                HashMap values = classInfo.get(label);
                if (values == null)
                    values = new HashMap();
                if (types.get(column).equals("disc")) {
                    String value = row.get(column);
                    if (values.containsKey(value))
                        values.put(value, values.get(value) + 1);
                    else
                        values.put(value, new Integer(1));
                    classInfo.put(label.toString(), values);
                }
            }
            conditionalCounts.put(clss,classInfo);
        }
        //System.out.println("NaiveBayes.createConditionalCounts() : " + conditionalCounts);
    }

    /** *************************************************************
     * Create totals per class of each variable that is continuous.
     */
    public void createTotals() {

        System.out.println("NaiveBayes.createTotals() : types: " + types);
        System.out.println("NaiveBayes.createTotals() : labels: " + labels);
        System.out.println("NaiveBayes.createTotals() : starting line: " + input.get(0));
        for (ArrayList row : input) {
            int classIndex = types.indexOf("class");
            String clss = row.get(classIndex);
            HashMap classInfo = totals.get(clss);
            if (classInfo == null) {
                classInfo = new HashMap();
                totals.put(clss, classInfo);
            }
            for (int i = 0; i < row.size(); i++) {
                if (!types.get(i).equals("cont"))
                    continue;
                String column = labels.get(i);
                float value = Float.parseFloat(row.get(i));
                if (classInfo.containsKey(column)) {
                    value += classInfo.get(column);
                }
                classInfo.put(column, value);
            }
        }
    }

    /** *************************************************************
     */
    public void createMeans() {

        for (String clss : totals.keySet()) {
            HashMap classMeanInfo = means.get(clss);
            if (classMeanInfo == null) {
                classMeanInfo = new HashMap();
                means.put(clss, classMeanInfo);
            }
            HashMap classTotalsInfo = totals.get(clss);
            float count = (float) priorCounts.get(clss);
            for (String column : classTotalsInfo.keySet()) {
                float value = classTotalsInfo.get(column);
                classMeanInfo.put(column,value / count);
            }
            //System.out.println("createMeans(): " + clss + ":" + classMeanInfo);
        }
    }

    /** *************************************************************
     * Note that this computes the sample standard deviation
     * sigma = sqrt( (1/(N-1)) sum(1,n,(xi-meanx)*(xi-meanx)))
     */
    public void createStandardDeviation() {

        for (ArrayList row : input) {
            int classIndex = types.indexOf("class");
            String clss = row.get(classIndex);
            HashMap classMeansInfo = means.get(clss);
            HashMap classSsdInfo = ssd.get(clss);
            if (classSsdInfo == null) {
                classSsdInfo = new HashMap();
                ssd.put(clss, classSsdInfo);
            }
            for (String label : classMeansInfo.keySet()) {
                if (label.equals("cont"))
                    continue;
                float mean = classMeansInfo.get(label);
                int colIndex = labels.indexOf(label);
                float value = Float.parseFloat(row.get(colIndex));
                float squaredDiff = (value - mean) * (value - mean);
                float total = squaredDiff;
                if (classSsdInfo.containsKey(label))
                    total = classSsdInfo.get(label) + squaredDiff;
                classSsdInfo.put(label,total);
            }
        }
        for (String clss : ssd.keySet()) {
            HashMap classSsdInfo = ssd.get(clss);
            int classCount = priorCounts.get(clss);
            for (String label : classSsdInfo.keySet()) {
                float variance = classSsdInfo.get(label) / ((float) (classCount - 1));
                variance = (float) Math.sqrt(variance);
                classSsdInfo.put(label,variance);
            }
        }
    }

    /** *************************************************************
     */
    public int findTrainingSetSize() {

        int sum = 0;
        for (String s : priorCounts.keySet()) {
            sum += priorCounts.get(s);
        }
        return sum;
    }

    /** *************************************************************
     * Calculate the prior probabilities of each class given the
     * numbers of instances of each class.
     */
    public void calcPriors(int sum) {

        for (String s : priorCounts.keySet())
            priors.put(s,new Float((float) priorCounts.get(s) / (float) sum));
    }

    /** *************************************************************
     * Compute conditionals in the format of a class name key then
     * the probabilities of the values for each "column" which is a
     * numerical key.  So below, the probability of getting the value
     * "appearance" as the value of variable 1 for class
     * {'i500': {1: {'appearance':0.33,'health':0.44},
     *           2: {...
     */
    public void calcConditionals(int sum) {

        for (String clss : conditionalCounts.keySet()) {
            int clssCount = priorCounts.get(clss);
            HashMap> classCounts = conditionalCounts.get(clss);
            HashMap> classConditionals = new HashMap<>();
            for (String col : classCounts.keySet()) {
                HashMap colValues = classCounts.get(col);
                HashMap colConditionals = new HashMap<>();
                for (String val : colValues.keySet()) {
                    float posterior = (float) colValues.get(val).intValue() / clssCount;
                    // m-estimate of probability (Mitchell, Machine Learning, p 179, eq 6.22)
                    //float posterior = (float) colValues.get(val).intValue() + (float) 1.0 /
                    //        ((float) clssCount + (float) colValues.keySet().size());
                    colConditionals.put(val,posterior);
                }
                classConditionals.put(col,colConditionals);
            }
            conds.put(clss,classConditionals);
        }
    }

    /** *************************************************************
     * Given the conditional and prior probabilities, and a particular
     * instance set of attributes, compute which class that instance is
     * mostly likely to fall into.
     */
    public String classify(List values) {

        if (values == null) {
            System.out.println("Error in NaiveBayes.classify: null input");
            return "";
        }
        int classIndex = types.indexOf("class");
        int indexMod = 0;  // if class name is not the last element
        if (classIndex == 0)
            indexMod = 1;
        float maxProb = 0;
        String maxClass = "";
        HashMap probs = new HashMap();
        for (String clss : conds.keySet()) {
            HashMap> posteriors = conds.get(clss);
            float prior = priors.get(clss);
            float prob = prior;
            for (String label : labels) {
                if (label.equals("class"))
                    continue;
                int index = labels.indexOf(label);
                String type = types.get(index);
                if (type.equals("disc")) {
                    HashMap conditCol = posteriors.get(label);
                    if (conditCol != null) { // trap unseen features
                        String value = values.get(index);
                        if (value == null || value == "" || conditCol.get(value) == null) {
                            System.out.println("Error in NaiveBayes.classify: " + label +
                                    " index: " + index + " values: " + values + " value: " + value);
                            System.out.println(conds.get(clss));
                        }
                        else {
                            float conditional = conditCol.get(value);
                            prob = prob * conditional;
                        }
                    }
                }
                if (type.equals("cont")) {
                    float value = Float.parseFloat(values.get(index));
                    float conditional = probDensFunc(clss,label,value);
                    prob = prob * conditional;
                }
            }
            probs.put(clss, new Float(prob));
            if (prob > maxProb) {
                maxProb = prob;
                maxClass = clss;
            }
        }
        //System.out.println("NaiveBayes.classify(): probabilities: " + probs);
        return maxClass;
    }

    /** *************************************************************
     */
    public void initialize() {

        System.out.println("NaiveBayes.initialize(): first line: " + input.get(0));
        createPriorCounts();
        System.out.println("NaiveBayes.initialize() : priorCounts: " + priorCounts);
        int sum = findTrainingSetSize();
        System.out.println("NaiveBayes.initialize() : sum: " + sum);
        // class name key value map of column number key
        createConditionalCounts();
        System.out.println("NaiveBayes.initialize() : conditionalCounts: " + conditionalCounts);
        createTotals();
        System.out.println("NaiveBayes.initialize() : totals: " + totals);
        createMeans();
        System.out.println("NaiveBayes.initialize() : means: " + means);
        createStandardDeviation();
        System.out.println("NaiveBayes.initialize() : standard deviation: " + ssd);
        calcPriors(sum);
        System.out.println("NaiveBayes.initialize() : priors: " + priors);
        calcConditionals(sum);
        System.out.println("NaiveBayes.initialize() : conds: " + conds);
    }

    /** *************************************************************
     * take a filename and a quoted list of numbers as arguments on
     * the command line
     */
    public static void main(String[] args) {

        /* Sample Data
        both,sedentary,moderate,yes,i100
        both,sedentary,moderate,no,i100
        health,sedentary,moderate,yes,i500
        appearance,active,moderate,yes,i500
        appearance,moderate,aggressive,yes,i500
        appearance,moderate,aggressive,no,i100
        health,moderate,aggressive,no,i500
        both,active,moderate,yes,i100
        both,moderate,aggressive,yes,i500
        appearance,active,aggressive,yes,i500
        both,active,aggressive,no,i500
        health,active,moderate,no,i500
        health,sedentary,aggressive,yes,i500
        appearance,active,moderate,no,i100
        health,sedentary,moderate,no,i100
         */

        // read from a file assuming a list of attributes and a class name last on each line
        DocGen dg = DocGen.getInstance();
        //NaiveBayes nb = new NaiveBayes("/home/apease/IPsoft/NB/NBdata.txt");
        //NaiveBayes nb = new NaiveBayes("/home/apease/IPsoft/NB/house-votes-84.data");
        NaiveBayes nb = null;
        ArrayList values = null;
        if (args.length >= 1) {
            nb = new NaiveBayes(args[0]);
            nb.initialize();
            for (int i = -1; i <= 10; i++) {
                for (int j = -1; j <= 10; j++) {
                    values = new ArrayList<>();
                    values.add(Integer.toString(i));
                    values.add(Integer.toString(j));
                    System.out.println(values.get(0) + ", " + values.get(1) + ", " + nb.classify(values));
                }
            }
        }
        else {
            nb = new NaiveBayes("/home/apease/IPsoft/NB/pima-indians-diabetes.data");
            values = Lists.newArrayList("4","111","72","47","207","37.1","1.390","56");
            nb.initialize();
            System.out.println("main(): most likely class: " + nb.classify(values));
        }
        //ArrayList values = Lists.newArrayList("health","moderate","moderate","yes","class");
        //ArrayList values = Lists.newArrayList("y","y","y","n","n","n","y","y","y","n","n","n","y","n","y","y");
        // ArrayList values = Lists.newArrayList("both","sedentary","aggressive","no","class");
        //ArrayList values = Lists.newArrayList("7","81","88","40","48","46.7","0.261","52");
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy