All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.bigml.mimir.forest.Path Maven / Gradle / Ivy

The newest version!
package org.bigml.mimir.forest;

public class Path {
    public static void doCopy(double[] src, double[] dest, int udp1) {
        System.arraycopy(src, udp1, dest, 0, src.length - udp1);
        System.arraycopy(src, 0, dest, 0, udp1);
    }

    public static double[][] newPath(double[][] path, int uniqueOrMaxDepth) {
        double[][] newPath = null;

        if (path == null) {
            int maxDepth = uniqueOrMaxDepth + 2;
            int s = (maxDepth * (maxDepth + 1)) / 2;
            newPath = new double[4][s];
        }
        else {
            newPath = new double[4][path[0].length];
            int udp1 = uniqueOrMaxDepth + 1;

            for (int i = 0; i < 4; i++)
                doCopy(path[i], newPath[i], udp1);
        }

        return newPath;
    }

    public static void pathSum(
            double[] value,
            double[][] phi,
            int udepth,
            double[] featureIdxs,
            double[] zeroFracs,
            double[] oneFracs,
            double[] pweights) {

        int udp1 = udepth + 1, i, j, pathFidx, lenPhiFid;
        double oneFrac, zeroFrac, fracDiff, nextOne, total, tmp, udmj;

        for (i = 1; i < udp1; i++) {
            oneFrac = oneFracs[i];
            zeroFrac = zeroFracs[i];

            nextOne = pweights[udepth];
            total = 0;

            for (j = udepth - 1; j >= 0; j--) {
                udmj = udepth - j;

                if (oneFrac != 0.0) {
                    tmp = nextOne * udp1 / ((j + 1.0) * oneFrac);
                    total += tmp;
                    nextOne = pweights[j] - tmp * zeroFrac * (udmj / udp1);
                }
                else if (zeroFrac != 0.0) {
                    total += (pweights[j] / zeroFrac) / (udmj / udp1);
                }
            }

            pathFidx = (int)featureIdxs[i];
            lenPhiFid = phi[pathFidx].length;
            fracDiff = oneFrac - zeroFrac;

            for (j = 0; j < lenPhiFid; j++) {
                phi[pathFidx][j] += total * fracDiff * value[j];
            }
        }
    }

    public static void pathExtend(
            int fi,
            int d,
            double zeroFrac,
            double oneFrac,
            double[] featureIdxs,
            double[] zeroFracs,
            double[] oneFracs,
            double[] pwts) {

        double udp1 = d + 1.0, pwi;

        featureIdxs[d] = fi;
        zeroFracs[d] = zeroFrac;
        oneFracs[d] = oneFrac;

        if (d == 0) pwts[d] = 1.0;
        else pwts[d] = 0.0;

        for (int i = d - 1; i >= 0; i--) {
            pwi = pwts[i];
            pwts[i + 1] += oneFrac * pwi * (i + 1.0) / udp1;
            pwts[i] = zeroFrac * pwi * (d - i) / udp1;

        }
    }

    public static double[] maybeUnwind(
            int udepth,
            int splitIndex,
            double[] featureIdxs,
            double[] zeroFracs,
            double[] oneFracs,
            double[] pwts) {

        int i, udmi, pathIndex = 0, udp1 = udepth + 1;
        double zeroFrac, oneFrac, nextOne, tmpPwi;

        while (pathIndex <= udepth && featureIdxs[pathIndex] != splitIndex)
            pathIndex++;

        if (pathIndex != udp1) {
            zeroFrac = zeroFracs[pathIndex];
            oneFrac = oneFracs[pathIndex];

            nextOne = pwts[udepth];

            for (i = udepth - 1; i >= 0; i--) {
                udmi = udepth - i;
                tmpPwi = pwts[i];

                if (oneFrac != 0.0) {
                    pwts[i] = nextOne * udp1 / ((i + 1.0) * oneFrac);
                    nextOne = tmpPwi - pwts[i] * zeroFrac * udmi / udp1;
                }
                else {
                    pwts[i] = (tmpPwi * udp1) / (zeroFrac * udmi);
                }
            }

            for (i = pathIndex; i < udepth; i++) {
                featureIdxs[i] = featureIdxs[i + 1];
                zeroFracs[i] = zeroFracs[i + 1];
                oneFracs[i] = oneFracs[i + 1];
            }

            return new double[]{udepth - 1, zeroFrac, oneFrac};
        }
        else {
            return new double[]{udepth, 1.0, 1.0};
        }
    }

    public static void shapForNode(
            double[] x,
            double[][] phi,
            ShapNode currentNode,
            int uniqueDepth,
            double[][] parentPath,
            double zeroFraction,
            double oneFraction,
            int featureIndex) {

        double[][] path;
        double[] result;
        double currentWeight, inZF, inOF, hotOF, hotZF, coldOF, coldZF;
        double hotFraction, coldFraction;
        int splitIndex, nextIndex, nextDepth, si, nd, uidx, depth;
        ShapNode hotNode, coldNode, phantomNode, nextNode;
        MultipredicateNode multi;

        if (currentNode.isMultipredicate) {
            multi = (MultipredicateNode)currentNode;
            nextNode = multi._nextNode;
            uidx = multi.getFirstUnsatisfied(x);

            depth = uniqueDepth;
            path = parentPath;
            inZF = zeroFraction;
            inOF = oneFraction;

            if (uidx < 0) {
                si = featureIndex;
                shapForNode(x, phi, nextNode, depth, path, inZF, inOF, si);
            }
            else {
                phantomNode = new LeafNode(currentNode.objective, 0);
                shapForNode(x, phi, phantomNode, depth, path, 0.0, inOF, uidx);
                shapForNode(x, phi, nextNode, depth, path, inZF, 0.0, uidx);
            }
        }
        else {
            path = newPath(parentPath, uniqueDepth);
            // extend the unique path
            pathExtend(
                    featureIndex,
                    uniqueDepth,
                    zeroFraction,
                    oneFraction,
                    path[0],
                    path[1],
                    path[2],
                    path[3]);

            if (currentNode.isLeaf) {
                pathSum(currentNode.objective,
                        phi,
                        uniqueDepth,
                        path[0],
                        path[1],
                        path[2],
                        path[3]);
            }
            else {
                splitIndex = currentNode.splitIndex;
                nextIndex = currentNode.nextIndex(x);

                // Missing or out-of-sample value; create a phantom leaf
                // node and treat both branches the same
                if (nextIndex == -1) {
                    phantomNode = new LeafNode(currentNode.objective, 0);
                }
                else {
                    phantomNode = null;
                }

                if (nextIndex == 0 || nextIndex == -1) {
                    hotNode = currentNode.left;
                    coldNode = currentNode.right;
                }
                else if (nextIndex == 1) {
                    hotNode = currentNode.right;
                    coldNode = currentNode.left;
                }
                else {
                    throw new RuntimeException("Next index is " + nextIndex);
                }

                currentWeight = currentNode.weight;

                if (currentWeight > 0) {
                    hotFraction = hotNode.weight / currentWeight;
                    coldFraction = coldNode.weight / currentWeight;
                }
                else {
                    hotFraction = coldFraction = 0.0;
                }

                result = maybeUnwind(
                        uniqueDepth,
                        splitIndex,
                        path[0],
                        path[1],
                        path[2],
                        path[3]);

                nextDepth = (int)result[0];
                inZF = result[1];
                inOF = result[2];

                si = splitIndex;
                nd = nextDepth + 1;

                coldZF = coldFraction * inZF;
                coldOF = 0.0;
                hotZF = hotFraction * inZF;
                hotOF = inOF;

                if (phantomNode != null) {
                    shapForNode(x, phi, phantomNode, nd, path, 0.0, inOF, si);
                    hotOF = 0.0;
                }

                if (hotZF > 0 || hotOF > 0)
                    shapForNode(x, phi, hotNode, nd, path, hotZF, hotOF, si);

                if (coldZF > 0 || coldOF > 0)
                    shapForNode(x, phi, coldNode, nd, path, coldZF, coldOF, si);
            }
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy