All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.intel.analytics.zoo.examples.textclassification.README.md Maven / Gradle / Ivy

# Summary
This text classification example uses pre-trained GloVe embeddings to convert words to vectors,
and trains a CNN, LSTM or GRU `TextClassifier` model on 20 Newsgroup dataset.
It was first described in: https://blog.keras.io/using-pre-trained-word-embeddings-in-a-keras-model.html

A CNN `TextClassifier` model can achieve around 85% accuracy after 20 epochs of training.
LSTM and GRU models are a little bit difficult to train, and more epochs are needed to achieve compatible results.


## Download Analytics Zoo
You can download Analytics Zoo prebuilt release and nightly build package from [here](https://analytics-zoo.github.io/master/#release-download/) and extract it.


## Data Preparation
The data used in this example are:
- [20 Newsgroup dataset](http://qwone.com/~jason/20Newsgroups/20news-18828.tar.gz) which contains 20 categories and with 18828 texts in total.
- [GloVe word embeddings](http://nlp.stanford.edu/data/glove.6B.zip): embeddings of 400k words pre-trained on a 2014 dump of English Wikipedia.

You need to prepare the data by yourself beforehand.

The following scripts we provide will serve to download and extract the data for you:
```bash
bash ${ANALYTICS_ZOO_HOME}/bin/data/news20/get_news20.sh dir
bash ${ANALYTICS_ZOO_HOME}/bin/data/glove/get_glove.sh dir
```
Remarks:
- `ANALYTICS_ZOO_HOME` is the folder where you extract the downloaded package and `dir` is the directory you wish to locate the corresponding downloaded data.
- If `dir` is not specified, the data will be downloaded to the current working directory.


## Run this example
Run the following command for Spark local mode (`MASTER=local[*]`) or cluster mode:

```bash
export SPARK_HOME=the root directory of Spark
export ANALYTICS_ZOO_HOME=the folder where you extract the downloaded Analytics Zoo zip package
MASTER=...
news20Path=the directory containing News20 dataset
glovePath=the directory containing GloVe embeddings

${ANALYTICS_ZOO_HOME}/bin/spark-shell-with-zoo.sh \
    --master ${MASTER} \
    --driver-memory 2g \
    --executor-memory 2g \
    --class com.intel.analytics.zoo.examples.textclassification.TextClassification \
    --dataPath ${news20Path} \
    --embeddingPath ${glovePath}
```
See [here](#options) for more configurable options for this example.


## Options
* `--dataPath` This option is __required__. The path where News20 dataset locate.
* `--embeddingPath` This option is __required__. The path where GloVe embeddings locate.
* `--outputPath` If specified, the trained model `text_classifier.model` and word dictionary file `word_index.txt` will be saved under this path. It can be either a local or distributed file system path.
* `--classNum` The number of classes to do classification. Default is 20 for News20 dataset.
* `--partitionNum` The number of partitions to cut the dataset into. Default is 4.
* `--tokenLength` The size of each word vector. GloVe supports tokenLength 50, 100, 200 and 300. Default is 200.
* `--sequenceLength` The length of a sequence. Default is 500.
* `--maxWordsNum` The maximum number of words sorted by frequencies to be taken into consideration. Default is 5000.
* `--encoder` The encoder for the input sequence. String, 'cnn' or 'lstm' or 'gru'. Default is 'cnn'.
* `--encoderOutputDim` The output dimension of the encoder. Default is 256.
* `--trainingSplit` The split portion of the data for training. Default is 0.8.
* `-b` `--batchSize` The number of samples per gradient update. Default is 128.
* `-e` `--nbEpoch` The number of iterations to train the model. Default is 20.
* `-l` `--learningRate` The learning rate for the TextClassifier model. Default is 0.01.
* `-m` `--model` Specify this option only if you want to load an existing TextClassifier model and in this case its path should be provided.


## Results
You can find the accuracy information from the log during the training process:
```
INFO DistriOptimizer$: [Epoch 20 15120/15044][Iteration 2700][Wall Clock 643.613424167s] Validate model...
INFO DistriOptimizer$: [Epoch 20 15120/15044][Iteration 2700][Wall Clock 643.613424167s] Top1Accuracy is Accuracy(correct: 3207, count: 3784, accuracy: 0.8475158562367865)
```




© 2015 - 2025 Weber Informatics LLC | Privacy Policy