All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.genmodel.tools.PredictCsv Maven / Gradle / Ivy

There is a newer version: 3.46.0.5
Show newest version
package hex.genmodel.tools;

import au.com.bytecode.opencsv.CSVReader;
import hex.ModelCategory;
import hex.genmodel.GenModel;
import hex.genmodel.MojoModel;
import hex.genmodel.algos.glrm.GlrmMojoModel;
import hex.genmodel.algos.tree.SharedTreeMojoModel;
import hex.genmodel.easy.EasyPredictModelWrapper;
import hex.genmodel.easy.RowData;
import hex.genmodel.easy.prediction.*;
import hex.genmodel.utils.ArrayUtils;

import java.io.BufferedWriter;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.util.*;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;

/**
 * Simple driver program for reading a CSV file and making predictions.  Added support for separators that are
 * not commas. User needs to add the --separator separator_string to the input call.  Do not escape
 * the special Java characters, I will do it for you.
 *
 * This driver program is used as a test harness by several tests in the testdir_javapredict directory.
 * 

* See the top-of-tree master version of this file here on github. */ public class PredictCsv { private final String inputCSVFileName; private final String outputCSVFileName; private final boolean useDecimalOutput; private final char separator; private final boolean setInvNumNA; private final boolean getTreePath; private final boolean predictContributions; private final boolean predictCalibrated; private final boolean returnGLRMReconstruct; private final int glrmIterNumber; private final boolean outputHeader; // Model instance private EasyPredictModelWrapper modelWrapper; private PredictCsv( String inputCSVFileName, String outputCSVFileName, boolean useDecimalOutput, char separator, boolean setInvNumNA, boolean getTreePath, boolean predictContributions, boolean predictCalibrated, boolean returnGLRMReconstruct, int glrmIterNumber, boolean outputHeader) { this.inputCSVFileName = inputCSVFileName; this.outputCSVFileName = outputCSVFileName; this.useDecimalOutput = useDecimalOutput; this.separator = separator; this.setInvNumNA = setInvNumNA; this.getTreePath = getTreePath; this.predictContributions = predictContributions; this.predictCalibrated = predictCalibrated; this.returnGLRMReconstruct = returnGLRMReconstruct; this.glrmIterNumber = glrmIterNumber; this.outputHeader = outputHeader; } public static void main(String[] args) { PredictCsvCollection predictors = buildPredictCsv(args); PredictCsv main = predictors.main; // Run the main program try { main.run(); } catch (Exception e) { System.out.println("Predict error: " + e.getMessage()); System.out.println(); e.printStackTrace(); System.exit(1); } if (predictors.concurrent.length > 0) { try { ExecutorService executor = Executors.newFixedThreadPool(predictors.concurrent.length); List callables = new ArrayList<>(predictors.concurrent.length); for (int i = 0; i < predictors.concurrent.length; i++) { callables.add(new PredictCsvCallable(predictors.concurrent[i])); } int numExceptions = 0; for (Future future : executor.invokeAll(callables)) { Exception e = future.get(); if (e != null) { e.printStackTrace(); numExceptions++; } } if (numExceptions > 0) { throw new Exception("Some predictors failed (#failed=" + numExceptions + ")"); } } catch (Exception e) { System.out.println("Concurrent predict error: " + e.getMessage()); System.out.println(); e.printStackTrace(); System.exit(1); } } // Predictions were successfully generated. System.exit(0); } // Only meant to be used in tests public static PredictCsv make(String[] args, GenModel model) { final PredictCsvCollection predictorCollection = buildPredictCsv(args); if (predictorCollection.concurrent.length != 0) { throw new UnsupportedOperationException("Predicting with concurrent predictors is not supported in programmatic mode."); } final PredictCsv predictor = predictorCollection.main; if (model != null) { try { predictor.setModelWrapper(model); } catch (IOException e) { throw new RuntimeException(e); } } return predictor; } private static RowData formatDataRow(String[] splitLine, String[] inputColumnNames) { // Assemble the input values for the row. RowData row = new RowData(); int maxI = Math.min(inputColumnNames.length, splitLine.length); for (int i = 0; i < maxI; i++) { String columnName = inputColumnNames[i]; String cellData = splitLine[i]; switch (cellData) { case "": case "NA": case "N/A": case "-": continue; default: row.put(columnName, cellData); } } return row; } private String myDoubleToString(double d) { if (Double.isNaN(d)) { return "NA"; } return useDecimalOutput? Double.toString(d) : Double.toHexString(d); } private void writeTreePathNames(BufferedWriter output) throws Exception { String[] columnNames = ((SharedTreeMojoModel) modelWrapper.m).getDecisionPathNames(); writeColumnNames(output, columnNames); } private void writeCalibratedOutputNames(BufferedWriter output) throws Exception { String[] outputNames = modelWrapper.m.getOutputNames(); String[] calibOutputNames = new String[outputNames.length - 1]; for (int i = 0; i < calibOutputNames.length; i++) { calibOutputNames[i] = "cal_" + outputNames[i + 1]; } writeColumnNames(output, ArrayUtils.append(outputNames, calibOutputNames)); } private void writeContributionNames(BufferedWriter output) throws Exception { writeColumnNames(output, modelWrapper.getContributionNames()); } private void writeColumnNames(BufferedWriter output, String[] columnNames) throws Exception { int lastIndex = columnNames.length-1; for (int index = 0; index < lastIndex; index++) { output.write(columnNames[index]); output.write(","); } output.write(columnNames[lastIndex]); } public void run() throws Exception { ModelCategory category = modelWrapper.getModelCategory(); CSVReader reader = new CSVReader(new FileReader(inputCSVFileName), separator); BufferedWriter output = new BufferedWriter(new FileWriter(outputCSVFileName)); // Emit outputCSV column names. if (outputHeader) { switch (category) { case Binomial: case Multinomial: case Regression: if (getTreePath) { writeTreePathNames(output); } else if (predictContributions) { writeContributionNames(output); } else if (predictCalibrated) { writeCalibratedOutputNames(output); } else writeHeader(modelWrapper.m.getOutputNames(), output); break; case DimReduction: // will write factor or the predicted value depending on what the user wants if (returnGLRMReconstruct) { int datawidth; String[] colnames = this.modelWrapper.m.getNames(); datawidth = ((GlrmMojoModel) modelWrapper.m)._permutation.length; int lastData = datawidth - 1; for (int index = 0; index < datawidth; index++) { // add the numerical column names output.write("reconstr_" + colnames[index]); if (index < lastData) output.write(','); } } else writeHeader(modelWrapper.m.getOutputNames(), output); break; default: writeHeader(modelWrapper.m.getOutputNames(), output); } output.write("\n"); } // Loop over inputCSV one row at a time. // int lineNum=1; // count number of lines of input dataset file parsed try { String[] inputColumnNames; String[] splitLine; //Reader in the column names here. if ((splitLine = reader.readNext()) != null) { inputColumnNames = splitLine; checkMissingColumns(inputColumnNames); } else // file empty, throw an error throw new Exception("Input dataset file is empty!"); while ((splitLine = reader.readNext()) != null) { // Parse the CSV line. Don't handle quoted commas. This isn't a parser test. RowData row = formatDataRow(splitLine, inputColumnNames); // Do the prediction. // Emit the result to the output file. String offsetColumn = modelWrapper.m.getOffsetName(); double offset = offsetColumn==null ? 0 : Double.parseDouble((String) row.get(offsetColumn)); switch (category) { case AutoEncoder: { // write the expanded predictions out AutoEncoderModelPrediction p = modelWrapper.predictAutoEncoder(row); for (int i=0; i < p.reconstructed.length; i++) { output.write(myDoubleToString(p.reconstructed[i])); if (i < p.reconstructed.length-1) output.write(','); } break; } case Binomial: { BinomialModelPrediction p = modelWrapper.predictBinomial(row, offset); if (getTreePath) { writeTreePaths(p.leafNodeAssignments, output); } else if (predictContributions) { writeContributions(p.contributions, output); } else { output.write(p.label); output.write(","); for (int i = 0; i < p.classProbabilities.length; i++) { if (i > 0) { output.write(","); } output.write(myDoubleToString(p.classProbabilities[i])); } if (predictCalibrated) { for (int i = 0; i < p.classProbabilities.length; i++) { output.write(","); double calibProb = p.calibratedClassProbabilities != null ? p.calibratedClassProbabilities[i] : Double.NaN; output.write(myDoubleToString(calibProb)); } } } break; } case Multinomial: { MultinomialModelPrediction p = modelWrapper.predictMultinomial(row); if (getTreePath) { writeTreePaths(p.leafNodeAssignments, output); } else { output.write(p.label); output.write(","); for (int i = 0; i < p.classProbabilities.length; i++) { if (i > 0) { output.write(","); } output.write(myDoubleToString(p.classProbabilities[i])); } } break; } case Ordinal: { OrdinalModelPrediction p = modelWrapper.predictOrdinal(row, offset); output.write(p.label); output.write(","); for (int i = 0; i < p.classProbabilities.length; i++) { if (i > 0) { output.write(","); } output.write(myDoubleToString(p.classProbabilities[i])); } break; } case Clustering: { ClusteringModelPrediction p = modelWrapper.predictClustering(row); output.write(myDoubleToString(p.cluster)); break; } case Regression: { RegressionModelPrediction p = modelWrapper.predictRegression(row, offset); if (getTreePath) { writeTreePaths(p.leafNodeAssignments, output); } else if (predictContributions) { writeContributions(p.contributions, output); } else output.write(myDoubleToString(p.value)); break; } case CoxPH: { CoxPHModelPrediction p = modelWrapper.predictCoxPH(row, offset); output.write(myDoubleToString(p.value)); break; } case DimReduction: { DimReductionModelPrediction p = modelWrapper.predictDimReduction(row); double[] out; if (returnGLRMReconstruct) { out = p.reconstructed; // reconstructed A } else { out = p.dimensions; // x factors } int lastOne = out.length-1; for (int i=0; i < out.length; i++) { output.write(myDoubleToString(out[i])); if (i < lastOne) output.write(','); } break; } case AnomalyDetection: { AnomalyDetectionPrediction p = modelWrapper.predictAnomalyDetection(row); double[] rawPreds = p.toPreds(); for (int i = 0; i < rawPreds.length - 1; i++) { output.write(myDoubleToString(rawPreds[i])); output.write(','); } output.write(myDoubleToString(rawPreds[rawPreds.length - 1])); break; } default: throw new Exception("Unknown model category " + category); } output.write("\n"); lineNum++; } } catch (Exception e) { throw new Exception("Prediction failed on line " + lineNum, e); } finally { // Clean up. output.close(); reader.close(); } } private void writeHeader(String[] colNames, BufferedWriter output) throws Exception { output.write(colNames[0]); for (int i = 1; i < colNames.length; i++) { output.write(","); output.write(colNames[i]); } } private void writeTreePaths(String[] treePaths, BufferedWriter output) throws Exception { int len = treePaths.length-1; for (int index=0; index 0) { output.write(","); } output.write(myDoubleToString(contributions[i])); } } private void setModelWrapper(GenModel genModel) throws IOException { EasyPredictModelWrapper.Config config = new EasyPredictModelWrapper.Config() .setModel(genModel) .setConvertUnknownCategoricalLevelsToNa(true) .setConvertInvalidNumbersToNa(setInvNumNA); if (getTreePath) config.setEnableLeafAssignment(true); if (predictContributions) config.setEnableContributions(true); if (returnGLRMReconstruct) config.setEnableGLRMReconstrut(true); if (glrmIterNumber > 0) // set GLRM Mojo iteration number config.setGLRMIterNumber(glrmIterNumber); setModelWrapper(new EasyPredictModelWrapper(config)); } private void setModelWrapper(EasyPredictModelWrapper modelWrapper) { this.modelWrapper = modelWrapper; } private static void usage() { System.out.println(); System.out.println("Usage: java [...java args...] hex.genmodel.tools.PredictCsv --mojo mojoName"); System.out.println(" --pojo pojoName --input inputFile --output outputFile --separator sepStr --decimal --setConvertInvalidNum"); System.out.println(); System.out.println(" --mojo Name of the zip file containing model's MOJO."); System.out.println(" --pojo Name of the java class containing the model's POJO. Either this "); System.out.println(" parameter or --model must be specified."); System.out.println(" --input text file containing the test data set to score."); System.out.println(" --output Name of the output CSV file with computed predictions."); System.out.println(" --separator Separator to be used in input file containing test data set."); System.out.println(" --decimal Use decimal numbers in the output (default is to use hexademical)."); System.out.println(" --setConvertInvalidNum Will call .setConvertInvalidNumbersToNa(true) when loading models."); System.out.println(" --leafNodeAssignment will show the leaf node assignment for tree based models instead of" + " prediction results"); System.out.println(" --predictContributions will output prediction contributions (Shapley values) for tree based" + " models instead of regular model predictions"); System.out.println(" --glrmReconstruct will return the reconstructed dataset for GLRM mojo instead of X factor derived from the dataset."); System.out.println(" --glrmIterNumber integer indicating number of iterations to go through when constructing X factor derived from the dataset."); System.out.println(" --testConcurrent integer (for testing) number of concurrent threads that will be making predictions."); System.out.println(); System.exit(1); } private void checkMissingColumns(final String[] parsedColumnNamesArr) { final String[] modelColumnNames = modelWrapper.m._names; final Set parsedColumnNames = new HashSet<>(parsedColumnNamesArr.length); Collections.addAll(parsedColumnNames, parsedColumnNamesArr); List missingColumns = new ArrayList<>(); for (String columnName : modelColumnNames) { if (!parsedColumnNames.contains(columnName) && !columnName.equals(modelWrapper.m._responseColumn)) { missingColumns.add(columnName); } else { parsedColumnNames.remove(columnName); } } if(missingColumns.size() > 0){ final StringBuilder stringBuilder = new StringBuilder("There were "); stringBuilder.append(missingColumns.size()); stringBuilder.append(" missing columns found in the input data set: {"); for (int i = 0; i < missingColumns.size(); i++) { stringBuilder.append(missingColumns.get(i)); if(i != missingColumns.size() - 1) stringBuilder.append(","); } stringBuilder.append('}'); System.out.println(stringBuilder); } if(parsedColumnNames.size() > 0){ final StringBuilder stringBuilder = new StringBuilder("Detected "); stringBuilder.append(parsedColumnNames.size()); stringBuilder.append(" unused columns in the input data set: {"); final Iterator iterator = parsedColumnNames.iterator(); while (iterator.hasNext()){ stringBuilder.append(iterator.next()); if(iterator.hasNext()) stringBuilder.append(","); } stringBuilder.append('}'); System.out.println(stringBuilder); } } private static class PredictCsvCollection { private final PredictCsv main; private final PredictCsv[] concurrent; private PredictCsvCollection(PredictCsv main, PredictCsv[] concurrent) { this.main = main; this.concurrent = concurrent; } } private static PredictCsvCollection buildPredictCsv(String[] args) { try { PredictCsvBuilder builder = new PredictCsvBuilder(); builder.parseArgs(args); final GenModel genModel; switch (builder.loadType) { case -1: genModel = null; break; case 0: genModel = loadPojo(builder.pojoMojoModelNames); break; case 1: genModel = loadMojo(builder.pojoMojoModelNames); break; case 2: genModel = loadModel(builder.pojoMojoModelNames); break; default: throw new IllegalStateException("Unexpected value of loadType = " + builder.loadType); } PredictCsv mainPredictCsv = builder.newPredictCsv(); if (genModel != null) { mainPredictCsv.setModelWrapper(genModel); } PredictCsv[] concurrentPredictCsvs = new PredictCsv[builder.testConcurrent]; for (int id = 0; id < concurrentPredictCsvs.length; id++) { PredictCsv concurrentPredictCsv = builder.newConcurrentPredictCsv(id); concurrentPredictCsv.setModelWrapper(mainPredictCsv.modelWrapper); // re-use both the wrapper and the MOJO concurrentPredictCsvs[id] = concurrentPredictCsv; } return new PredictCsvCollection(mainPredictCsv, concurrentPredictCsvs); } catch (Exception e) { e.printStackTrace(); usage(); throw new IllegalStateException("Should not be reachable"); } } private static GenModel loadPojo(String className) throws Exception { return (GenModel) Class.forName(className).newInstance(); } private static GenModel loadMojo(String modelName) throws IOException { return MojoModel.load(modelName); } private static GenModel loadModel(String modelName) throws Exception { try { return loadMojo(modelName); } catch (IOException e) { return loadPojo(modelName); // may throw an exception too } } private static class PredictCsvBuilder { // For PredictCsv private String inputCSVFileName; private String outputCSVFileName; private boolean useDecimalOutput; private char separator = ','; // separator used to delimite input datasets private boolean setInvNumNA; // enable .setConvertInvalidNumbersToNa(true) private boolean getTreePath; // enable tree models to obtain the leaf-assignment information private boolean predictContributions; // enable tree models to predict contributions instead of regular predictions private boolean predictCalibrated; // output also calibrated probabilities private boolean returnGLRMReconstruct; // for GLRM, return x factor by default unless set this to true private int glrmIterNumber = -1; // for GLRM, default to 100. private boolean outputHeader = true; // should we write-out header to output files? // For Model Loading private int loadType = 0; // 0: load pojo, 1: load mojo, 2: load model, -1: special value when PredictCsv is used embedded and instance of Model is passed directly private String pojoMojoModelNames = ""; // store Pojo/Mojo/Model names private int testConcurrent = 0; private PredictCsv newPredictCsv() { return new PredictCsv(inputCSVFileName, outputCSVFileName, useDecimalOutput, separator, setInvNumNA, getTreePath, predictContributions, predictCalibrated, returnGLRMReconstruct, glrmIterNumber, outputHeader); } private PredictCsv newConcurrentPredictCsv(int id) { return new PredictCsv(inputCSVFileName, outputCSVFileName + "." + id, useDecimalOutput, separator, setInvNumNA, getTreePath, predictContributions, predictCalibrated, returnGLRMReconstruct, glrmIterNumber, outputHeader); } private void parseArgs(String[] args) { for (int i = 0; i < args.length; i++) { String s = args[i]; if (s.equals("--header")) continue; if (s.equals("--decimal")) useDecimalOutput = true; else if (s.equals("--glrmReconstruct")) returnGLRMReconstruct = true; else if (s.equals("--setConvertInvalidNum")) setInvNumNA = true; else if (s.equals("--leafNodeAssignment")) getTreePath = true; else if (s.equals("--predictContributions")) { predictContributions = true; } else if (s.equals("--predictCalibrated")) { predictCalibrated = true; } else if (s.equals("--embedded")) { loadType = -1; } else { i++; if (i >= args.length) usage(); String sarg = args[i]; switch (s) { case "--model": pojoMojoModelNames = sarg; loadType = 2; break; case "--mojo": pojoMojoModelNames = sarg; loadType = 1; break; case "--pojo": pojoMojoModelNames = sarg; loadType = 0; break; case "--input": inputCSVFileName = sarg; break; case "--output": outputCSVFileName = sarg; break; case "--separator": separator = sarg.charAt(sarg.length() - 1); break; case "--glrmIterNumber": glrmIterNumber = Integer.parseInt(sarg); break; case "--testConcurrent": testConcurrent = Integer.parseInt(sarg); break; case "--outputHeader": outputHeader = Boolean.parseBoolean(sarg); break; default: System.out.println("ERROR: Unknown command line argument: " + s); usage(); } } } } } private static class PredictCsvCallable implements Callable { private final PredictCsv predictCsv; private PredictCsvCallable(PredictCsv predictCsv) { this.predictCsv = predictCsv; } @Override public Exception call() throws Exception { try { predictCsv.run(); } catch (Exception e) { return e; } return null; } } }




© 2015 - 2024 Weber Informatics LLC | Privacy Policy