All Downloads are FREE. Search and download functionalities are using the official Maven repository.

pairsICSE15.1000098.TrainLogistic.1000098_TrainLogistic_t Maven / Gradle / Ivy

The newest version!
 /*
  * Licensed to the Apache Software Foundation (ASF) under one or more
  * contributor license agreements.  See the NOTICE file distributed with
  * this work for additional information regarding copyright ownership.
  * The ASF licenses this file to You under the Apache License, Version 2.0
  * (the "License"); you may not use this file except in compliance with
  * the License.  You may obtain a copy of the License at
  *
  *     http://www.apache.org/licenses/LICENSE-2.0
  *
  * Unless required by applicable law or agreed to in writing, software
  * distributed under the License is distributed on an "AS IS" BASIS,
  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
 
 package org.apache.mahout.classifier.sgd;
 
 import com.google.common.collect.Lists;
 import com.google.common.io.Resources;
 import org.apache.commons.cli2.CommandLine;
 import org.apache.commons.cli2.Group;
 import org.apache.commons.cli2.Option;
 import org.apache.commons.cli2.builder.ArgumentBuilder;
 import org.apache.commons.cli2.builder.DefaultOptionBuilder;
 import org.apache.commons.cli2.builder.GroupBuilder;
 import org.apache.commons.cli2.commandline.Parser;
 import org.apache.commons.cli2.util.HelpFormatter;
 import org.apache.mahout.math.RandomAccessSparseVector;
 import org.apache.mahout.math.Vector;
 
 import java.io.BufferedReader;
 import java.io.FileReader;
 import java.io.FileWriter;
 import java.io.IOException;
 import java.io.InputStreamReader;
 import java.io.OutputStreamWriter;
 import java.net.URL;
 import java.util.List;
 
 
 /**
  * Train a logistic regression for the examples from Chapter 13 of Mahout in Action
  */
 public final class TrainLogistic {
 
   private static String inputFile;
   private static String outputFile;
   private static LogisticModelParameters lmp;
 
   private static int passes;
   private static boolean scores;
  private static OnlineLogisticRegression model;
 
   private TrainLogistic() {
   }
 
   public static void main(String[] args) throws IOException {
     if (parseArgs(args)) {
       double logPEstimate = 0;
       int samples = 0;
 
       CsvRecordFactory csv = lmp.getCsvRecordFactory();
       OnlineLogisticRegression lr = lmp.createRegression();
       for (int pass = 0; pass < passes; pass++) {
         BufferedReader in = InputOpener.open(inputFile);
 
         // read variable names
         csv.firstLine(in.readLine());
 
         String line = in.readLine();
         while (line != null) {
           // for each new line, get target and predictors
           Vector input = new RandomAccessSparseVector(lmp.getNumFeatures());
           int targetValue = csv.processLine(line, input);
 
           // check performance while this is still news
           double logP = lr.logLikelihood(targetValue, input);
           if (!Double.isInfinite(logP)) {
             if (samples < 20) {
               logPEstimate = (samples * logPEstimate + logP) / (samples + 1);
             } else {
               logPEstimate = 0.95 * logPEstimate + 0.05 * logP;
             }
             samples++;
           }
           double p = lr.classifyScalar(input);
           if (scores) {
             System.out.printf("%10d %2d %10.2f %2.4f %10.4f %10.4f\n", samples, targetValue, lr.currentLearningRate(), p, logP, logPEstimate);
           }
 
           // now update model
           lr.train(targetValue, input);
 
           line = in.readLine();
         }
         in.close();
       }
 
       OutputStreamWriter modelOutput = new FileWriter(outputFile);
       try {
         lmp.saveTo(modelOutput);
       } finally {
         modelOutput.close();
       }
       
       System.out.printf("%d\n", lmp.getNumFeatures());
       System.out.printf("%s ~ ", lmp.getTargetVariable());
       String sep = "";
       for (String v : csv.getPredictors()) {
         double weight = predictorWeight(lr, 0, csv, v);
         if (weight != 0) {
           System.out.printf("%s%.3f*%s", sep, weight, v);
           sep = " + ";
         }
       }
       System.out.printf("\n");
      model = lr;
       for (int row = 0; row < lr.getBeta().numRows(); row++) {
         for (String key : csv.getTraceDictionary().keySet()) {
           double weight = predictorWeight(lr, row, csv, key);
           if (weight != 0) {
             System.out.printf("%20s %.5f\n", key, weight);
           }
         }
         for (int column = 0; column < lr.getBeta().numCols(); column++) {
           System.out.printf("%15.9f ", lr.getBeta().get(row, column));
         }
         System.out.println();
       }
     }
   }
 
   private static double predictorWeight(OnlineLogisticRegression lr, int row, RecordFactory csv, String predictor) {
     double weight = 0;
     for (Integer column : csv.getTraceDictionary().get(predictor)) {
       weight += lr.getBeta().get(row, column);
     }
     return weight;
   }
 
   private static boolean parseArgs(String[] args) {
     DefaultOptionBuilder builder = new DefaultOptionBuilder();
 
     Option help = builder.withLongName("help").withDescription("print this list").create();
 
     Option quiet = builder.withLongName("quiet").withDescription("be extra quiet").create();
     Option scores = builder.withLongName("scores").withDescription("output score diagnostics during training").create();
 
     ArgumentBuilder argumentBuilder = new ArgumentBuilder();
     Option inputFile = builder.withLongName("input")
             .withRequired(true)
             .withArgument(argumentBuilder.withName("input").withMaximum(1).create())
             .withDescription("where to get training data")
             .create();
 
     Option outputFile = builder.withLongName("output")
             .withRequired(true)
             .withArgument(argumentBuilder.withName("output").withMaximum(1).create())
             .withDescription("where to get training data")
             .create();
 
     Option predictors = builder.withLongName("predictors")
             .withRequired(true)
             .withArgument(argumentBuilder.withName("p").create())
             .withDescription("a list of predictor variables")
             .create();
 
     Option types = builder.withLongName("types")
             .withRequired(true)
             .withArgument(argumentBuilder.withName("t").create())
             .withDescription("a list of predictor variable types (numeric, word, or text)")
             .create();
 
     Option target = builder.withLongName("target")
             .withRequired(true)
             .withArgument(argumentBuilder.withName("target").withMaximum(1).create())
             .withDescription("the name of the target variable")
             .create();
 
     Option features = builder.withLongName("features")
             .withArgument(
                     argumentBuilder.withName("numFeatures")
                             .withDefault("1000")
                             .withMaximum(1).create())
             .withDescription("the number of internal hashed features to use")
             .create();
 
     Option passes = builder.withLongName("passes")
             .withArgument(
                     argumentBuilder.withName("passes")
                             .withDefault("2")
                             .withMaximum(1).create())
             .withDescription("the number of times to pass over the input data")
             .create();
 
     Option lambda = builder.withLongName("lambda")
             .withArgument(argumentBuilder.withName("lambda").withDefault("1e-4").withMaximum(1).create())
             .withDescription("the amount of coefficient decay to use")
             .create();
 
     Option rate = builder.withLongName("rate")
             .withArgument(argumentBuilder.withName("learningRate").withDefault("1e-3").withMaximum(1).create())
             .withDescription("the learning rate")
             .create();
 
     Option noBias = builder.withLongName("noBias")
             .withDescription("don't include a bias term")
             .create();
 
     Option targetCategories = builder.withLongName("categories")
             .withRequired(true)
             .withArgument(argumentBuilder.withName("number").withMaximum(1).create())
             .withDescription("the number of target categories to be considered")
             .create();
 
     Group normalArgs = new GroupBuilder()
             .withOption(help)
             .withOption(quiet)
             .withOption(inputFile)
             .withOption(outputFile)
             .withOption(target)
             .withOption(targetCategories)
             .withOption(predictors)
             .withOption(types)
             .withOption(passes)
             .withOption(lambda)
             .withOption(rate)
             .withOption(noBias)
             .withOption(features)
             .create();
 
     Parser parser = new Parser();
     parser.setHelpOption(help);
     parser.setHelpTrigger("--help");
     parser.setGroup(normalArgs);
     parser.setHelpFormatter(new HelpFormatter(" ", "", " ", 130));
     CommandLine cmdLine = parser.parseAndHelp(args);
 
     if (cmdLine == null) {
       return false;
     }
 
     TrainLogistic.inputFile = getStringArgument(cmdLine, inputFile);
     TrainLogistic.outputFile = getStringArgument(cmdLine, outputFile);
 
     List typeList = Lists.newArrayList();
     for (Object x : cmdLine.getValues(types)) {
       typeList.add(x.toString());
     }
 
     List predictorList = Lists.newArrayList();
     for (Object x : cmdLine.getValues(predictors)) {
       predictorList.add(x.toString());
     }
 
     lmp = new LogisticModelParameters();
     lmp.setTargetVariable(getStringArgument(cmdLine, target));
     lmp.setMaxTargetCategories(getIntegerArgument(cmdLine, targetCategories));
     lmp.setNumFeatures(getIntegerArgument(cmdLine, features));
     lmp.setUseBias(!getBooleanArgument(cmdLine, noBias));
     lmp.setTypeMap(predictorList, typeList);
 
     lmp.setLambda(getDoubleArgument(cmdLine, lambda));
     lmp.setLearningRate(getDoubleArgument(cmdLine, rate));
 
     TrainLogistic.scores = getBooleanArgument(cmdLine, scores);
     TrainLogistic.passes = getIntegerArgument(cmdLine, passes);
 
     return true;
   }
 
   private static String getStringArgument(CommandLine cmdLine, Option inputFile) {
     return (String) cmdLine.getValue(inputFile);
   }
 
   private static boolean getBooleanArgument(CommandLine cmdLine, Option option) {
     return cmdLine.hasOption(option);
   }
 
   private static int getIntegerArgument(CommandLine cmdLine, Option features) {
     return Integer.parseInt((String) cmdLine.getValue(features));
   }
 
   private static double getDoubleArgument(CommandLine cmdLine, Option op) {
     return Double.parseDouble((String) cmdLine.getValue(op));
   }
 
  public static OnlineLogisticRegression getModel() {
    return model;
  }

  public static LogisticModelParameters getParameters() {
    return lmp;
  }

   public static class InputOpener {
     private InputOpener() {
     }
 
     public static BufferedReader open(String inputFile) throws IOException {
       InputStreamReader s;
       try {
         URL resource = Resources.getResource(inputFile);
         s = new InputStreamReader(resource.openStream());
       } catch (IllegalArgumentException e) {
         s = new FileReader(inputFile);
       }
 
       return new BufferedReader(s);
     }
   }
 }




© 2015 - 2024 Weber Informatics LLC | Privacy Policy