All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ai.djl.ml.xgboost.XgbModel Maven / Gradle / Ivy

/*
 * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
 * with the License. A copy of the License is located at
 *
 * http://aws.amazon.com/apache2.0/
 *
 * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
 * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
 * and limitations under the License.
 */
package ai.djl.ml.xgboost;

import ai.djl.BaseModel;
import ai.djl.Model;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;

import ml.dmlc.xgboost4j.java.JniUtils;

import java.io.FileNotFoundException;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.Locale;
import java.util.Map;

/** {@code XgbModel} is the XGBoost implementation of {@link Model}. */
public class XgbModel extends BaseModel {

    /**
     * Constructs a new Model on a given device.
     *
     * @param modelName the model name
     * @param manager the {@link NDManager} to holds the NDArray
     */
    XgbModel(String modelName, NDManager manager) {
        super(modelName);
        dataType = DataType.FLOAT32;
        this.manager = manager;
        manager.setName("XgbModel");
    }

    /** {@inheritDoc} */
    @Override
    public void load(Path modelPath, String prefix, Map options) throws IOException {
        setModelDir(modelPath);
        wasLoaded = true;
        if (block != null) {
            throw new UnsupportedOperationException("XGBoost does not support dynamic blocks");
        }
        Path modelFile = findModelFile(prefix);
        if (modelFile == null) {
            modelFile = findModelFile(modelDir.toFile().getName());
            if (modelFile == null) {
                throw new FileNotFoundException(".json file not found in: " + modelPath);
            }
        }
        block = JniUtils.loadModel((XgbNDManager) manager, modelFile.toAbsolutePath().toString());
        // set extra options
        if (options != null) {
            if (options.containsKey("Mode")) {
                ((XgbSymbolBlock) block)
                        .setMode(
                                XgbSymbolBlock.Mode.valueOf(
                                        ((String) options.get("Mode")).toUpperCase(Locale.ROOT)));
            }
            if (options.containsKey("TreeLimit")) {
                ((XgbSymbolBlock) block)
                        .setTreeLimit(Integer.parseInt((String) options.get("TreeLimit")));
            }
        }
    }

    private Path findModelFile(String prefix) {
        if (Files.isRegularFile(modelDir)) {
            Path file = modelDir;
            modelDir = modelDir.getParent();
            String fileName = file.toFile().getName();
            if (fileName.endsWith(".json")) {
                modelName = fileName.substring(0, fileName.length() - 5);
            } else if (fileName.endsWith(".xgb")) {
                modelName = fileName.substring(0, fileName.length() - 4);
            } else {
                modelName = fileName;
            }
            return file;
        }
        if (prefix == null) {
            prefix = modelName;
        }
        Path modelFile = modelDir.resolve(prefix);
        if (Files.notExists(modelFile) || !Files.isRegularFile(modelFile)) {
            if (prefix.endsWith(".json") || prefix.endsWith(".xgb")) {
                return null;
            }
            modelFile = modelDir.resolve(prefix + ".json");
            if (Files.isRegularFile(modelFile)) {
                return modelFile;
            }
            modelFile = modelDir.resolve(prefix + ".xgb");
            if (Files.isRegularFile(modelFile)) {
                return modelFile;
            }
            modelFile = modelDir.resolve("model.xgb");
            if (Files.isRegularFile(modelFile)) {
                return modelFile;
            }
            return null;
        }
        return modelFile;
    }

    /** {@inheritDoc} */
    @Override
    public void close() {
        if (block != null) {
            ((XgbSymbolBlock) block).close();
            block = null;
        }
        super.close();
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy