All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.flipkart.fdp.ml.adapter.DecisionTreeModelInfoAdapter Maven / Gradle / Ivy

There is a newer version: 0.1.2
Show newest version
package com.flipkart.fdp.ml.adapter;

import com.flipkart.fdp.ml.modelinfo.DecisionTreeModelInfo;
import com.flipkart.fdp.ml.modelinfo.DecisionTreeModelInfo.DecisionNode;
import lombok.extern.slf4j.Slf4j;
import org.apache.spark.mllib.tree.model.DecisionTreeModel;
import org.apache.spark.mllib.tree.model.Node;
import org.apache.spark.mllib.tree.model.Split;
import org.apache.spark.sql.DataFrame;

import java.util.Stack;

/**
 * Transforms Spark's {@link DecisionTreeModel} in MlLib to  {@link com.flipkart.fdp.ml.modelinfo.DecisionTreeModelInfo} object
 * that can be exported through {@link com.flipkart.fdp.ml.export.ModelExporter}
 */
@Slf4j
public class DecisionTreeModelInfoAdapter
        extends AbstractModelInfoAdapter {

    private void visit(final Node node, final Stack nodesToVisit, final DecisionTreeModelInfo treeInfo) {
        final DecisionNode nodeInfo = new DecisionNode();
        nodeInfo.setId(node.id());
        nodeInfo.setLeaf(node.isLeaf());
        if (node.split().nonEmpty()) {
            final Split split = node.split().get();
            nodeInfo.setFeature(split.feature());
            nodeInfo.setThreshold(split.threshold());
            nodeInfo.setFeatureType(split.featureType().toString());
        }
        nodeInfo.setPredict(node.predict().predict());
        nodeInfo.setProbability(node.predict().prob());
        treeInfo.getNodeInfo().put(node.id(), nodeInfo);
        if (node.rightNode().nonEmpty()) {
            final Node right = node.rightNode().get();
            treeInfo.getRightChildMap().put(node.id(), right.id());
            nodesToVisit.push(right);
        }
        if (node.leftNode().nonEmpty()) {
            final Node left = node.leftNode().get();
            treeInfo.getLeftChildMap().put(node.id(), left.id());
            nodesToVisit.push(left);
        }
    }

    public DecisionTreeModelInfo getModelInfo(final DecisionTreeModel decisionTreeModel, DataFrame df) {
        final DecisionTreeModelInfo treeInfo = new DecisionTreeModelInfo();
        final Node node = decisionTreeModel.topNode();
        treeInfo.setRoot(node.id());
        final Stack nodesToVisit = new Stack<>();
        nodesToVisit.push(node);
        while (!nodesToVisit.empty()) {
            Node curr = nodesToVisit.pop();
            visit(curr, nodesToVisit, treeInfo);
        }
        return treeInfo;
    }

    @Override
    public Class getSource() {
        return DecisionTreeModel.class;
    }

    @Override
    public Class getTarget() {
        return DecisionTreeModelInfo.class;
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy