All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.mayabot.nlp.fasttext.args.InputArgs.kt Maven / Gradle / Ivy

package com.mayabot.nlp.fasttext.args

import com.mayabot.nlp.fasttext.loss.LossName
import java.io.File


class InputArgs {

    // The following arguments for training are optional:
    /**
     * learn rate
     */
    var lr: Double? = null

    /**
     * change the rate of updates for the learning rate [100]
     */
    var lrUpdateRate: Int? = null

    /**
     * size of word vectors [100]
     */
    var dim: Int? = null

    /**
     * size of the context window [5]
     */
    var ws: Int? = null

    /**
     * number of epochs [5]
     */
    var epoch: Int? = null

    /**
     * number of negatives sampled [5]
     */
    var neg: Int? = null

    /**
     * loss function {ns, hs, softmax} [softmax]
     */
    var loss: LossName? = null

    /**
     * number of threads [12]
     */
    var thread: Int? = null

    /**
     * pretrained word vectors for supervised learning
     */
    var pretrainedVectors: File? = null


    // The following arguments for the dictionary are optional:
    var minCount: Int? = null
    var minCountLabel: Int? = null
    var wordNgrams: Int? = null
    var bucket: Int? = null
    var minn: Int? = null
    var maxn: Int? = null
    var t: Double? = null
    var label: String? = null

    var maxVocabSize: Int? = null
    var seed: Int? = null
    var preTrainedVectors: File? = null

    /**
     * 保留预训练词典到dict中
     */
    var keepPreTrainedVector: Boolean = false

    /**
     * validation file to be used for evaluation
     */
    var autotuneValidationFile: File? = null

    /**
     *  metric objective {f1, f1:labelname}
     */
    var autotuneMetric: String? = null

    /**
     * number of predictions used for evaluation
     */
    var autotunePredictions: Int? = null

    /**
     * maximum duration in seconds. default 5 minutes
     */
    var autotuneDuration: Int? = null

    /**
     * constraint model file size
     */
    var autotuneModelSize: String? = null


    fun parse(model: ModelName): Args {
        var temp = Args(model = model)

        if (model == ModelName.sup) {
            temp = temp.copy(
                    loss = LossName.softmax,
                    minCount = 1,
                    minn = 0,
                    maxn = 0,
                    lr = 0.1
            )
        }

        this.maxVocabSize?.let {
            temp.setManual("maxVocabSize")
            temp = temp.copy(maxVocabSize=it)
        }
        this.seed?.let {
            temp.setManual("seed")
            temp = temp.copy(seed = it)
        }
        this.preTrainedVectors?.let {
            temp.setManual("preTrainedVectors")
            temp = temp.copy(preTrainedVectors = it)
        }

        temp = temp.copy(
                keepPreTrainedVector = keepPreTrainedVector
        )

        this.autotuneValidationFile?.let {
            temp.setManual("autotuneValidationFile")
            temp = temp.copy(autotuneValidationFile = it)
        }

        this.autotuneMetric?.let {
            temp.setManual("autotuneMetric")
            temp = temp.copy(autotuneMetric = it)
        }

        this.autotunePredictions?.let {
            temp.setManual("autotunePredictions")
            temp = temp.copy(autotunePredictions=it)
        }

        this.autotuneDuration?.let {
            temp.setManual("autotuneDuration")
            temp = temp.copy(autotuneDuration=it)
        }

        this.autotuneModelSize?.let {
            temp.setManual("autotuneModelSize")
            temp = temp.copy(autotuneModelSize=it)
        }

        this.lr?.let {
            temp.setManual("lr")
            temp = temp.copy(lr = it)
        }

        this.minCount?.let {
            temp.setManual("minCount")
            temp = temp.copy(minCount = it)
        }

        this.minCountLabel?.let {
            temp.setManual("minCountLabel")
            temp = temp.copy(minCountLabel = it)
        }

        this.wordNgrams?.let {
            temp.setManual("wordNgrams")
            temp = temp.copy(wordNgrams = it)
        }

        this.bucket?.let {
            temp.setManual("bucket")
            temp = temp.copy(bucket = it)
        }

        this.minn?.let {
            temp.setManual("minn")
            temp = temp.copy(minn = it)
        }

        this.maxn?.let {
            temp.setManual("maxn")
            temp = temp.copy(maxn = it)
        }
        this.t?.let {
            temp.setManual("t")
            temp = temp.copy(t = it)
        }

        this.lrUpdateRate?.let {
            temp.setManual("lrUpdateRate")
            temp = temp.copy(lrUpdateRate = it)
        }

        this.dim?.let {
            temp.setManual("dim")
            temp = temp.copy(dim = it)
        }

        this.ws?.let {
            temp.setManual("ws")
            temp = temp.copy(ws = it)
        }

        this.epoch?.let {
            temp.setManual("epoch")
            temp = temp.copy(epoch = it)
        }

        this.neg?.let {
            temp.setManual("neg")
            temp = temp.copy(neg = it)
        }

        this.loss?.let {
            temp.setManual("loss")
            temp = temp.copy(loss = it)
        }

        /////////////////////////////////////////////
        if (temp.wordNgrams <= 1 && temp.maxn == 0) {
            temp = temp.copy(bucket = 0)
        }

        return temp
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy