All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.mahout.classifier.mlp.TrainMultilayerPerceptron Maven / Gradle / Ivy

/**
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.mahout.classifier.mlp;

import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import com.google.common.base.Preconditions;
import org.apache.commons.cli2.CommandLine;
import org.apache.commons.cli2.Group;
import org.apache.commons.cli2.Option;
import org.apache.commons.cli2.builder.ArgumentBuilder;
import org.apache.commons.cli2.builder.DefaultOptionBuilder;
import org.apache.commons.cli2.builder.GroupBuilder;
import org.apache.commons.cli2.commandline.Parser;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.mahout.math.Arrays;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Vector;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/** Train a {@link MultilayerPerceptron}.
 * @deprecated as of as of 0.10.0.
 * */
@Deprecated
public final class TrainMultilayerPerceptron {

  private static final Logger log = LoggerFactory.getLogger(TrainMultilayerPerceptron.class);

  /**  The parameters used by MLP. */
  static class Parameters {
    double learningRate;
    double momemtumWeight;
    double regularizationWeight;

    String inputFilePath;
    boolean skipHeader;
    Map labelsIndex = new HashMap<>();

    String modelFilePath;
    boolean updateModel;
    List layerSizeList = new ArrayList<>();
    String squashingFunctionName;
  }

  public static void main(String[] args) throws Exception {
    Parameters parameters = new Parameters();

    if (parseArgs(args, parameters)) {
      log.info("Validate model...");
      // check whether the model already exists
      Path modelPath = new Path(parameters.modelFilePath);
      FileSystem modelFs = modelPath.getFileSystem(new Configuration());
      MultilayerPerceptron mlp;

      if (modelFs.exists(modelPath) && parameters.updateModel) {
        // incrementally update existing model
        log.info("Build model from existing model...");
        mlp = new MultilayerPerceptron(parameters.modelFilePath);
      } else {
        if (modelFs.exists(modelPath)) {
          modelFs.delete(modelPath, true); // delete the existing file
        }
        log.info("Build model from scratch...");
        mlp = new MultilayerPerceptron();
        for (int i = 0; i < parameters.layerSizeList.size(); ++i) {
          if (i != parameters.layerSizeList.size() - 1) {
            mlp.addLayer(parameters.layerSizeList.get(i), false, parameters.squashingFunctionName);
          } else {
            mlp.addLayer(parameters.layerSizeList.get(i), true, parameters.squashingFunctionName);
          }
          mlp.setCostFunction("Minus_Squared");
          mlp.setLearningRate(parameters.learningRate)
              .setMomentumWeight(parameters.momemtumWeight)
              .setRegularizationWeight(parameters.regularizationWeight);
        }
        mlp.setModelPath(parameters.modelFilePath);
      }

      // set the parameters
      mlp.setLearningRate(parameters.learningRate)
          .setMomentumWeight(parameters.momemtumWeight)
          .setRegularizationWeight(parameters.regularizationWeight);

      // train by the training data
      Path trainingDataPath = new Path(parameters.inputFilePath);
      FileSystem dataFs = trainingDataPath.getFileSystem(new Configuration());

      Preconditions.checkArgument(dataFs.exists(trainingDataPath), "Training dataset %s cannot be found!",
          parameters.inputFilePath);

      log.info("Read data and train model...");

      try (BufferedReader reader = new BufferedReader(new InputStreamReader(dataFs.open(trainingDataPath)))) {
        String line;
        // read training data line by line
        if (parameters.skipHeader) {
          reader.readLine();
        }

        int labelDimension = parameters.labelsIndex.size();
        while ((line = reader.readLine()) != null) {
          String[] token = line.split(",");
          String label = token[token.length - 1];
          int labelIndex = parameters.labelsIndex.get(label);

          double[] instances = new double[token.length - 1 + labelDimension];
          for (int i = 0; i < token.length - 1; ++i) {
            instances[i] = Double.parseDouble(token[i]);
          }
          for (int i = 0; i < labelDimension; ++i) {
            instances[token.length - 1 + i] = 0;
          }
          // set the corresponding dimension
          instances[token.length - 1 + labelIndex] = 1;

          Vector instance = new DenseVector(instances).viewPart(0, instances.length);
          mlp.trainOnline(instance);
        }

        // write model back
        log.info("Write trained model to {}", parameters.modelFilePath);
        mlp.writeModelToFile();
        mlp.close();
      }
    }
  }

  /**
   * Parse the input arguments.
   *
   * @param args The input arguments
   * @param parameters The parameters parsed.
   * @return Whether the input arguments are valid.
   * @throws Exception
   */
  private static boolean parseArgs(String[] args, Parameters parameters) throws Exception {
    // build the options
    log.info("Validate and parse arguments...");
    DefaultOptionBuilder optionBuilder = new DefaultOptionBuilder();
    GroupBuilder groupBuilder = new GroupBuilder();
    ArgumentBuilder argumentBuilder = new ArgumentBuilder();

    // whether skip the first row of the input file
    Option skipHeaderOption = optionBuilder.withLongName("skipHeader")
        .withShortName("sh").create();

    Group skipHeaderGroup = groupBuilder.withOption(skipHeaderOption).create();

    Option inputOption = optionBuilder
        .withLongName("input")
        .withShortName("i")
        .withRequired(true)
        .withChildren(skipHeaderGroup)
        .withArgument(argumentBuilder.withName("path").withMinimum(1).withMaximum(1)
            .create()).withDescription("the file path of training dataset")
        .create();

    Option labelsOption = optionBuilder
        .withLongName("labels")
        .withShortName("labels")
        .withRequired(true)
        .withArgument(argumentBuilder.withName("label-name").withMinimum(2).create())
        .withDescription("label names").create();

    Option updateOption = optionBuilder
        .withLongName("update")
        .withShortName("u")
        .withDescription("whether to incrementally update model if the model exists")
        .create();

    Group modelUpdateGroup = groupBuilder.withOption(updateOption).create();

    Option modelOption = optionBuilder
        .withLongName("model")
        .withShortName("mo")
        .withRequired(true)
        .withArgument(argumentBuilder.withName("model-path").withMinimum(1).withMaximum(1).create())
        .withDescription("the path to store the trained model")
        .withChildren(modelUpdateGroup).create();

    Option layerSizeOption = optionBuilder
        .withLongName("layerSize")
        .withShortName("ls")
        .withRequired(true)
        .withArgument(argumentBuilder.withName("size of layer").withMinimum(2).withMaximum(5).create())
        .withDescription("the size of each layer").create();

    Option squashingFunctionOption = optionBuilder
        .withLongName("squashingFunction")
        .withShortName("sf")
        .withArgument(argumentBuilder.withName("squashing function").withMinimum(1).withMaximum(1)
            .withDefault("Sigmoid").create())
        .withDescription("the name of squashing function (currently only supports Sigmoid)")
        .create();

    Option learningRateOption = optionBuilder
        .withLongName("learningRate")
        .withShortName("l")
        .withArgument(argumentBuilder.withName("learning rate").withMaximum(1)
            .withMinimum(1).withDefault(NeuralNetwork.DEFAULT_LEARNING_RATE).create())
        .withDescription("learning rate").create();

    Option momemtumOption = optionBuilder
        .withLongName("momemtumWeight")
        .withShortName("m")
        .withArgument(argumentBuilder.withName("momemtum weight").withMaximum(1)
            .withMinimum(1).withDefault(NeuralNetwork.DEFAULT_MOMENTUM_WEIGHT).create())
        .withDescription("momemtum weight").create();

    Option regularizationOption = optionBuilder
        .withLongName("regularizationWeight")
        .withShortName("r")
        .withArgument(argumentBuilder.withName("regularization weight").withMaximum(1)
            .withMinimum(1).withDefault(NeuralNetwork.DEFAULT_REGULARIZATION_WEIGHT).create())
        .withDescription("regularization weight").create();

    // parse the input
    Parser parser = new Parser();
    Group normalOptions = groupBuilder.withOption(inputOption)
        .withOption(skipHeaderOption).withOption(updateOption)
        .withOption(labelsOption).withOption(modelOption)
        .withOption(layerSizeOption).withOption(squashingFunctionOption)
        .withOption(learningRateOption).withOption(momemtumOption)
        .withOption(regularizationOption).create();

    parser.setGroup(normalOptions);

    CommandLine commandLine = parser.parseAndHelp(args);
    if (commandLine == null) {
      return false;
    }

    parameters.learningRate = getDouble(commandLine, learningRateOption);
    parameters.momemtumWeight = getDouble(commandLine, momemtumOption);
    parameters.regularizationWeight = getDouble(commandLine, regularizationOption);

    parameters.inputFilePath = getString(commandLine, inputOption);
    parameters.skipHeader = commandLine.hasOption(skipHeaderOption);

    List labelsList = getStringList(commandLine, labelsOption);
    int currentIndex = 0;
    for (String label : labelsList) {
      parameters.labelsIndex.put(label, currentIndex++);
    }

    parameters.modelFilePath = getString(commandLine, modelOption);
    parameters.updateModel = commandLine.hasOption(updateOption);

    parameters.layerSizeList = getIntegerList(commandLine, layerSizeOption);

    parameters.squashingFunctionName = getString(commandLine, squashingFunctionOption);

    System.out.printf("Input: %s, Model: %s, Update: %s, Layer size: %s, Squashing function: %s, Learning rate: %f," +
            " Momemtum weight: %f, Regularization Weight: %f\n", parameters.inputFilePath, parameters.modelFilePath,
        parameters.updateModel, Arrays.toString(parameters.layerSizeList.toArray()),
        parameters.squashingFunctionName, parameters.learningRate, parameters.momemtumWeight,
        parameters.regularizationWeight);

    return true;
  }

  static Double getDouble(CommandLine commandLine, Option option) {
    Object val = commandLine.getValue(option);
    if (val != null) {
      return Double.parseDouble(val.toString());
    }
    return null;
  }

  static String getString(CommandLine commandLine, Option option) {
    Object val = commandLine.getValue(option);
    if (val != null) {
      return val.toString();
    }
    return null;
  }

  static List getIntegerList(CommandLine commandLine, Option option) {
    List list = commandLine.getValues(option);
    List valList = new ArrayList<>();
    for (String str : list) {
      valList.add(Integer.parseInt(str));
    }
    return valList;
  }

  static List getStringList(CommandLine commandLine, Option option) {
    return commandLine.getValues(option);
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy