
org.apache.mahout.classifier.sgd.TrainAdaptiveLogistic Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of mahout-examples Show documentation
Show all versions of mahout-examples Show documentation
Scalable machine learning library examples
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.mahout.classifier.sgd;
import com.google.common.io.Resources;
import org.apache.commons.cli2.CommandLine;
import org.apache.commons.cli2.Group;
import org.apache.commons.cli2.Option;
import org.apache.commons.cli2.builder.ArgumentBuilder;
import org.apache.commons.cli2.builder.DefaultOptionBuilder;
import org.apache.commons.cli2.builder.GroupBuilder;
import org.apache.commons.cli2.commandline.Parser;
import org.apache.commons.cli2.util.HelpFormatter;
import org.apache.commons.io.Charsets;
import org.apache.mahout.classifier.sgd.AdaptiveLogisticRegression.Wrapper;
import org.apache.mahout.ep.State;
import org.apache.mahout.math.RandomAccessSparseVector;
import org.apache.mahout.math.Vector;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.OutputStream;
import java.io.OutputStreamWriter;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
public final class TrainAdaptiveLogistic {
private static String inputFile;
private static String outputFile;
private static AdaptiveLogisticModelParameters lmp;
private static int passes;
private static boolean showperf;
private static int skipperfnum = 99;
private static AdaptiveLogisticRegression model;
private TrainAdaptiveLogistic() {
}
public static void main(String[] args) throws Exception {
mainToOutput(args, new PrintWriter(new OutputStreamWriter(System.out, Charsets.UTF_8), true));
}
static void mainToOutput(String[] args, PrintWriter output) throws Exception {
if (parseArgs(args)) {
CsvRecordFactory csv = lmp.getCsvRecordFactory();
model = lmp.createAdaptiveLogisticRegression();
State best;
CrossFoldLearner learner = null;
int k = 0;
for (int pass = 0; pass < passes; pass++) {
BufferedReader in = open(inputFile);
// read variable names
csv.firstLine(in.readLine());
String line = in.readLine();
while (line != null) {
// for each new line, get target and predictors
Vector input = new RandomAccessSparseVector(lmp.getNumFeatures());
int targetValue = csv.processLine(line, input);
// update model
model.train(targetValue, input);
k++;
if (showperf && (k % (skipperfnum + 1) == 0)) {
best = model.getBest();
if (best != null) {
learner = best.getPayload().getLearner();
}
if (learner != null) {
double averageCorrect = learner.percentCorrect();
double averageLL = learner.logLikelihood();
output.printf("%d\t%.3f\t%.2f%n",
k, averageLL, averageCorrect * 100);
} else {
output.printf(Locale.ENGLISH,
"%10d %2d %s%n", k, targetValue,
"AdaptiveLogisticRegression has not found a good model ......");
}
}
line = in.readLine();
}
in.close();
}
best = model.getBest();
if (best != null) {
learner = best.getPayload().getLearner();
}
if (learner == null) {
output.println("AdaptiveLogisticRegression has failed to train a model.");
return;
}
try (OutputStream modelOutput = new FileOutputStream(outputFile)) {
lmp.saveTo(modelOutput);
}
OnlineLogisticRegression lr = learner.getModels().get(0);
output.println(lmp.getNumFeatures());
output.println(lmp.getTargetVariable() + " ~ ");
String sep = "";
for (String v : csv.getTraceDictionary().keySet()) {
double weight = predictorWeight(lr, 0, csv, v);
if (weight != 0) {
output.printf(Locale.ENGLISH, "%s%.3f*%s", sep, weight, v);
sep = " + ";
}
}
output.printf("%n");
for (int row = 0; row < lr.getBeta().numRows(); row++) {
for (String key : csv.getTraceDictionary().keySet()) {
double weight = predictorWeight(lr, row, csv, key);
if (weight != 0) {
output.printf(Locale.ENGLISH, "%20s %.5f%n", key, weight);
}
}
for (int column = 0; column < lr.getBeta().numCols(); column++) {
output.printf(Locale.ENGLISH, "%15.9f ", lr.getBeta().get(row, column));
}
output.println();
}
}
}
private static double predictorWeight(OnlineLogisticRegression lr, int row, RecordFactory csv, String predictor) {
double weight = 0;
for (Integer column : csv.getTraceDictionary().get(predictor)) {
weight += lr.getBeta().get(row, column);
}
return weight;
}
private static boolean parseArgs(String[] args) {
DefaultOptionBuilder builder = new DefaultOptionBuilder();
Option help = builder.withLongName("help")
.withDescription("print this list").create();
Option quiet = builder.withLongName("quiet")
.withDescription("be extra quiet").create();
ArgumentBuilder argumentBuilder = new ArgumentBuilder();
Option showperf = builder
.withLongName("showperf")
.withDescription("output performance measures during training")
.create();
Option inputFile = builder
.withLongName("input")
.withRequired(true)
.withArgument(
argumentBuilder.withName("input").withMaximum(1)
.create())
.withDescription("where to get training data").create();
Option outputFile = builder
.withLongName("output")
.withRequired(true)
.withArgument(
argumentBuilder.withName("output").withMaximum(1)
.create())
.withDescription("where to write the model content").create();
Option threads = builder.withLongName("threads")
.withArgument(
argumentBuilder.withName("threads").withDefault("4").create())
.withDescription("the number of threads AdaptiveLogisticRegression uses")
.create();
Option predictors = builder.withLongName("predictors")
.withRequired(true)
.withArgument(argumentBuilder.withName("predictors").create())
.withDescription("a list of predictor variables").create();
Option types = builder
.withLongName("types")
.withRequired(true)
.withArgument(argumentBuilder.withName("types").create())
.withDescription(
"a list of predictor variable types (numeric, word, or text)")
.create();
Option target = builder
.withLongName("target")
.withDescription("the name of the target variable")
.withRequired(true)
.withArgument(
argumentBuilder.withName("target").withMaximum(1)
.create())
.create();
Option targetCategories = builder
.withLongName("categories")
.withDescription("the number of target categories to be considered")
.withRequired(true)
.withArgument(argumentBuilder.withName("categories").withMaximum(1).create())
.create();
Option features = builder
.withLongName("features")
.withDescription("the number of internal hashed features to use")
.withArgument(
argumentBuilder.withName("numFeatures")
.withDefault("1000").withMaximum(1).create())
.create();
Option passes = builder
.withLongName("passes")
.withDescription("the number of times to pass over the input data")
.withArgument(
argumentBuilder.withName("passes").withDefault("2")
.withMaximum(1).create())
.create();
Option interval = builder.withLongName("interval")
.withArgument(
argumentBuilder.withName("interval").withDefault("500").create())
.withDescription("the interval property of AdaptiveLogisticRegression")
.create();
Option window = builder.withLongName("window")
.withArgument(
argumentBuilder.withName("window").withDefault("800").create())
.withDescription("the average propery of AdaptiveLogisticRegression")
.create();
Option skipperfnum = builder.withLongName("skipperfnum")
.withArgument(
argumentBuilder.withName("skipperfnum").withDefault("99").create())
.withDescription("show performance measures every (skipperfnum + 1) rows")
.create();
Option prior = builder.withLongName("prior")
.withArgument(
argumentBuilder.withName("prior").withDefault("L1").create())
.withDescription("the prior algorithm to use: L1, L2, ebp, tp, up")
.create();
Option priorOption = builder.withLongName("prioroption")
.withArgument(
argumentBuilder.withName("prioroption").create())
.withDescription("constructor parameter for ElasticBandPrior and TPrior")
.create();
Option auc = builder.withLongName("auc")
.withArgument(
argumentBuilder.withName("auc").withDefault("global").create())
.withDescription("the auc to use: global or grouped")
.create();
Group normalArgs = new GroupBuilder().withOption(help)
.withOption(quiet).withOption(inputFile).withOption(outputFile)
.withOption(target).withOption(targetCategories)
.withOption(predictors).withOption(types).withOption(passes)
.withOption(interval).withOption(window).withOption(threads)
.withOption(prior).withOption(features).withOption(showperf)
.withOption(skipperfnum).withOption(priorOption).withOption(auc)
.create();
Parser parser = new Parser();
parser.setHelpOption(help);
parser.setHelpTrigger("--help");
parser.setGroup(normalArgs);
parser.setHelpFormatter(new HelpFormatter(" ", "", " ", 130));
CommandLine cmdLine = parser.parseAndHelp(args);
if (cmdLine == null) {
return false;
}
TrainAdaptiveLogistic.inputFile = getStringArgument(cmdLine, inputFile);
TrainAdaptiveLogistic.outputFile = getStringArgument(cmdLine,
outputFile);
List typeList = new ArrayList<>();
for (Object x : cmdLine.getValues(types)) {
typeList.add(x.toString());
}
List predictorList = new ArrayList<>();
for (Object x : cmdLine.getValues(predictors)) {
predictorList.add(x.toString());
}
lmp = new AdaptiveLogisticModelParameters();
lmp.setTargetVariable(getStringArgument(cmdLine, target));
lmp.setMaxTargetCategories(getIntegerArgument(cmdLine, targetCategories));
lmp.setNumFeatures(getIntegerArgument(cmdLine, features));
lmp.setInterval(getIntegerArgument(cmdLine, interval));
lmp.setAverageWindow(getIntegerArgument(cmdLine, window));
lmp.setThreads(getIntegerArgument(cmdLine, threads));
lmp.setAuc(getStringArgument(cmdLine, auc));
lmp.setPrior(getStringArgument(cmdLine, prior));
if (cmdLine.getValue(priorOption) != null) {
lmp.setPriorOption(getDoubleArgument(cmdLine, priorOption));
}
lmp.setTypeMap(predictorList, typeList);
TrainAdaptiveLogistic.showperf = getBooleanArgument(cmdLine, showperf);
TrainAdaptiveLogistic.skipperfnum = getIntegerArgument(cmdLine, skipperfnum);
TrainAdaptiveLogistic.passes = getIntegerArgument(cmdLine, passes);
lmp.checkParameters();
return true;
}
private static String getStringArgument(CommandLine cmdLine,
Option inputFile) {
return (String) cmdLine.getValue(inputFile);
}
private static boolean getBooleanArgument(CommandLine cmdLine, Option option) {
return cmdLine.hasOption(option);
}
private static int getIntegerArgument(CommandLine cmdLine, Option features) {
return Integer.parseInt((String) cmdLine.getValue(features));
}
private static double getDoubleArgument(CommandLine cmdLine, Option op) {
return Double.parseDouble((String) cmdLine.getValue(op));
}
public static AdaptiveLogisticRegression getModel() {
return model;
}
public static LogisticModelParameters getParameters() {
return lmp;
}
static BufferedReader open(String inputFile) throws IOException {
InputStream in;
try {
in = Resources.getResource(inputFile).openStream();
} catch (IllegalArgumentException e) {
in = new FileInputStream(new File(inputFile));
}
return new BufferedReader(new InputStreamReader(in, Charsets.UTF_8));
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy