ai.h2o.sparkling.ml.params.H2OWord2VecParams.scala Maven / Gradle / Ivy
The newest version!
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ai.h2o.sparkling.ml.params
import hex.word2vec.Word2VecModel.Word2VecParameters
import ai.h2o.sparkling.H2OFrame
import hex.word2vec.Word2Vec.NormModel
import hex.word2vec.Word2Vec.WordModel
trait H2OWord2VecParams
extends H2OAlgoParamsBase {
protected def paramTag = reflect.classTag[Word2VecParameters]
//
// Parameter definitions
//
protected val vecSize = intParam(
name = "vecSize",
doc = """Set size of word vectors.""")
protected val windowSize = intParam(
name = "windowSize",
doc = """Set max skip length between words.""")
protected val sentSampleRate = floatParam(
name = "sentSampleRate",
doc = """Set threshold for occurrence of words. Those that appear with higher frequency in the training data
will be randomly down-sampled; useful range is (0, 1e-5).""")
protected val normModel = stringParam(
name = "normModel",
doc = """Use Hierarchical Softmax. Possible values are ``"HSM"``.""")
protected val epochs = intParam(
name = "epochs",
doc = """Number of training iterations to run.""")
protected val minWordFreq = intParam(
name = "minWordFreq",
doc = """This will discard words that appear less than times.""")
protected val initLearningRate = floatParam(
name = "initLearningRate",
doc = """Set the starting learning rate.""")
protected val wordModel = stringParam(
name = "wordModel",
doc = """The word model to use (SkipGram or CBOW). Possible values are ``"SkipGram"``, ``"CBOW"``.""")
protected val modelId = nullableStringParam(
name = "modelId",
doc = """Destination id for this model; auto-generated if not specified.""")
protected val maxRuntimeSecs = doubleParam(
name = "maxRuntimeSecs",
doc = """Maximum allowed runtime in seconds for model training. Use 0 to disable.""")
protected val exportCheckpointsDir = nullableStringParam(
name = "exportCheckpointsDir",
doc = """Automatically export generated models to this directory.""")
//
// Default values
//
setDefault(
vecSize -> 100,
windowSize -> 5,
sentSampleRate -> 0.001f,
normModel -> NormModel.HSM.name(),
epochs -> 5,
minWordFreq -> 5,
initLearningRate -> 0.025f,
wordModel -> WordModel.SkipGram.name(),
modelId -> null,
maxRuntimeSecs -> 0.0,
exportCheckpointsDir -> null)
//
// Getters
//
def getVecSize(): Int = $(vecSize)
def getWindowSize(): Int = $(windowSize)
def getSentSampleRate(): Float = $(sentSampleRate)
def getNormModel(): String = $(normModel)
def getEpochs(): Int = $(epochs)
def getMinWordFreq(): Int = $(minWordFreq)
def getInitLearningRate(): Float = $(initLearningRate)
def getWordModel(): String = $(wordModel)
def getModelId(): String = $(modelId)
def getMaxRuntimeSecs(): Double = $(maxRuntimeSecs)
def getExportCheckpointsDir(): String = $(exportCheckpointsDir)
//
// Setters
//
def setVecSize(value: Int): this.type = {
set(vecSize, value)
}
def setWindowSize(value: Int): this.type = {
set(windowSize, value)
}
def setSentSampleRate(value: Float): this.type = {
set(sentSampleRate, value)
}
def setNormModel(value: String): this.type = {
val validated = EnumParamValidator.getValidatedEnumValue[NormModel](value)
set(normModel, validated)
}
def setEpochs(value: Int): this.type = {
set(epochs, value)
}
def setMinWordFreq(value: Int): this.type = {
set(minWordFreq, value)
}
def setInitLearningRate(value: Float): this.type = {
set(initLearningRate, value)
}
def setWordModel(value: String): this.type = {
val validated = EnumParamValidator.getValidatedEnumValue[WordModel](value)
set(wordModel, validated)
}
def setModelId(value: String): this.type = {
set(modelId, value)
}
def setMaxRuntimeSecs(value: Double): this.type = {
set(maxRuntimeSecs, value)
}
def setExportCheckpointsDir(value: String): this.type = {
set(exportCheckpointsDir, value)
}
override private[sparkling] def getH2OAlgorithmParams(trainingFrame: H2OFrame): Map[String, Any] = {
super.getH2OAlgorithmParams(trainingFrame) ++ getH2OWord2VecParams(trainingFrame)
}
private[sparkling] def getH2OWord2VecParams(trainingFrame: H2OFrame): Map[String, Any] = {
Map(
"vec_size" -> getVecSize(),
"window_size" -> getWindowSize(),
"sent_sample_rate" -> getSentSampleRate(),
"norm_model" -> getNormModel(),
"epochs" -> getEpochs(),
"min_word_freq" -> getMinWordFreq(),
"init_learning_rate" -> getInitLearningRate(),
"word_model" -> getWordModel(),
"model_id" -> getModelId(),
"max_runtime_secs" -> getMaxRuntimeSecs(),
"export_checkpoints_dir" -> getExportCheckpointsDir())
}
override private[sparkling] def getSWtoH2OParamNameMap(): Map[String, String] = {
super.getSWtoH2OParamNameMap() ++
Map(
"vecSize" -> "vec_size",
"windowSize" -> "window_size",
"sentSampleRate" -> "sent_sample_rate",
"normModel" -> "norm_model",
"epochs" -> "epochs",
"minWordFreq" -> "min_word_freq",
"initLearningRate" -> "init_learning_rate",
"wordModel" -> "word_model",
"modelId" -> "model_id",
"maxRuntimeSecs" -> "max_runtime_secs",
"exportCheckpointsDir" -> "export_checkpoints_dir")
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy