All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.mahout.classifier.df.mapreduce.Classifier Maven / Gradle / Ivy

/**
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.mahout.classifier.df.mapreduce;

import com.google.common.io.Closeables;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.filecache.DistributedCache;
import org.apache.hadoop.fs.FSDataOutputStream;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.DoubleWritable;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.JobContext;
import org.apache.hadoop.mapreduce.Mapper;
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
import org.apache.hadoop.mapreduce.lib.input.FileSplit;
import org.apache.hadoop.mapreduce.lib.input.TextInputFormat;
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
import org.apache.mahout.classifier.df.DFUtils;
import org.apache.mahout.classifier.df.DecisionForest;
import org.apache.mahout.classifier.df.data.DataConverter;
import org.apache.mahout.classifier.df.data.Dataset;
import org.apache.mahout.classifier.df.data.Instance;
import org.apache.mahout.common.HadoopUtil;
import org.apache.mahout.common.Pair;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;

/**
 * Mapreduce implementation that classifies the Input data using a previousely built decision forest
 */
public class Classifier {

  private static final Logger log = LoggerFactory.getLogger(Classifier.class);

  private final Path forestPath;
  private final Path inputPath;
  private final Path datasetPath;
  private final Configuration conf;
  private final Path outputPath; // path that will containt the final output of the classifier
  private final Path mappersOutputPath; // mappers will output here
  private double[][] results;
  
  public double[][] getResults() {
    return results;
  }

  public Classifier(Path forestPath,
                    Path inputPath,
                    Path datasetPath,
                    Path outputPath,
                    Configuration conf) {
    this.forestPath = forestPath;
    this.inputPath = inputPath;
    this.datasetPath = datasetPath;
    this.outputPath = outputPath;
    this.conf = conf;

    mappersOutputPath = new Path(outputPath, "mappers");
  }

  private void configureJob(Job job) throws IOException {

    job.setJarByClass(Classifier.class);

    FileInputFormat.setInputPaths(job, inputPath);
    FileOutputFormat.setOutputPath(job, mappersOutputPath);

    job.setOutputKeyClass(DoubleWritable.class);
    job.setOutputValueClass(Text.class);

    job.setMapperClass(CMapper.class);
    job.setNumReduceTasks(0); // no reducers

    job.setInputFormatClass(CTextInputFormat.class);
    job.setOutputFormatClass(SequenceFileOutputFormat.class);

  }

  public void run() throws IOException, ClassNotFoundException, InterruptedException {
    FileSystem fs = FileSystem.get(conf);

    // check the output
    if (fs.exists(outputPath)) {
      throw new IOException("Output path already exists : " + outputPath);
    }

    log.info("Adding the dataset to the DistributedCache");
    // put the dataset into the DistributedCache
    DistributedCache.addCacheFile(datasetPath.toUri(), conf);

    log.info("Adding the decision forest to the DistributedCache");
    DistributedCache.addCacheFile(forestPath.toUri(), conf);

    Job job = new Job(conf, "decision forest classifier");

    log.info("Configuring the job...");
    configureJob(job);

    log.info("Running the job...");
    if (!job.waitForCompletion(true)) {
      throw new IllegalStateException("Job failed!");
    }

    parseOutput(job);

    HadoopUtil.delete(conf, mappersOutputPath);
  }

  /**
   * Extract the prediction for each mapper and write them in the corresponding output file. 
   * The name of the output file is based on the name of the corresponding input file.
   * Will compute the ConfusionMatrix if necessary.
   */
  private void parseOutput(JobContext job) throws IOException {
    Configuration conf = job.getConfiguration();
    FileSystem fs = mappersOutputPath.getFileSystem(conf);

    Path[] outfiles = DFUtils.listOutputFiles(fs, mappersOutputPath);

    // read all the output
    List resList = new ArrayList<>();
    for (Path path : outfiles) {
      FSDataOutputStream ofile = null;
      try {
        for (Pair record : new SequenceFileIterable(path, true, conf)) {
          double key = record.getFirst().get();
          String value = record.getSecond().toString();
          if (ofile == null) {
            // this is the first value, it contains the name of the input file
            ofile = fs.create(new Path(outputPath, value).suffix(".out"));
          } else {
            // The key contains the correct label of the data. The value contains a prediction
            ofile.writeChars(value); // write the prediction
            ofile.writeChar('\n');

            resList.add(new double[]{key, Double.valueOf(value)});
          }
        }
      } finally {
        Closeables.close(ofile, false);
      }
    }
    results = new double[resList.size()][2];
    resList.toArray(results);
  }

  /**
   * TextInputFormat that does not split the input files. This ensures that each input file is processed by one single
   * mapper.
   */
  private static class CTextInputFormat extends TextInputFormat {
    @Override
    protected boolean isSplitable(JobContext jobContext, Path path) {
      return false;
    }
  }
  
  public static class CMapper extends Mapper {

    /** used to convert input values to data instances */
    private DataConverter converter;
    private DecisionForest forest;
    private final Random rng = RandomUtils.getRandom();
    private boolean first = true;
    private final Text lvalue = new Text();
    private Dataset dataset;
    private final DoubleWritable lkey = new DoubleWritable();

    @Override
    protected void setup(Context context) throws IOException, InterruptedException {
      super.setup(context);    //To change body of overridden methods use File | Settings | File Templates.

      Configuration conf = context.getConfiguration();

      Path[] files = HadoopUtil.getCachedFiles(conf);

      if (files.length < 2) {
        throw new IOException("not enough paths in the DistributedCache");
      }
      dataset = Dataset.load(conf, files[0]);
      converter = new DataConverter(dataset);

      forest = DecisionForest.load(conf, files[1]);
      if (forest == null) {
        throw new InterruptedException("DecisionForest not found!");
      }
    }

    @Override
    protected void map(LongWritable key, Text value, Context context) throws IOException, InterruptedException {
      if (first) {
        FileSplit split = (FileSplit) context.getInputSplit();
        Path path = split.getPath(); // current split path
        lvalue.set(path.getName());
        lkey.set(key.get());
        context.write(lkey, lvalue);

        first = false;
      }

      String line = value.toString();
      if (!line.isEmpty()) {
        Instance instance = converter.convert(line);
        double prediction = forest.classify(dataset, rng, instance);
        lkey.set(dataset.getLabel(instance));
        lvalue.set(Double.toString(prediction));
        context.write(lkey, lvalue);
      }
    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy