ciir.umass.edu.learning.tree.FeatureHistogram Maven / Gradle / Ivy
The newest version!
/*===============================================================================
* Copyright (c) 2010-2012 University of Massachusetts. All Rights Reserved.
*
* Use of the RankLib package is subject to the terms of the software license set
* forth in the LICENSE file included with this software, and also available at
* http://people.cs.umass.edu/~vdang/ranklib_license.html
*===============================================================================
*/
package ciir.umass.edu.learning.tree;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import ciir.umass.edu.learning.DataPoint;
import ciir.umass.edu.utilities.MyThreadPool;
import ciir.umass.edu.utilities.WorkerThread;
/**
* @author vdang
*/
public class FeatureHistogram {
class Config {
int featureIdx = -1;
int thresholdIdx = -1;
double S = -1;
double errReduced = -1;
}
//Parameter
public static float samplingRate = 1;
//Variables
public float[] accumFeatureImpact = null;
public int[] features = null;
public float[][] thresholds = null;
public double[][] sum = null;
public double sumResponse = 0;
public double sqSumResponse = 0;
public int[][] count = null;
public int[][] sampleToThresholdMap = null;
public double[] impacts;
//whether to re-use its parents @sum and @count instead of cleaning up the parent and re-allocate for the children.
//@sum and @count of any intermediate tree node (except for root) can be re-used.
private boolean reuseParent = false;
public FeatureHistogram() {
}
public void construct(final DataPoint[] samples, final double[] labels, final int[][] sampleSortedIdx, final int[] features,
final float[][] thresholds, final double[] impacts) {
this.features = features;
this.thresholds = thresholds;
this.impacts = impacts;
sumResponse = 0;
sqSumResponse = 0;
sum = new double[features.length][];
count = new int[features.length][];
sampleToThresholdMap = new int[features.length][];
final MyThreadPool p = MyThreadPool.getInstance();
if (p.size() == 1) {
construct(samples, labels, sampleSortedIdx, thresholds, 0, features.length - 1);
} else {
p.execute(new Worker(this, samples, labels, sampleSortedIdx, thresholds), features.length);
}
}
protected void construct(final DataPoint[] samples, final double[] labels, final int[][] sampleSortedIdx, final float[][] thresholds,
final int start, final int end) {
for (int i = start; i <= end; i++) {
final int fid = features[i];
//get the list of samples associated with this node (sorted in ascending order with respect to the current feature)
final int[] idx = sampleSortedIdx[i];
double sumLeft = 0;
final float[] threshold = thresholds[i];
final double[] sumLabel = new double[threshold.length];
final int[] c = new int[threshold.length];
final int[] stMap = new int[samples.length];
int last = -1;
for (int t = 0; t < threshold.length; t++) {
int j = last + 1;
//find the first sample that exceeds the current threshold
for (; j < idx.length; j++) {
final int k = idx[j];
if (samples[k].getFeatureValue(fid) > threshold[t]) {
break;
}
sumLeft += labels[k];
if (i == 0) {
sumResponse += labels[k];
sqSumResponse += labels[k] * labels[k];
}
stMap[k] = t;
}
last = j - 1;
sumLabel[t] = sumLeft;
c[t] = last + 1;
}
sampleToThresholdMap[i] = stMap;
sum[i] = sumLabel;
count[i] = c;
}
}
protected void update(final double[] labels) {
sumResponse = 0;
sqSumResponse = 0;
final MyThreadPool p = MyThreadPool.getInstance();
if (p.size() == 1) {
update(labels, 0, features.length - 1);
} else {
p.execute(new Worker(this, labels), features.length);
}
}
protected void update(final double[] labels, final int start, final int end) {
for (int f = start; f <= end; f++) {
Arrays.fill(sum[f], 0);
}
for (int k = 0; k < labels.length; k++) {
for (int f = start; f <= end; f++) {
final int t = sampleToThresholdMap[f][k];
sum[f][t] += labels[k];
if (f == 0) {
sumResponse += labels[k];
sqSumResponse += labels[k] * labels[k];
}
//count doesn't change, so no need to re-compute
}
}
for (int f = start; f <= end; f++) {
for (int t = 1; t < thresholds[f].length; t++) {
sum[f][t] += sum[f][t - 1];
}
}
}
public void construct(final FeatureHistogram parent, final int[] soi, final double[] labels) {
this.features = parent.features;
this.thresholds = parent.thresholds;
this.impacts = parent.impacts;
sumResponse = 0;
sqSumResponse = 0;
sum = new double[features.length][];
count = new int[features.length][];
sampleToThresholdMap = parent.sampleToThresholdMap;
final MyThreadPool p = MyThreadPool.getInstance();
if (p.size() == 1) {
construct(parent, soi, labels, 0, features.length - 1);
} else {
p.execute(new Worker(this, parent, soi, labels), features.length);
}
}
protected void construct(final FeatureHistogram parent, final int[] soi, final double[] labels, final int start, final int end) {
//init
for (int i = start; i <= end; i++) {
final float[] threshold = thresholds[i];
sum[i] = new double[threshold.length];
count[i] = new int[threshold.length];
Arrays.fill(sum[i], 0);
Arrays.fill(count[i], 0);
}
//update
for (final int k : soi) {
for (int f = start; f <= end; f++) {
final int t = sampleToThresholdMap[f][k];
sum[f][t] += labels[k];
count[f][t]++;
if (f == 0) {
sumResponse += labels[k];
sqSumResponse += labels[k] * labels[k];
}
}
}
for (int f = start; f <= end; f++) {
for (int t = 1; t < thresholds[f].length; t++) {
sum[f][t] += sum[f][t - 1];
count[f][t] += count[f][t - 1];
}
}
}
public void construct(final FeatureHistogram parent, final FeatureHistogram leftSibling, final boolean reuseParent) {
this.reuseParent = reuseParent;
this.features = parent.features;
this.thresholds = parent.thresholds;
this.impacts = parent.impacts;
sumResponse = parent.sumResponse - leftSibling.sumResponse;
sqSumResponse = parent.sqSumResponse - leftSibling.sqSumResponse;
if (reuseParent) {
sum = parent.sum;
count = parent.count;
} else {
sum = new double[features.length][];
count = new int[features.length][];
}
sampleToThresholdMap = parent.sampleToThresholdMap;
final MyThreadPool p = MyThreadPool.getInstance();
if (p.size() == 1) {
construct(parent, leftSibling, 0, features.length - 1);
} else {
p.execute(new Worker(this, parent, leftSibling), features.length);
}
}
protected void construct(final FeatureHistogram parent, final FeatureHistogram leftSibling, final int start, final int end) {
for (int f = start; f <= end; f++) {
final float[] threshold = thresholds[f];
if (!reuseParent) {
sum[f] = new double[threshold.length];
count[f] = new int[threshold.length];
}
for (int t = 0; t < threshold.length; t++) {
sum[f][t] = parent.sum[f][t] - leftSibling.sum[f][t];
count[f][t] = parent.count[f][t] - leftSibling.count[f][t];
}
}
}
protected Config findBestSplit(final int[] usedFeatures, final int minLeafSupport, final int start, final int end) {
final Config cfg = new Config();
final int totalCount = count[start][count[start].length - 1];
for (int f = start; f <= end; f++) {
final int i = usedFeatures[f];
final float[] threshold = thresholds[i];
for (int t = 0; t < threshold.length; t++) {
final int countLeft = count[i][t];
final int countRight = totalCount - countLeft;
if (countLeft < minLeafSupport || countRight < minLeafSupport) {
continue;
}
final double sumLeft = sum[i][t];
final double sumRight = sumResponse - sumLeft;
final double S = sumLeft * sumLeft / countLeft + sumRight * sumRight / countRight;
final double errST = (sqSumResponse / totalCount) * (S / totalCount);
if (cfg.S < S) {
cfg.S = S;
cfg.featureIdx = i;
cfg.thresholdIdx = t;
cfg.errReduced = errST;
}
}
}
return cfg;
}
public boolean findBestSplit(final Split sp, final double[] labels, final int minLeafSupport) {
if (sp.getDeviance() >= 0.0 && sp.getDeviance() <= 0.0) {
return false;//no need to split
}
int[] usedFeatures = null;//index of the features to be used for tree splitting
if (samplingRate < 1)//need to do sub sampling (feature sampling)
{
final int size = (int) (samplingRate * features.length);
usedFeatures = new int[size];
//put all features into a pool
final List fpool = new ArrayList<>();
for (int i = 0; i < features.length; i++) {
fpool.add(i);
}
//do sampling, without replacement
final Random r = new Random();
for (int i = 0; i < size; i++) {
final int sel = r.nextInt(fpool.size());
usedFeatures[i] = fpool.get(sel);
fpool.remove(sel);
}
} else//no sub-sampling, all features will be used
{
usedFeatures = new int[features.length];
for (int i = 0; i < features.length; i++) {
usedFeatures[i] = i;
}
}
//find the best split
Config best = new Config();
final MyThreadPool p = MyThreadPool.getInstance();
if (p.size() == 1) {
best = findBestSplit(usedFeatures, minLeafSupport, 0, usedFeatures.length - 1);
} else {
final WorkerThread[] workers = p.execute(new Worker(this, usedFeatures, minLeafSupport), usedFeatures.length);
for (final WorkerThread worker : workers) {
final Worker wk = (Worker) worker;
if (best.S < wk.cfg.S) {
best = wk.cfg;
}
}
}
if (best.S == -1) {
return false;
}
// bestFeaturesHist is the best features
final double[] bestFeaturesHist = sum[best.featureIdx];
final int[] sampleCount = count[best.featureIdx];
final double s = bestFeaturesHist[bestFeaturesHist.length - 1];
final int c = sampleCount[bestFeaturesHist.length - 1];
final double sumLeft = bestFeaturesHist[best.thresholdIdx];
final int countLeft = sampleCount[best.thresholdIdx];
final double sumRight = s - sumLeft;
final int countRight = c - countLeft;
final int[] left = new int[countLeft];
final int[] right = new int[countRight];
int l = 0;
int r = 0;
int k = 0;
final int[] idx = sp.getSamples();
for (final int element : idx) {
k = element;
if (sampleToThresholdMap[best.featureIdx][k] <= best.thresholdIdx) {
left[l++] = k;
} else {
right[r++] = k;
}
}
final FeatureHistogram lh = new FeatureHistogram();
lh.construct(sp.hist, left, labels);
final FeatureHistogram rh = new FeatureHistogram();
rh.construct(sp.hist, lh, !sp.isRoot());
final double var = sqSumResponse - sumResponse * sumResponse / idx.length;
final double varLeft = lh.sqSumResponse - lh.sumResponse * lh.sumResponse / left.length;
final double varRight = rh.sqSumResponse - rh.sumResponse * rh.sumResponse / right.length;
sp.set(features[best.featureIdx], thresholds[best.featureIdx][best.thresholdIdx], var);
sp.setLeft(new Split(left, lh, varLeft, sumLeft));
sp.setRight(new Split(right, rh, varRight, sumRight));
sp.clearSamples();
return true;
}
class Worker extends WorkerThread {
FeatureHistogram fh = null;
int type = -1;
//find best split (type == 0)
int[] usedFeatures = null;
int minLeafSup = -1;
Config cfg = null;
//update (type = 1)
double[] labels = null;
//construct (type = 2)
FeatureHistogram parent = null;
int[] soi = null;
//construct (type = 3)
FeatureHistogram leftSibling = null;
//construct (type = 4)
DataPoint[] samples;
int[][] sampleSortedIdx;
float[][] thresholds;
public Worker() {
}
public Worker(final FeatureHistogram fh, final int[] usedFeatures, final int minLeafSup) {
type = 0;
this.fh = fh;
this.usedFeatures = usedFeatures;
this.minLeafSup = minLeafSup;
}
public Worker(final FeatureHistogram fh, final double[] labels) {
type = 1;
this.fh = fh;
this.labels = labels;
}
public Worker(final FeatureHistogram fh, final FeatureHistogram parent, final int[] soi, final double[] labels) {
type = 2;
this.fh = fh;
this.parent = parent;
this.soi = soi;
this.labels = labels;
}
public Worker(final FeatureHistogram fh, final FeatureHistogram parent, final FeatureHistogram leftSibling) {
type = 3;
this.fh = fh;
this.parent = parent;
this.leftSibling = leftSibling;
}
public Worker(final FeatureHistogram fh, final DataPoint[] samples, final double[] labels, final int[][] sampleSortedIdx,
final float[][] thresholds) {
type = 4;
this.fh = fh;
this.samples = samples;
this.labels = labels;
this.sampleSortedIdx = sampleSortedIdx;
this.thresholds = thresholds;
}
@Override
public void run() {
if (type == 0) {
cfg = fh.findBestSplit(usedFeatures, minLeafSup, start, end);
} else if (type == 1) {
fh.update(labels, start, end);
} else if (type == 2) {
fh.construct(parent, soi, labels, start, end);
} else if (type == 3) {
fh.construct(parent, leftSibling, start, end);
} else if (type == 4) {
fh.construct(samples, labels, sampleSortedIdx, thresholds, start, end);
}
}
@Override
public WorkerThread clone() {
final Worker wk = new Worker();
wk.fh = fh;
wk.type = type;
//find best split (type == 0)
wk.usedFeatures = usedFeatures;
wk.minLeafSup = minLeafSup;
//update (type = 1)
wk.labels = labels;
//construct (type = 2)
wk.parent = parent;
wk.soi = soi;
//construct (type = 3)
wk.leftSibling = leftSibling;
//construct (type = 1)
wk.samples = samples;
wk.sampleSortedIdx = sampleSortedIdx;
wk.thresholds = thresholds;
return wk;
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy