
com.flipkart.fdp.ml.adapter.DecisionTreeRegressionModelInfoAdapter Maven / Gradle / Ivy
The newest version!
package com.flipkart.fdp.ml.adapter;
import com.flipkart.fdp.ml.modelinfo.DecisionTreeModelInfo;
import com.flipkart.fdp.ml.utils.DecisionNodeAdapterUtils;
import lombok.extern.slf4j.Slf4j;
import org.apache.spark.ml.regression.DecisionTreeRegressionModel;
import org.apache.spark.ml.tree.Node;
import org.apache.spark.sql.DataFrame;
import java.util.LinkedHashSet;
import java.util.Set;
/**
* Transforms Spark's {@link org.apache.spark.ml.regression.DecisionTreeRegressionModel} in MlLib to {@link com.flipkart.fdp.ml.modelinfo.DecisionTreeModelInfo} object
* that can be exported through {@link com.flipkart.fdp.ml.export.ModelExporter}
*/
@Slf4j
public class DecisionTreeRegressionModelInfoAdapter
extends AbstractModelInfoAdapter {
public DecisionTreeModelInfo getModelInfo(final DecisionTreeRegressionModel decisionTreeModel, final DataFrame df) {
final DecisionTreeModelInfo treeInfo = new DecisionTreeModelInfo();
Node rootNode = decisionTreeModel.rootNode();
treeInfo.setRoot( DecisionNodeAdapterUtils.adaptNode(rootNode));
final Set inputKeys = new LinkedHashSet();
inputKeys.add(decisionTreeModel.getFeaturesCol());
inputKeys.add(decisionTreeModel.getLabelCol());
treeInfo.setInputKeys(inputKeys);
final Set outputKeys = new LinkedHashSet();
outputKeys.add(decisionTreeModel.getPredictionCol());
treeInfo.setOutputKeys(outputKeys);
return treeInfo;
}
@Override
public Class getSource() {
return DecisionTreeRegressionModel.class;
}
@Override
public Class getTarget() {
return DecisionTreeModelInfo.class;
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy