All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.expleague.ml.cli.JMLLCLI Maven / Gradle / Ivy

There is a newer version: 1.4.9
Show newest version
package com.expleague.ml.cli;

import com.expleague.ml.cli.modes.impl.*;
import com.expleague.ml.cli.modes.AbstractMode;
import org.apache.commons.cli.*;

import java.io.IOException;
import java.io.PrintWriter;

/**
 * User: qdeee
 * Date: 03.09.14
 */
@SuppressWarnings("UnusedDeclaration,AccessStaticViaInstance")
public class JMLLCLI {
  public static final String DEFAULT_TARGET = "L2";
  public static final String DEFAULT_MODELS_COMPARISION_CV_OUTPUT_FILE = "results";
  public static final String DEFAULT_OPTIMIZATION_SCHEME = "GradientBoosting(local=SatL2, weak=GreedyObliviousTree(depth=6), step=0.02, iterations=1000)";

  public static final String LEARN_OPTION = "f";
  public static final String JSON_FORMAT = "j";
  public static final String CD_FILE = "cd";
  public static final String DELIMITER = "delimiter";
  public static final String HAS_HEADER = "hasHeader";
  public static final String TEST_OPTION = "t";
  public static final String CROSS_VALIDATION_OPTION = "X";
  public static final String INTERPRET_MODE_OPTION = "I";

  public static final String TARGET_OPTION = "T";
  public static final String METRICS_OPTION = "M";

  public static final String BIN_FOLDS_COUNT_OPTION = "x";
  public static final String GRID_OPTION = "g";
  public static final String OPTIMIZATION_OPTION = "O";
  public static final String LOAD_OPTIMIZATION_SCHEMES_FROM_FILE_OPTION = "s";
  public static final String CROSS_VALIDATION_RESULT_OPTION = "R";

  public static final String VERBOSE_OPTION = "v";
  public static final String PRINT_PERIOD = "i";
  public static final String FAST_OPTION = "fast";
  public static final String SKIP_FINAL_EVAL_OPTION = "fastfinal";
  public static final String HIST_OPTION = "h";
  public static final String OUTPUT_OPTION = "o";
  public static final String WRITE_BIN_FORMULA = "mxbin";

  public static final String MODEL_OPTION = "m";
  public static final String COUNTER_OPTION = "n";

  public static final String RANGES_OPTION = "r";
  public static final String RANDOM_SEED_OPTION = "seed";

  private static Options options = new Options();
  static {
    options.addOption(OptionBuilder.withLongOpt("learn").withDescription("features.txt format file used as learn").hasArg().create(LEARN_OPTION));
    options.addOption(OptionBuilder.withLongOpt("cd").withDescription("features.cd file (catboost pool format)").hasArg().create(CD_FILE));
    options.addOption(OptionBuilder.withLongOpt("json-format").withDescription("alternative format for features.txt").hasArg(false).create(JSON_FORMAT));
    options.addOption(OptionBuilder.withLongOpt("test").withDescription("test part in features.txt format").hasArg().create(TEST_OPTION));
    options.addOption(OptionBuilder.withLongOpt("cross-validation").withDescription("k folds CV").hasArg().create(CROSS_VALIDATION_OPTION));

    //support in catboost pools only
    options.addOption(OptionBuilder.withLongOpt("delimiter").withDescription("Delimiter on file, default \t").hasArg().create(DELIMITER));
    options.addOption(OptionBuilder.withLongOpt("has-header").withDescription("Pool with header flag").hasArg(false).create(HAS_HEADER));


    options.addOption(OptionBuilder.withLongOpt("target").withDescription("target function to optimize format Global/Weak/Cursor (" + DEFAULT_TARGET + ")").hasArg().create(TARGET_OPTION));
    options.addOption(OptionBuilder.withLongOpt("metrics").withDescription("metrics to test, by default contains global optimization target").hasArgs().create(METRICS_OPTION));

    options.addOption(OptionBuilder.withLongOpt("bin-folds-count").withDescription("binarization precision: how many binary features inferred from real one").hasArg().create(BIN_FOLDS_COUNT_OPTION));
    options.addOption(OptionBuilder.withLongOpt("grid").withDescription("file with already precomputed grid").hasArg().create(GRID_OPTION));
    options.addOption(OptionBuilder.withLongOpt("optimization").withDescription("optimization scheme: Strong/Weak or just Strong (" + DEFAULT_OPTIMIZATION_SCHEME + ")").hasArg().create(OPTIMIZATION_OPTION));
    options.addOption(OptionBuilder.withLongOpt("file-based-optimization").withDescription("optimization schemes in file: Strong/Weak or just Strong").hasArg().create(LOAD_OPTIMIZATION_SCHEMES_FROM_FILE_OPTION));
    options.addOption(OptionBuilder.withLongOpt("file-with-results").withDescription("result file for cross validation").hasArg().create(CROSS_VALIDATION_RESULT_OPTION));

    options.addOption(OptionBuilder.withLongOpt("out").withDescription("output file name").hasArg().create(OUTPUT_OPTION));
    options.addOption(OptionBuilder.withLongOpt("matrixnetbin").withDescription("write model in matrix-net bin format").hasArg(false).create(WRITE_BIN_FORMULA));

    options.addOption(OptionBuilder.withLongOpt("verbose").withDescription("verbose output").create(VERBOSE_OPTION));
    options.addOption(OptionBuilder.withLongOpt("print-period").withDescription("number of iterations to evaluate and print scores").hasArg().create(PRINT_PERIOD));
    options.addOption(OptionBuilder.withLongOpt("fast-run").withDescription("fast run without model evaluation").create(FAST_OPTION));
    options.addOption(OptionBuilder.withLongOpt("skip-final-eval").withDescription("skip model evaluation on last step (faster)").create(SKIP_FINAL_EVAL_OPTION));
    options.addOption(OptionBuilder.withLongOpt("histogram").withDescription("histogram for dynamic grid").hasArg(false).create(HIST_OPTION));

    options.addOption(OptionBuilder.withLongOpt("model").withDescription("model file").hasArg().create(MODEL_OPTION));

    options.addOption(OptionBuilder.withLongOpt("ranges").withDescription("parameters ranges").hasArg().create(RANGES_OPTION));
    options.addOption(OptionBuilder.withLongOpt("seed").withDescription("random seed").hasArg().create(RANDOM_SEED_OPTION));
    options.addOption(OptionBuilder.withLongOpt("view").withDescription("Comma separated interpret views. Possible values are: histogram for by feature histograms, linear for list of linear components of the ensemble, splits(k) for top k influencing splits").hasArg().create(INTERPRET_MODE_OPTION));
    options.addOption(OptionBuilder.withLongOpt("count").withDescription("Counter option for number of iterations or something like this").hasArg().create(COUNTER_OPTION));
  }

  public static void main(final String[] args) throws IOException {
    final CommandLineParser parser = new GnuParser();
    try {
      final CommandLine command = parser.parse(options, args);
      if (command.getArgs().length == 0)
        throw new RuntimeException("Please provide mode to run");

      final String modeName = command.getArgs()[0];
      final AbstractMode mode = getMode(modeName);
      mode.run(command);

    } catch (ParseException e) {
      System.err.println(e.getLocalizedMessage());
      final HelpFormatter formatter = new HelpFormatter();
      final String columns = System.getenv("COLUMNS");
      formatter.printUsage(new PrintWriter(System.err), columns != null ? Integer.parseInt(columns) : 80, "jmll", options);
      formatter.printHelp("JMLLCLI", options);

    } catch (Exception e) {
      e.printStackTrace();
      System.err.println(e.getLocalizedMessage());
    }
  }

  private static AbstractMode getMode(final String mode) {
    switch (mode) {
      case "fit":                       return new Fit();
      case "apply":                     return new Apply();
      case "convert-pool":              return new ConvertPool();
      case "convert-pool-json2classic": return new ConvertPoolJson2Classsic();
      case "convert-pool-libfm":        return new ConvertPoolLibSvm();
      case "validate-model":            return new ValidateModel();
      case "validate-pool":             return new ValidatePool();
      case "split-json-pool":           return new SplitJsonPool();
      case "print-pool-info":           return new PrintPoolInfo();
      case "grid-search":               return new GridSearch();
      case "cross-validation":          return new CrossValidation();
      case "interpret":                 return new InterpretModel();
      case "eval-model":                return new EvaluateModel();
      case "dict":                      return new CreateDictionary();

      default:
        throw new RuntimeException("Mode " + mode + " is not recognized");
    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy