All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.mahout.classifier.df.split.DefaultIgSplit Maven / Gradle / Ivy

/**
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.mahout.classifier.df.split;

import org.apache.mahout.classifier.df.data.Data;
import org.apache.mahout.classifier.df.data.conditions.Condition;

import java.util.Arrays;

/**
 * Default, not optimized, implementation of IgSplit
 */
public class DefaultIgSplit extends IgSplit {
  
  /** used by entropy() */
  private int[] counts;
  
  @Override
  public Split computeSplit(Data data, int attr) {
    if (data.getDataset().isNumerical(attr)) {
      double[] values = data.values(attr);
      double bestIg = -1;
      double bestSplit = 0.0;
      
      for (double value : values) {
        double ig = numericalIg(data, attr, value);
        if (ig > bestIg) {
          bestIg = ig;
          bestSplit = value;
        }
      }
      
      return new Split(attr, bestIg, bestSplit);
    } else {
      double ig = categoricalIg(data, attr);
      
      return new Split(attr, ig);
    }
  }
  
  /**
   * Computes the Information Gain for a CATEGORICAL attribute
   */
  double categoricalIg(Data data, int attr) {
    double[] values = data.values(attr);
    double hy = entropy(data); // H(Y)
    double hyx = 0.0; // H(Y|X)
    double invDataSize = 1.0 / data.size();
    
    for (double value : values) {
      Data subset = data.subset(Condition.equals(attr, value));
      hyx += subset.size() * invDataSize * entropy(subset);
    }
    
    return hy - hyx;
  }
  
  /**
   * Computes the Information Gain for a NUMERICAL attribute given a splitting value
   */
  double numericalIg(Data data, int attr, double split) {
    double hy = entropy(data);
    double invDataSize = 1.0 / data.size();
    
    // LO subset
    Data subset = data.subset(Condition.lesser(attr, split));
    hy -= subset.size() * invDataSize * entropy(subset);
    
    // HI subset
    subset = data.subset(Condition.greaterOrEquals(attr, split));
    hy -= subset.size() * invDataSize * entropy(subset);
    
    return hy;
  }
  
  /**
   * Computes the Entropy
   */
  protected double entropy(Data data) {
    double invDataSize = 1.0 / data.size();
    
    if (counts == null) {
      counts = new int[data.getDataset().nblabels()];
    }
    
    Arrays.fill(counts, 0);
    data.countLabels(counts);
    
    double entropy = 0.0;
    for (int label = 0; label < data.getDataset().nblabels(); label++) {
      int count = counts[label];
      if (count == 0) {
        continue; // otherwise we get a NaN
      }
      double p = count * invDataSize;
      entropy += -p * Math.log(p) / LOG2;
    }
    
    return entropy;
  }
  
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy