ciir.umass.edu.learning.tree.FeatureHistogram Maven / Gradle / Ivy
/*===============================================================================
* Copyright (c) 2010-2012 University of Massachusetts. All Rights Reserved.
*
* Use of the RankLib package is subject to the terms of the software license set
* forth in the LICENSE file included with this software, and also available at
* http://people.cs.umass.edu/~vdang/ranklib_license.html
*===============================================================================
*/
package ciir.umass.edu.learning.tree;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import ciir.umass.edu.learning.DataPoint;
import ciir.umass.edu.utilities.MyThreadPool;
import ciir.umass.edu.utilities.WorkerThread;
/**
* @author vdang
*/
public class FeatureHistogram {
class Config {
int featureIdx = -1;
int thresholdIdx = -1;
double S = -1;
}
//Parameter
public static float samplingRate = 1;
//Variables
public int[] features = null;
public float[][] thresholds = null;
public double[][] sum = null;
public double sumResponse = 0;
public double sqSumResponse = 0;
public int[][] count = null;
public int[][] sampleToThresholdMap = null;
//whether to re-use its parents @sum and @count instead of cleaning up the parent and re-allocate for the children.
//@sum and @count of any intermediate tree node (except for root) can be re-used.
private boolean reuseParent = false;
public FeatureHistogram()
{
}
public void construct(DataPoint[] samples, double[] labels, int[][] sampleSortedIdx, int[] features, float[][] thresholds)
{
this.features = features;
this.thresholds = thresholds;
sumResponse = 0;
sqSumResponse = 0;
sum = new double[features.length][];
count = new int[features.length][];
sampleToThresholdMap = new int[features.length][];
MyThreadPool p = MyThreadPool.getInstance();
if(p.size() == 1)
construct(samples, labels, sampleSortedIdx, thresholds, 0, features.length-1);
else
p.execute(new Worker(this, samples, labels, sampleSortedIdx, thresholds), features.length);
}
protected void construct(DataPoint[] samples, double[] labels, int[][] sampleSortedIdx, float[][] thresholds, int start, int end)
{
for(int i=start;i<=end;i++)
{
int fid = features[i];
//get the list of samples associated with this node (sorted in ascending order with respect to the current feature)
int[] idx = sampleSortedIdx[i];
double sumLeft = 0;
float[] threshold = thresholds[i];
double[] sumLabel = new double[threshold.length];
int[] c = new int[threshold.length];
int[] stMap = new int[samples.length];
int last = -1;
for(int t=0;t threshold[t])
break;
sumLeft += labels[k];
if(i == 0)
{
sumResponse += labels[k];
sqSumResponse += labels[k] * labels[k];
}
stMap[k] = t;
}
last = j-1;
sumLabel[t] = sumLeft;
c[t] = last+1;
}
sampleToThresholdMap[i] = stMap;
sum[i] = sumLabel;
count[i] = c;
}
}
protected void update(double[] labels)
{
sumResponse = 0;
sqSumResponse = 0;
MyThreadPool p = MyThreadPool.getInstance();
if(p.size() == 1)
update(labels, 0, features.length-1);
else
p.execute(new Worker(this, labels), features.length);
}
protected void update(double[] labels, int start, int end)
{
for(int f=start;f<=end;f++)
Arrays.fill(sum[f], 0);
for(int k=0;k= 0.0 && sp.getDeviance() <= 0.0)//equals 0
return false;//no need to split
int[] usedFeatures = null;//index of the features to be used for tree splitting
if(samplingRate < 1)//need to do sub sampling (feature sampling)
{
int size = (int)(samplingRate * features.length);
usedFeatures = new int[size];
//put all features into a pool
List fpool = new ArrayList();
for(int i=0;i= sp.getDeviance())
//return null;
double[] sumLabel = sum[best.featureIdx];
int[] sampleCount = count[best.featureIdx];
double s = sumLabel[sumLabel.length-1];
int c = sampleCount[sumLabel.length-1];
double sumLeft = sumLabel[best.thresholdIdx];
int countLeft = sampleCount[best.thresholdIdx];
double sumRight = s - sumLeft;
int countRight = c - countLeft;
int[] left = new int[countLeft];
int[] right = new int[countRight];
int l = 0;
int r = 0;
int k = 0;
int[] idx = sp.getSamples();
for(int j=0;j