All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.tree.xgboost.remote.RemoteXGBoostHandler Maven / Gradle / Ivy

There is a newer version: 3.46.0.6
Show newest version
package hex.tree.xgboost.remote;

import hex.genmodel.utils.IOUtils;
import hex.schemas.XGBoostExecReqV3;
import hex.schemas.XGBoostExecRespV3;
import hex.tree.xgboost.EvalMetric;
import hex.tree.xgboost.exec.LocalXGBoostExecutor;
import hex.tree.xgboost.exec.XGBoostExecReq;
import org.apache.log4j.Logger;
import water.BootstrapFreezable;
import water.H2O;
import water.Iced;
import water.TypeMap;
import water.api.Handler;
import water.api.StreamingSchema;

import java.io.ByteArrayInputStream;
import java.io.IOException;

import static hex.tree.xgboost.remote.XGBoostExecutorRegistry.*;

public class RemoteXGBoostHandler extends Handler {

    private static final Logger LOG = Logger.getLogger(RemoteXGBoostHandler.class);

    private XGBoostExecRespV3 makeResponse(LocalXGBoostExecutor exec) {
        return new XGBoostExecRespV3(exec.modelKey);
    }

    @SuppressWarnings("unused")
    public XGBoostExecRespV3 init(int ignored, XGBoostExecReqV3 req) {
        XGBoostExecReq.Init init = req.readData();
        LocalXGBoostExecutor exec = new LocalXGBoostExecutor(req.key.key(), init);
        storeExecutor(exec);
        return new XGBoostExecRespV3(exec.modelKey, collectNodes());
    }

    public static class RemoteExecutors extends Iced implements BootstrapFreezable {
        public final String[] _nodes;
        public final String[] _typeMap;
        public RemoteExecutors(String[] nodes) {
            _nodes = nodes;
            _typeMap = TypeMap.bootstrapClasses();
        }
    }

    private RemoteExecutors collectNodes() {
        String[] nodes = new String[H2O.CLOUD.size()];
        for (int i = 0; i < nodes.length; i++) {
            nodes[i] = H2O.CLOUD.members()[i].getIpPortString();
        }
        return new RemoteExecutors(nodes);
    }

    @SuppressWarnings("unused")
    public StreamingSchema setup(int ignored, XGBoostExecReqV3 req) {
        LocalXGBoostExecutor exec = getExecutor(req);
        byte[] booster = exec.setup();
        return streamBytes(booster);
    }

    @SuppressWarnings("unused")
    public XGBoostExecRespV3 update(int ignored, XGBoostExecReqV3 req) {
        LocalXGBoostExecutor exec = getExecutor(req);
        XGBoostExecReq.Update update = req.readData();
        exec.update(update.treeId);
        return makeResponse(exec);
    }

    @SuppressWarnings("unused")
    public XGBoostExecRespV3 getEvalMetric(int ignored, XGBoostExecReqV3 req) {
        LocalXGBoostExecutor exec = getExecutor(req);
        EvalMetric evalMetric = exec.getEvalMetric();
        return new XGBoostExecRespV3(exec.modelKey, evalMetric);
    }
    
    @SuppressWarnings("unused")
    public StreamingSchema getBooster(int ignored, XGBoostExecReqV3 req) {
        LocalXGBoostExecutor exec = getExecutor(req);
        byte[] booster = exec.updateBooster();
        return streamBytes(booster);
    }

    @SuppressWarnings("unused")
    public XGBoostExecRespV3 cleanup(int ignored, XGBoostExecReqV3 req) {
        LocalXGBoostExecutor exec = getExecutor(req);
        exec.close();
        removeExecutor(exec);
        return makeResponse(exec);
    }

    private StreamingSchema streamBytes(byte[] data) {
        final byte[] dataToSend;
        if (data == null) dataToSend = new byte[0];
        else dataToSend = data;
        return new StreamingSchema((os, options) -> {
            try {
                IOUtils.copyStream(new ByteArrayInputStream(dataToSend), os);
            } catch (IOException e) {
                LOG.error("Failed writing data to response.", e);
                throw new RuntimeException("Failed writing data to response.", e);
            }
        });
    }


}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy