All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.tencent.angel.sona.graph.embedding.Param.scala Maven / Gradle / Ivy

/*
 * Tencent is pleased to support the open source community by making Angel available.
 *
 * Copyright (C) 2017-2018 THL A29 Limited, a Tencent company. All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in
 * compliance with the License. You may obtain a copy of the License at
 *
 * https://opensource.org/licenses/Apache-2.0
 *
 * Unless required by applicable law or agreed to in writing, software distributed under the License
 * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
 * or implied. See the License for the specific language governing permissions and limitations under
 * the License.
 *
 */


package com.tencent.angel.sona.graph.embedding

class Param extends Serializable {

  var partitionNum: Int = _
  var windowSize: Int = _
  var embeddingDim: Int = _
  var negSample: Int = _
  var learningRate: Float = _
  var batchSize: Int = _
  var numEpoch: Int = _
  var maxIndex: Int = _
  var sampleRate: Float = _
  var numPSPart: Int = 1
  var modelPath: String = _
  var checkpointInterval: Int = Int.MaxValue
  var saveModelInterval: Int = Int.MaxValue
  var order: Int = _
  var nodesNumPerRow: Int = -1
  var numRowDataSet: Option[Long] = None
  var seed: Int = _
  var maxLength: Int = -1
  var model: String = "cbow"

  def setModel(model: String): this.type = {
    this.model = model
    this
  }

  def setMaxLength(maxLength: Int): this.type = {
    this.maxLength = maxLength
    this
  }

  def setNumRowDataSet(numRowDataSet: Long): this.type = {
    this.numRowDataSet = Some(numRowDataSet)
    this
  }

  def setSeed(seed: Int): this.type = {
    this.seed = seed
    this
  }

  def setPartitionNum(partitionNum: Int): this.type = {
    require(partitionNum > 0, s"require partitionNum > 0, $partitionNum given")
    this.partitionNum = partitionNum
    this
  }

  def setWindowSize(windowSize: Int): this.type = {
    require(windowSize > 0, s"require windowSize > 0, $windowSize given")
    this.windowSize = windowSize
    this
  }

  def setEmbeddingDim(embeddingDim: Int): this.type = {
    require(embeddingDim > 0, s"require embedding dimension > 0, $embeddingDim given")
    this.embeddingDim = embeddingDim
    this
  }

  def setNegSample(negSample: Int): this.type = {
    require(negSample > 0, s"require num of negative sample > 0, $negSample given")
    this.negSample = negSample
    this
  }

  def setLearningRate(learningRate: Float): this.type = {
    require(learningRate > 0, s"require learning rate > 0, $learningRate given")
    this.learningRate = learningRate
    this
  }

  def setBatchSize(batchSize: Int): this.type = {
    require(batchSize > 0, s"require batch size > 0, $batchSize given")
    this.batchSize = batchSize
    this
  }

  def setNumEpoch(numEpoch: Int): this.type = {
    require(numEpoch > 0, s"require num of epoch > 0, $numEpoch given")
    this.numEpoch = numEpoch
    this
  }

  def setMaxIndex(maxIndex: Long): this.type = {
    require(maxIndex > 0 && maxIndex < Int.MaxValue, s"require maxIndex > 0 && maxIndex < Int.maxValue, $maxIndex given")
    this.maxIndex = maxIndex.toInt
    this
  }

  def setSampleRate(sampleRate: Float): this.type = {
    require(sampleRate > 0, s"sample rate belongs to [0, 1], $sampleRate given")
    this.sampleRate = sampleRate
    this
  }

  def setModelPath(modelPath: String): this.type = {
    require(null != modelPath && modelPath.nonEmpty, s"require non empty path to save model, $modelPath given")
    this.modelPath = modelPath
    this
  }

  def setModelCPInterval(modelCPInterval: Int): this.type = {
    require(modelCPInterval > 0, s"model checkpoint interval > 0, $modelCPInterval given")
    this.checkpointInterval = modelCPInterval
    this
  }

  def setModelSaveInterval(modelSaveInterval: Int): this.type = {
    require(modelSaveInterval > 0, s"model save interval > 0, $modelSaveInterval given")
    this.saveModelInterval = modelSaveInterval
    this
  }

  def setOrder(order: Int): this.type = {
    require(order == 1 || order == 2, s"order equals 1 or 2, $order given")
    this.order = order
    this
  }

  def setNumPSPart(numPSPart: Option[Int]): this.type = {
    require(numPSPart.fold(true)(_ > 0), s"require num of PS part > 0, $numPSPart given")
    numPSPart.foreach(this.numPSPart = _)
    this
  }

  def setNodesNumPerRow(nodesNumPerRow: Option[Int]): this.type = {
    nodesNumPerRow.foreach(this.nodesNumPerRow = _)
    this
  }


}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy