All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ai.vespa.rankingexpression.importer.xgboost.XGBoostParser Maven / Gradle / Ivy

// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package ai.vespa.rankingexpression.importer.xgboost;

import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

import com.yahoo.json.Jackson;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.JsonNode;

/**
 * @author grace-lam
 */
class XGBoostParser {

    private final List xgboostTrees;

    /**
     * Constructor stores parsed JSON trees.
     *
     * @param filePath XGBoost JSON intput file.
     * @throws JsonProcessingException Fails JSON parsing.
     * @throws IOException             Fails file reading.
     */
    XGBoostParser(String filePath) throws JsonProcessingException, IOException {
        this.xgboostTrees = new ArrayList<>();
        var mapper = Jackson.mapper();
        JsonNode forestNode = mapper.readTree(new File(filePath));
        for (JsonNode treeNode : forestNode) {
            this.xgboostTrees.add(mapper.treeToValue(treeNode, XGBoostTree.class));
        }
    }

    /**
     * Converts parsed JSON trees to Vespa ranking expressions.
     *
     * @return Vespa ranking expressions.
     */
    String toRankingExpression() {
        StringBuilder ret = new StringBuilder();
        for (int i = 0; i < xgboostTrees.size(); i++) {
            ret.append(treeToRankExp(xgboostTrees.get(i)));
            if (i != xgboostTrees.size() - 1) {
                ret.append(" + \n");
            }
        }
        return ret.toString();
    }

    /**
     * Recursive helper function for toRankingExpression().
     *
     * @param node XGBoost tree node to convert.
     * @return Vespa ranking expression for input node.
     */
    private String treeToRankExp(XGBoostTree node) {
        if (node.isLeaf()) {
            return Double.toString(node.getLeaf());
        } else {
            assert node.getChildren().size() == 2;
            String trueExp;
            String falseExp;
            if (node.getYes() == node.getChildren().get(0).getNodeid()) {
                trueExp = treeToRankExp(node.getChildren().get(0));
                falseExp = treeToRankExp(node.getChildren().get(1));
            } else {
                trueExp = treeToRankExp(node.getChildren().get(1));
                falseExp = treeToRankExp(node.getChildren().get(0));
            }
            String condition;
            if (node.getMissing() == node.getYes()) {
                // Note: this is for handling missing features, as the backend handles comparison with NaN as false.
                condition = "!(" + node.getSplit() + " >= " + node.getSplit_condition() + ")";
            } else {
                condition = node.getSplit() + " < " + node.getSplit_condition();
            }
            return "if (" + condition + ", " + trueExp + ", " + falseExp + ")";
        }
    }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy