All Downloads are FREE. Search and download functionalities are using the official Maven repository.

water.bindings.examples.retrofit.GBM_Example Maven / Gradle / Ivy

There is a newer version: 3.46.0.5
Show newest version
package water.bindings.examples.retrofit;

import water.bindings.H2oApi;
import water.bindings.pojos.*;

import java.io.File;
import java.io.IOException;
import java.util.UUID;

public class GBM_Example {

    public static void gbmExampleFlow(String url) throws IOException {
        H2oApi h2o = url != null ? new H2oApi(url) : new H2oApi();

        // Utility var:
        JobV3 job = null;

        // STEP 0: init a session
        String sessionId = h2o.newSession().sessionKey;


        // STEP 1: import raw file
        ImportFilesV3 importBody = h2o.importFiles(
            "http://s3.amazonaws.com/h2o-public-test-data/smalldata/flow_examples/arrhythmia.csv.gz", null
          );
        System.out.println("import: " + importBody);


        // STEP 2: parse setup
        ParseSetupV3 parseSetupBody = h2o.guessParseSetup(H2oApi.stringArrayToKeyArray(importBody.destinationFrames, FrameKeyV3.class));
        System.out.println("parseSetupBody: " + parseSetupBody);


        // STEP 3: parse into columnar Frame
        ParseV3 parseParms = new ParseV3();
        H2oApi.copyFields(parseParms, parseSetupBody);
        parseParms.destinationFrame = H2oApi.stringToFrameKey("arrhythmia.hex");
        parseParms.blocking = true;  // alternately, call h2o.waitForJobCompletion(parseSetupBody.job)

        ParseV3 parseBody = h2o.parse(parseParms);
        System.out.println("parseBody: " + parseBody);


        // STEP 4: Split into test and train datasets
        String tmpVec = "tmp_" + UUID.randomUUID().toString();
        String splitExpr =
          "(, " +
          "  (tmp= " + tmpVec + " (h2o.runif arrhythmia.hex 906317))" +
          "  (assign train " +
          "    (rows arrhythmia.hex (<= " + tmpVec + " 0.75)))" +
          "  (assign test " +
          "    (rows arrhythmia.hex (> " + tmpVec + " 0.75)))" +
          "  (rm " + tmpVec + "))";
        RapidsSchemaV3 rapidsParms = new RapidsSchemaV3();
        rapidsParms.sessionId = sessionId;
        rapidsParms.ast = splitExpr;
        h2o.rapidsExec(rapidsParms);


        // STEP 5: Train the model (NOTE: step 4 is polling, which we don't require because we specified blocking for the parse above)
        GBMParametersV3 gbmParms = new GBMParametersV3();

        // gbmParms.trainingFrame = H2oApi.stringToFrameKey("arrhythmia.hex");

        gbmParms.trainingFrame = H2oApi.stringToFrameKey("train");
        gbmParms.validationFrame = H2oApi.stringToFrameKey("test");

        ColSpecifierV3 responseColumn = new ColSpecifierV3();
        responseColumn.columnName = "C1";
        gbmParms.responseColumn = responseColumn;

        System.out.println("About to train GBM. . .");
        GBMV3 gbmBody = h2o.train_gbm(gbmParms);
        System.out.println("gbmBody: " + gbmBody);


        // STEP 6: poll for completion
        job = h2o.waitForJobCompletion(gbmBody.job.key);
        System.out.println("GBM build done.");


        // STEP 7: fetch the model
        ModelKeyV3 model_key = (ModelKeyV3)job.dest;
        ModelsV3 models = h2o.model(model_key);
        System.out.println("models: " + models);
        GBMModelV3 model = (GBMModelV3)models.models[0];
        System.out.println("new GBM model: " + model);
        // System.out.println("new GBM model: " + models.models[0]);
        assert model.getClass() == GBMModelV3.class;
        assert model.output.getClass() == GBMModelOutputV3.class;
        assert model.parameters.getClass() == GBMParametersV3.class;

        // STEP 9 (optional): export model as binary
        ModelExportV3 modelExport = new ModelExportV3();
        modelExport.modelId = model_key;

        File binaryModelFile = File.createTempFile("model", ".h2o");
        modelExport.dir = File.createTempFile("model", ".h2o").getPath();
        binaryModelFile.deleteOnExit();

        // STEP 8: predict!
        ModelMetricsListSchemaV3 predict_params = new ModelMetricsListSchemaV3();
        predict_params.model = model_key;
        predict_params.frame = gbmParms.trainingFrame;
        predict_params.predictionsFrame = H2oApi.stringToFrameKey("predictions");

        ModelMetricsListSchemaV3 predictions = h2o.predict(predict_params);
        System.out.println("predictions: " + predictions);

        // STEP 99: end the session
        h2o.endSession();
    }

    public static void gbmExampleFlow() throws IOException {
        gbmExampleFlow(null);
    }

    public static void main (String[] args) throws IOException {
        gbmExampleFlow();
    }
}





© 2015 - 2024 Weber Informatics LLC | Privacy Policy