All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.mahout.classifier.sgd.TrainASFEmail Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.mahout.classifier.sgd;

import com.google.common.collect.HashMultiset;
import com.google.common.collect.Multiset;
import com.google.common.collect.Ordering;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.fs.PathFilter;
import org.apache.hadoop.io.Text;
import org.apache.mahout.common.AbstractJob;
import org.apache.mahout.common.Pair;
import org.apache.mahout.common.iterator.sequencefile.PathType;
import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterator;
import org.apache.mahout.ep.State;
import org.apache.mahout.math.VectorWritable;
import org.apache.mahout.vectorizer.encoders.Dictionary;

import java.io.File;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;

public final class TrainASFEmail extends AbstractJob {

  private TrainASFEmail() {
  }

  @Override
  public int run(String[] args) throws Exception {
    addInputOption();
    addOutputOption();
    addOption("categories", "nc", "The number of categories to train on", true);
    addOption("cardinality", "c", "The size of the vectors to use", "100000");
    addOption("threads", "t", "The number of threads to use in the learner", "20");
    addOption("poolSize", "p", "The number of CrossFoldLearners to use in the AdaptiveLogisticRegression. "
                               + "Higher values require more memory.", "5");
    if (parseArguments(args) == null) {
      return -1;
    }

    File base = new File(getInputPath().toString());

    Multiset overallCounts = HashMultiset.create();
    File output = new File(getOutputPath().toString());
    output.mkdirs();
    int numCats = Integer.parseInt(getOption("categories"));
    int cardinality = Integer.parseInt(getOption("cardinality", "100000"));
    int threadCount = Integer.parseInt(getOption("threads", "20"));
    int poolSize = Integer.parseInt(getOption("poolSize", "5"));
    Dictionary asfDictionary = new Dictionary();
    AdaptiveLogisticRegression learningAlgorithm =
        new AdaptiveLogisticRegression(numCats, cardinality, new L1(), threadCount, poolSize);
    learningAlgorithm.setInterval(800);
    learningAlgorithm.setAveragingWindow(500);

    //We ran seq2encoded and split input already, so let's just build up the dictionary
    Configuration conf = new Configuration();
    PathFilter trainFilter = new PathFilter() {
      @Override
      public boolean accept(Path path) {
        return path.getName().contains("training");
      }
    };
    SequenceFileDirIterator iter =
        new SequenceFileDirIterator<>(new Path(base.toString()), PathType.LIST, trainFilter, null, true, conf);
    long numItems = 0;
    while (iter.hasNext()) {
      Pair next = iter.next();
      asfDictionary.intern(next.getFirst().toString());
      numItems++;
    }

    System.out.println(numItems + " training files");

    SGDInfo info = new SGDInfo();

    iter = new SequenceFileDirIterator<>(new Path(base.toString()), PathType.LIST, trainFilter,
            null, true, conf);
    int k = 0;
    while (iter.hasNext()) {
      Pair next = iter.next();
      String ng = next.getFirst().toString();
      int actual = asfDictionary.intern(ng);
      //we already have encoded
      learningAlgorithm.train(actual, next.getSecond().get());
      k++;
      State best = learningAlgorithm.getBest();

      SGDHelper.analyzeState(info, 0, k, best);
    }
    learningAlgorithm.close();
    //TODO: how to dissection since we aren't processing the files here
    //SGDHelper.dissect(leakType, asfDictionary, learningAlgorithm, files, overallCounts);
    System.out.println("exiting main, writing model to " + output);

    ModelSerializer.writeBinary(output + "/asf.model",
            learningAlgorithm.getBest().getPayload().getLearner().getModels().get(0));

    List counts = new ArrayList<>();
    System.out.println("Word counts");
    for (String count : overallCounts.elementSet()) {
      counts.add(overallCounts.count(count));
    }
    Collections.sort(counts, Ordering.natural().reverse());
    k = 0;
    for (Integer count : counts) {
      System.out.println(k + "\t" + count);
      k++;
      if (k > 1000) {
        break;
      }
    }
    return 0;
  }

  public static void main(String[] args) throws Exception {
    TrainASFEmail trainer = new TrainASFEmail();
    trainer.run(args);
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy