All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.mahout.classifier.df.split.OptIgSplit Maven / Gradle / Ivy

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.mahout.classifier.df.split;

import org.apache.commons.math3.stat.descriptive.rank.Percentile;
import org.apache.mahout.classifier.df.data.Data;
import org.apache.mahout.classifier.df.data.DataUtils;
import org.apache.mahout.classifier.df.data.Dataset;
import org.apache.mahout.classifier.df.data.Instance;

import java.util.Arrays;
import java.util.Collection;
import java.util.Iterator;
import java.util.TreeSet;

/**
 * 

Optimized implementation of IgSplit. * This class can be used when the criterion variable is the categorical attribute.

* *

This code was changed in MAHOUT-1419 to deal in sampled splits among numeric * features to fix a performance problem. To generate some synthetic data that exercises * the issue, try for example generating 4 features of Normal(0,1) values with a random * boolean 0/1 categorical feature. In Scala:

* * {@code * val r = new scala.util.Random() * val pw = new java.io.PrintWriter("random.csv") * (1 to 10000000).foreach(e => * pw.println(r.nextDouble() + "," + * r.nextDouble() + "," + * r.nextDouble() + "," + * r.nextDouble() + "," + * (if (r.nextBoolean()) 1 else 0)) * ) * pw.close() * } */ public class OptIgSplit extends IgSplit { private static final int MAX_NUMERIC_SPLITS = 16; @Override public Split computeSplit(Data data, int attr) { if (data.getDataset().isNumerical(attr)) { return numericalSplit(data, attr); } else { return categoricalSplit(data, attr); } } /** * Computes the split for a CATEGORICAL attribute */ private static Split categoricalSplit(Data data, int attr) { double[] values = data.values(attr).clone(); double[] splitPoints = chooseCategoricalSplitPoints(values); int numLabels = data.getDataset().nblabels(); int[][] counts = new int[splitPoints.length][numLabels]; int[] countAll = new int[numLabels]; computeFrequencies(data, attr, splitPoints, counts, countAll); int size = data.size(); double hy = entropy(countAll, size); // H(Y) double hyx = 0.0; // H(Y|X) double invDataSize = 1.0 / size; for (int index = 0; index < splitPoints.length; index++) { size = DataUtils.sum(counts[index]); hyx += size * invDataSize * entropy(counts[index], size); } double ig = hy - hyx; return new Split(attr, ig); } static void computeFrequencies(Data data, int attr, double[] splitPoints, int[][] counts, int[] countAll) { Dataset dataset = data.getDataset(); for (int index = 0; index < data.size(); index++) { Instance instance = data.get(index); int label = (int) dataset.getLabel(instance); double value = instance.get(attr); int split = 0; while (split < splitPoints.length && value > splitPoints[split]) { split++; } if (split < splitPoints.length) { counts[split][label]++; } // Otherwise it's in the last split, which we don't need to count countAll[label]++; } } /** * Computes the best split for a NUMERICAL attribute */ static Split numericalSplit(Data data, int attr) { double[] values = data.values(attr).clone(); Arrays.sort(values); double[] splitPoints = chooseNumericSplitPoints(values); int numLabels = data.getDataset().nblabels(); int[][] counts = new int[splitPoints.length][numLabels]; int[] countAll = new int[numLabels]; int[] countLess = new int[numLabels]; computeFrequencies(data, attr, splitPoints, counts, countAll); int size = data.size(); double hy = entropy(countAll, size); double invDataSize = 1.0 / size; int best = -1; double bestIg = -1.0; // try each possible split value for (int index = 0; index < splitPoints.length; index++) { double ig = hy; DataUtils.add(countLess, counts[index]); DataUtils.dec(countAll, counts[index]); // instance with attribute value < values[index] size = DataUtils.sum(countLess); ig -= size * invDataSize * entropy(countLess, size); // instance with attribute value >= values[index] size = DataUtils.sum(countAll); ig -= size * invDataSize * entropy(countAll, size); if (ig > bestIg) { bestIg = ig; best = index; } } if (best == -1) { throw new IllegalStateException("no best split found !"); } return new Split(attr, bestIg, splitPoints[best]); } /** * @return an array of values to split the numeric feature's values on when * building candidate splits. When input size is <= MAX_NUMERIC_SPLITS + 1, it will * return the averages between success values as split points. When larger, it will * return MAX_NUMERIC_SPLITS approximate percentiles through the data. */ private static double[] chooseNumericSplitPoints(double[] values) { if (values.length <= 1) { return values; } if (values.length <= MAX_NUMERIC_SPLITS + 1) { double[] splitPoints = new double[values.length - 1]; for (int i = 1; i < values.length; i++) { splitPoints[i-1] = (values[i] + values[i-1]) / 2.0; } return splitPoints; } Percentile distribution = new Percentile(); distribution.setData(values); double[] percentiles = new double[MAX_NUMERIC_SPLITS]; for (int i = 0 ; i < percentiles.length; i++) { double p = 100.0 * ((i + 1.0) / (MAX_NUMERIC_SPLITS + 1.0)); percentiles[i] = distribution.evaluate(p); } return percentiles; } private static double[] chooseCategoricalSplitPoints(double[] values) { // There is no great reason to believe that categorical value order matters, // but the original code worked this way, and it's not terrible in the absence // of more sophisticated analysis Collection uniqueOrderedCategories = new TreeSet(); for (double v : values) { uniqueOrderedCategories.add(v); } double[] uniqueValues = new double[uniqueOrderedCategories.size()]; Iterator it = uniqueOrderedCategories.iterator(); for (int i = 0; i < uniqueValues.length; i++) { uniqueValues[i] = it.next(); } return uniqueValues; } /** * Computes the Entropy * * @param counts counts[i] = numInstances with label i * @param dataSize numInstances */ private static double entropy(int[] counts, int dataSize) { if (dataSize == 0) { return 0.0; } double entropy = 0.0; for (int count : counts) { if (count > 0) { double p = count / (double) dataSize; entropy -= p * Math.log(p); } } return entropy / LOG2; } }




© 2015 - 2024 Weber Informatics LLC | Privacy Policy