All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.github.chen0040.trees.id3.ID3TreeNode Maven / Gradle / Ivy

package com.github.chen0040.trees.id3;

import com.github.chen0040.data.frame.BasicDataFrame;
import com.github.chen0040.data.frame.DataFrame;
import com.github.chen0040.data.frame.DataRow;
import com.github.chen0040.data.utils.CountRepository;

import java.util.*;


/**
 * Created by xschen on 17/8/15.
 */
public class ID3TreeNode {
    private int rowCount;
    private int splitAttributeIndex;
    private String attributeValue;
    private final List childNodes = new ArrayList<>();
    private String classLabel;
    private final List columns = new ArrayList<>();


    public ID3TreeNode(){

    }

    public ID3TreeNode(DataFrame batch, Random random, int height, int maxHeight, List columns){
        this.columns.addAll(columns);

        rowCount = batch.rowCount();
        splitAttributeIndex = -1;
        attributeValue = "";
        classLabel = "";
        updateClassLabel(batch);

        if(rowCount <= 1 || height == maxHeight){
            return;
        }


        int n = columns.size();

        CountRepository[] counts = new CountRepository[n];
        CountRepository counts2 = new CountRepository();

        for(int i=0; i < n; ++i){
            String category = String.format("field%d", i);
            counts[i] = new CountRepository(category);
        }


        for(int i=0; i < rowCount; ++i){
            DataRow tuple = batch.row(i);
            String label = tuple.categoricalTarget();
            String classEventName = "ClassLabel="+label;

            for(int j=0; j < n; ++j){
                String category_value = columns.get(j) + "=" + tuple.getCategoricalCell(columns.get(j));
                counts[j].addSupportCount(category_value, classEventName);
                counts[j].addSupportCount(category_value);
                counts[j].addSupportCount();
            }
            counts2.addSupportCount(classEventName);
            counts2.addSupportCount();
        }


        double entropy_S = 0;
        for(String classEventName : counts2.getSubEventNames()){
            double p_class = counts2.getProbability(classEventName);
            entropy_S += (-p_class * log2(p_class));
        }


        if(entropy_S == 0){ // perfectly classified
            return;
        }

        splitAttributeIndex =  -1;

        Map candidates = new HashMap<>();
        for(int i = 0; i < n; ++i){
            List T = counts[i].getSubEventNames();
            double entropy_reduced = 0;
            for(int j=0; j < T.size(); ++j) {

                String t = T.get(j);
                double p_t = counts[i].getProbability(t);
                List classNames = counts[i].getSubEventNames(t);
                double entropy_t = 0;
                for(int k=0; k < classNames.size(); ++k) {
                    double p_class_in_t = counts[i].getConditionalProbability(T.get(j), classNames.get(k));
                    entropy_t += (-p_class_in_t * log2(p_class_in_t));
                }
                entropy_reduced += p_t * entropy_t;
            }

            double information_gain = entropy_S - entropy_reduced;

            if(information_gain > 0){
                candidates.put(i, information_gain);
            }
        }

        if(candidates.isEmpty()){
            return;
        }

        double max_information_gain = 0;
        for(Integer candidateFeatureIndex : candidates.keySet()){
            double information_gain = candidates.get(candidateFeatureIndex);
            if(information_gain > max_information_gain){
                max_information_gain = information_gain;
                splitAttributeIndex = candidateFeatureIndex;
            }
        }

        List T = counts[splitAttributeIndex].getSubEventNames();

        DataFrame[] batches = new DataFrame[T.size()];

        for(int i=0; i < batches.length; ++i){
            batches[i] = new BasicDataFrame();
        }

        for(int i=0; i < rowCount; ++i){
            DataRow row = batch.row(i);
            int j = splitAttributeIndex;
            String attribute_value = columns.get(j) + "=" + row.getCategoricalCell(columns.get(j));
            batches[T.indexOf(attribute_value)].addRow(row);
        }


        for(int i=0; i < batches.length; ++i){
            batches[i].lock();

            childNodes.add(new ID3TreeNode(batches[i], random, height+1, maxHeight, columns));
            childNodes.get(i).attributeValue = T.get(i);
        }
    }

    public static double heuristicCost(double n){
        if(n <= 1.0) return 0;
        return 2 * (Math.log(n - 1) + 0.5772156649) - (2 * (n - 1) / n);
    }

    private double log2(double val){
        return Math.log(val) / Math.log(2);
    }

    private void updateClassLabel(DataFrame batch){
        HashMap classLabelCounts = new HashMap();
        for(int i = 0; i < batch.rowCount(); ++i){
            String label = batch.row(i).categoricalTarget();
            classLabelCounts.put(label, classLabelCounts.containsKey(label) ? classLabelCounts.get(label)+1 : 1);
        }
        int maxCount = 0;
        for(String label : classLabelCounts.keySet()){
            if(classLabelCounts.get(label) > maxCount){
                maxCount = classLabelCounts.get(label);
                classLabel = label;
            }
        }
        //System.out.println("label: "+classLabel+"\tcount: "+maxCount);
    }

    public String predict(DataRow row){
        if(!childNodes.isEmpty()){
            int j = splitAttributeIndex;
            String value = columns.get(j) + "=" + row.getCategoricalCell(columns.get(j));

            for(ID3TreeNode child : childNodes){

                if(child.attributeValue.equals(value)){
                    return child.predict(row);
                }
            }
        }
        return classLabel;
    }


    protected double pathLength(DataRow row){
        if(!childNodes.isEmpty()){
            String value = columns.get(splitAttributeIndex);
            for(ID3TreeNode child : childNodes){
                if(child.attributeValue.equals(value)){
                    return child.pathLength(row)+1.0;
                }
            }
        }

        return heuristicCost(rowCount);
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy