All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.tree.isoforextended.isolationtree.IsolationTree Maven / Gradle / Ivy

package hex.tree.isoforextended.isolationtree;

import org.apache.log4j.Level;
import org.apache.log4j.Logger;
import water.util.ArrayUtils;
import water.util.RandomUtils;

import java.util.Random;

/**
 * IsolationTree class implements Algorithm 2 (iTree)
 * Naming convention comes from the Extended Isolation Forest paper.
 *
 * @author Adam Valenta
 */
public class IsolationTree {
    private static final Logger LOG = Logger.getLogger(IsolationTree.class);

    private Node[] _nodes;

    private final int _heightLimit;
    private final int _extensionLevel;

    private int _isolatedPoints = 0;
    private long _notIsolatedPoints = 0;
    private int _zeroSplits = 0;
    private int _leaves = 0;
    private int _depth = 0;

    public IsolationTree(int _heightLimit, int _extensionLevel) {
        this._heightLimit = _heightLimit;
        this._extensionLevel = _extensionLevel;
    }

    /**
     * Implementation of Algorithm 2 (iTree) from paper.
     */
    public CompressedIsolationTree buildTree(double[][] data, final long seed, final int treeNum) {
        int maxNumNodesInTree = (int) Math.pow(2, _heightLimit + 1) - 1;
        _isolatedPoints = 0;
        _notIsolatedPoints = 0;
        _zeroSplits = 0;
        _leaves = 0;
        _depth = 0;
        this._nodes = new Node[maxNumNodesInTree];
        CompressedIsolationTree compressedIsolationTree = new CompressedIsolationTree(_heightLimit);
        
        _nodes[0] = new Node(data, data[0].length, 0);
        for (int i = 0; i < _nodes.length; i++) {
            LOG.trace((i + 1) + " from " + _nodes.length + " is being prepared on tree " + treeNum);
            Node node = _nodes[i];
            if (node == null || node._external) {
                continue;
            }
            double[][] nodeData = node._data;
            int currentHeight = node._height;
            if (node._height >= _heightLimit || nodeData[0].length <= 1) {
                node._external = true;
                node._numRows = nodeData[0].length;
                node._height = currentHeight;
                node._data = null; // attempt to inform Java GC the data are not longer needed
                compressedIsolationTree.getNodes()[i] = new CompressedLeaf(node);
                if (nodeData[0].length == 1)
                    _isolatedPoints++;
                if (nodeData[0].length > 1)
                    _notIsolatedPoints += node._numRows;
                _leaves++;
            } else {
                if (rightChildIndex(i) < _nodes.length) {
                    currentHeight++;
                    _depth = currentHeight;

                    node._p = ArrayUtils.uniformDistFromArray(nodeData, seed + i);
                    node._n = gaussianVector(
                            nodeData.length, nodeData.length - _extensionLevel - 1, seed + i);

                    FilteredData ret = extendedIsolationForestSplit(nodeData, node._p, node._n);

                    compressedIsolationTree.getNodes()[i] = new CompressedNode(node);
                    if (ret.left != null) {
                        _nodes[leftChildIndex(i)] = new Node(ret.left, ret.left[0].length, currentHeight);
                        compressedIsolationTree.getNodes()[leftChildIndex(i)] = new CompressedNode(_nodes[leftChildIndex(i)]);
                    } else {
                        _nodes[leftChildIndex(i)] = new Node(null, 0, currentHeight);
                        _nodes[leftChildIndex(i)]._external = true;
                        compressedIsolationTree.getNodes()[leftChildIndex(i)] = new CompressedLeaf(_nodes[leftChildIndex(i)]);
                        _leaves++;
                        _zeroSplits++;
                    }
                    if (ret.right != null) {
                        _nodes[rightChildIndex(i)] = new Node(ret.right, ret.right[0].length, currentHeight);
                        compressedIsolationTree.getNodes()[rightChildIndex(i)] = new CompressedNode(_nodes[rightChildIndex(i)]);
                    } else {
                        _nodes[rightChildIndex(i)] = new Node(null, 0, currentHeight);
                        _nodes[rightChildIndex(i)]._external = true;
                        compressedIsolationTree.getNodes()[rightChildIndex(i)] = new CompressedLeaf(_nodes[rightChildIndex(i)]);
                        _leaves++;
                        _zeroSplits++;
                    }
                } else {
                    compressedIsolationTree.getNodes()[i] = new CompressedLeaf(node);
                    _leaves++;
                }
                node._data = null; // attempt to inform Java GC the data are not longer needed
            }
        }
        return compressedIsolationTree;
    }

    private int leftChildIndex(int i) {
        return 2 * i + 1;
    }

    private int rightChildIndex(int i) {
        return 2 * i + 2;
    }

    /**
     * Helper method. Print nodes' size of the tree.
     */
    public void logNodesNumRows(Level level) {
        StringBuilder logMessage = new StringBuilder();
        for (int i = 0; i < _nodes.length; i++) {
            if (_nodes[i] == null)
                logMessage.append(". ");
            else
                logMessage.append(_nodes[i]._numRows + " ");
        }
        LOG.log(level, logMessage.toString());
    }

    /**
     * Helper method. Print height (length of path from root) of each node in trees. Root is 0.
     */
    public void logNodesHeight(Level level) {
        StringBuilder logMessage = new StringBuilder();
        for (int i = 0; i < _nodes.length; i++) {
            if (_nodes[i] == null)
                logMessage.append(". ");
            else
                logMessage.append(_nodes[i]._height + " ");
        }
        LOG.log(level, logMessage.toString());
    }

    /**
     * IsolationTree Node. Naming convention comes from Algorithm 2 (iTree) in paper.
     * _data should be always null after buildTree() method because only number of rows in data is needed for
     * scoring (evaluation) stage.
     */
    public static class Node {

        /**
         * Data in this node. After computation should be null, because only _numRows is important.
         */
        private double[][] _data;

        /**
         * Random slope
         */
        private double[] _n;

        /**
         * Random intercept point
         */
        private double[] _p;

        private int _height;
        private boolean _external = false;
        private int _numRows;

        public Node(double[][] data, int numRows, int currentHeight) {
            this._data = data;
            this._numRows = numRows;
            this._height = currentHeight;
        }

        public double[] getN() {
            return _n;
        }

        public double[] getP() {
            return _p;
        }

        public int getHeight() {
            return _height;
        }

        public int getNumRows() {
            return _numRows;
        }
    }

    /**
     * Compute Extended Isolation Forest split point and filter input data with this split point in the same time.
     * 

* See Algorithm 2 (iTree) in the paper. * * @return Object containing data for Left and Right branch of the tree. */ public static FilteredData extendedIsolationForestSplit(double[][] data, double[] p, double[] n) { double[] res = new double[data[0].length]; int leftLength = 0; int rightLength = 0; for (int row = 0; row < data[0].length; row++) { for (int col = 0; col < data.length; col++) { res[row] += (data[col][row] - p[col]) * n[col]; } if (res[row] <= 0) { leftLength++; } else { rightLength++; } } double[][] left = null; if (leftLength > 0) { left = new double[data.length][leftLength]; } double[][] right = null; if (rightLength > 0) { right = new double[data.length][rightLength]; } for (int row = 0, rowLeft = 0, rowRight = 0; row < data[0].length; row++) { if (res[row] <= 0) { for (int col = 0; col < data.length; col++) { left[col][rowLeft] = data[col][row]; } rowLeft++; } else { for (int col = 0; col < data.length; col++) { right[col][rowRight] = data[col][row]; } rowRight++; } } return new FilteredData(left, right); } public static class FilteredData { private final double[][] left; private final double[][] right; public FilteredData(double[][] left, double[][] right) { this.left = left; this.right = right; } public double[][] getLeft() { return left; } public double[][] getRight() { return right; } } /** * Make a new array initialized to random Gaussian N(0,1) values with the given seed. * Make randomly selected {@code zeroNum} items zeros (based on extensionLevel value). * * @param n length of generated vector * @param zeroNum set randomly selected {@code zeroNum} items of vector to zero * @return array with gaussian values. Randomly selected {@code zeroNum} item values are zeros. */ public static double[] gaussianVector(int n, int zeroNum, long seed) { double[] gaussian = ArrayUtils.gaussianVector(n, seed); Random r = RandomUtils.getRNG(seed); while (zeroNum > 0) { int pos = r.nextInt(n); if (!Double.isNaN(gaussian[pos])) { gaussian[pos] = Double.NaN; zeroNum--; } } for (int i = 0; i < gaussian.length; i++) { if (Double.isNaN(gaussian[i])) gaussian[i] = 0; } return gaussian; } public int getIsolatedPoints() { return _isolatedPoints; } public long getNotIsolatedPoints() { return _notIsolatedPoints; } public int getZeroSplits() { return _zeroSplits; } public int getLeaves() { return _leaves; } public int getDepth() { return _depth; } }





© 2015 - 2025 Weber Informatics LLC | Privacy Policy