All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.tree.ExactSplitPoints Maven / Gradle / Ivy

There is a newer version: 3.46.0.6
Show newest version
package hex.tree;

import water.MRTask;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.util.IcedDouble;
import water.util.IcedHashSet;

import java.util.Arrays;
import java.util.Collection;
import java.util.HashSet;
import java.util.Set;

/**
 * Finds exact split points for low-cardinality columns.
 */
public class ExactSplitPoints extends MRTask {

    private final int _maxCardinality;
    private final IcedHashSet[] _values;

    static double[][] splitPoints(Frame trainFr, int maxCardinality) {
        final Frame fr = new Frame();
        final int[] frToTrain = new int[trainFr.numCols()];
        for (int i = 0; i < trainFr.numCols(); ++i) {
            if (!trainFr.vec(i).isNumeric() || trainFr.vec(i).isCategorical() ||
                    trainFr.vec(i).isBinary() || trainFr.vec(i).isConst()) {
                continue;
            }
            frToTrain[fr.numCols()] = i;
            fr.add(trainFr.name(i), trainFr.vec(i));
        }
        IcedHashSet[] values = new ExactSplitPoints(maxCardinality, fr.numCols())
                .doAll(fr)._values;
        double[][] splitPoints = new double[trainFr.numCols()][];
        for (int i = 0; i < values.length; i++) {
            if (values[i] == null) {
                continue;
            }
            double[] vals = new double[values[i].size()];
            int valsSize = 0;
            for (IcedDouble wrapper : values[i]) {
                vals[valsSize++] = wrapper._val;
            }
            assert valsSize == vals.length;
            Arrays.sort(vals);
            assert isUniqueSequence(vals);
            splitPoints[frToTrain[i]] = vals;
        }
        return splitPoints;
    }

    static boolean isUniqueSequence(double[] seq) {
        if (seq.length == 1)
            return true;
        double lastValue = seq[0];
        for (int i = 1; i < seq.length; i++) {
            if (lastValue >= seq[i])
                return false;
            lastValue = seq[i];
        }
        return true;
    }

    @SuppressWarnings("unchecked")
    private ExactSplitPoints(int maxCardinality, int nCols) {
        _maxCardinality = maxCardinality;
        _values = new IcedHashSet[nCols];
        for (int i = 0; i < _values.length; i++) {
            _values[i] = new IcedHashSet<>();
        }
    }

    @Override
    public void map(Chunk[] cs) {
        Set localValues = new HashSet<>(_maxCardinality);
        for (int col = 0; col < cs.length; col++) {
            localValues.clear();
            if (_values[col] == null)
                continue;
            Chunk c = cs[col];
            IcedDouble wrapper = new IcedDouble();
            for (int i = 0; i < c._len; i++) {
                double num = c.atd(i);
                if (Double.isNaN(num))
                    continue;
                if (wrapper._val == num)
                    continue;
                wrapper.setVal(num);
                if (localValues.add(wrapper)) {
                    if (localValues.size() > _maxCardinality) {
                        _values[col] = null;
                        break;
                    }
                    wrapper = new IcedDouble();
                }
            }
            merge(col, localValues);
        }
    }

    private void merge(int col, Collection localValues) {
        final Set allValues = _values[col];
        if (allValues == null)
            return;
        allValues.addAll(localValues);
        if (allValues.size() > _maxCardinality) {
            _values[col] = null;
        }
    }

    @Override
    public void reduce(ExactSplitPoints mrt) {
        if (mrt._values != _values) { // merging with a result from a different node
            for (int col = 0; col < _values.length; col++) {
                if (_values[col] == null || mrt._values[col] == null)
                    _values[col] = null;
                else {
                    merge(col, mrt._values[col]);
                }
            }
        } // else: nothing to do on the same node
    }

    @Override
    protected void postGlobal() {
        for (int col = 0; col < _values.length; col++) {
            if (_values[col] != null && _values[col].size() > _maxCardinality) {
                _values[col] = null;
            }
        }
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy