water.bindings.examples.retrofit.GBM_Example Maven / Gradle / Ivy
package water.bindings.examples.retrofit;
import water.bindings.H2oApi;
import water.bindings.pojos.*;
import java.io.File;
import java.io.IOException;
import java.util.UUID;
public class GBM_Example {
public static void gbmExampleFlow(String url) throws IOException {
H2oApi h2o = url != null ? new H2oApi(url) : new H2oApi();
// Utility var:
JobV3 job = null;
// STEP 0: init a session
String sessionId = h2o.newSession().sessionKey;
// STEP 1: import raw file
ImportFilesV3 importBody = h2o.importFiles(
"http://s3.amazonaws.com/h2o-public-test-data/smalldata/flow_examples/arrhythmia.csv.gz", null
);
System.out.println("import: " + importBody);
// STEP 2: parse setup
ParseSetupV3 parseSetupBody = h2o.guessParseSetup(H2oApi.stringArrayToKeyArray(importBody.destinationFrames, FrameKeyV3.class));
System.out.println("parseSetupBody: " + parseSetupBody);
// STEP 3: parse into columnar Frame
ParseV3 parseParms = new ParseV3();
H2oApi.copyFields(parseParms, parseSetupBody);
parseParms.destinationFrame = H2oApi.stringToFrameKey("arrhythmia.hex");
parseParms.blocking = true; // alternately, call h2o.waitForJobCompletion(parseSetupBody.job)
ParseV3 parseBody = h2o.parse(parseParms);
System.out.println("parseBody: " + parseBody);
// STEP 4: Split into test and train datasets
String tmpVec = "tmp_" + UUID.randomUUID().toString();
String splitExpr =
"(, " +
" (tmp= " + tmpVec + " (h2o.runif arrhythmia.hex 906317))" +
" (assign train " +
" (rows arrhythmia.hex (<= " + tmpVec + " 0.75)))" +
" (assign test " +
" (rows arrhythmia.hex (> " + tmpVec + " 0.75)))" +
" (rm " + tmpVec + "))";
RapidsSchemaV3 rapidsParms = new RapidsSchemaV3();
rapidsParms.sessionId = sessionId;
rapidsParms.ast = splitExpr;
h2o.rapidsExec(rapidsParms);
// STEP 5: Train the model (NOTE: step 4 is polling, which we don't require because we specified blocking for the parse above)
GBMParametersV3 gbmParms = new GBMParametersV3();
// gbmParms.trainingFrame = H2oApi.stringToFrameKey("arrhythmia.hex");
gbmParms.trainingFrame = H2oApi.stringToFrameKey("train");
gbmParms.validationFrame = H2oApi.stringToFrameKey("test");
ColSpecifierV3 responseColumn = new ColSpecifierV3();
responseColumn.columnName = "C1";
gbmParms.responseColumn = responseColumn;
System.out.println("About to train GBM. . .");
GBMV3 gbmBody = h2o.train_gbm(gbmParms);
System.out.println("gbmBody: " + gbmBody);
// STEP 6: poll for completion
job = h2o.waitForJobCompletion(gbmBody.job.key);
System.out.println("GBM build done.");
// STEP 7: fetch the model
ModelKeyV3 model_key = (ModelKeyV3)job.dest;
ModelsV3 models = h2o.model(model_key);
System.out.println("models: " + models);
GBMModelV3 model = (GBMModelV3)models.models[0];
System.out.println("new GBM model: " + model);
// System.out.println("new GBM model: " + models.models[0]);
assert model.getClass() == GBMModelV3.class;
assert model.output.getClass() == GBMModelOutputV3.class;
assert model.parameters.getClass() == GBMParametersV3.class;
// STEP 9 (optional): export model as binary
ModelExportV3 modelExport = new ModelExportV3();
modelExport.modelId = model_key;
File binaryModelFile = File.createTempFile("model", ".h2o");
modelExport.dir = File.createTempFile("model", ".h2o").getPath();
binaryModelFile.deleteOnExit();
// STEP 8: predict!
ModelMetricsListSchemaV3 predict_params = new ModelMetricsListSchemaV3();
predict_params.model = model_key;
predict_params.frame = gbmParms.trainingFrame;
predict_params.predictionsFrame = H2oApi.stringToFrameKey("predictions");
ModelMetricsListSchemaV3 predictions = h2o.predict(predict_params);
System.out.println("predictions: " + predictions);
// STEP 99: end the session
h2o.endSession();
}
public static void gbmExampleFlow() throws IOException {
gbmExampleFlow(null);
}
public static void main (String[] args) throws IOException {
gbmExampleFlow();
}
}