All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.tencent.angel.sona.examples.online_learning.FTRLExample.scala Maven / Gradle / Ivy

/*
 * Tencent is pleased to support the open source community by making Angel available.
 *
 * Copyright (C) 2017-2018 THL A29 Limited, a Tencent company. All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in
 * compliance with the License. You may obtain a copy of the License at
 *
 * https://opensource.org/licenses/Apache-2.0
 *
 * Unless required by applicable law or agreed to in writing, software distributed under the License
 * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
 * or implied. See the License for the specific language governing permissions and limitations under
 * the License.
 *
 */
package com.tencent.angel.sona.examples.online_learning

import com.tencent.angel.conf.AngelConf
import com.tencent.angel.ml.math2.utils.{LabeledData, RowType}
import com.tencent.angel.ml.math2.vector.LongFloatVector
import com.tencent.angel.ps.storage.partitioner.ColumnRangePartitioner
import com.tencent.angel.sona.context.PSContext
import com.tencent.angel.sona.graph.utils.{DataLoader, LoadBalancePartitioner}
import com.tencent.angel.sona.online_learning.FTRL
import org.apache.hadoop.fs.Path
import org.apache.spark.angel.ml.metric.AUC
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.SparkUtil
import org.apache.spark.{SparkConf, SparkContext}


object FTRLExample {

  def start(): Unit = {
    val conf = new SparkConf()

    val sc = new SparkContext(conf)
    PSContext.getOrCreate(sc)
  }

  def stop(): Unit = {
    PSContext.stop()
    SparkContext.getOrCreate().stop()
  }

  def main(args: Array[String]): Unit = {
    val params = SparkUtil.parse(args)
    val actionType = params.getOrElse("actionType", "train").toString
    if (actionType == "train" || actionType == "incTrain") {
      train(params)
    } else {
      predict(params)
    }
    stop()
  }

  def train(params: Map[String, String]): Unit = {

    val alpha = params.getOrElse("alpha", "2.0").toDouble
    val beta = params.getOrElse("beta", "1.0").toDouble
    val lambda1 = params.getOrElse("lambda1", "0.1").toDouble
    val lambda2 = params.getOrElse("lambda2", "5.0").toDouble
    val dim = params.getOrElse("dim", "-1").toLong
    val input = params.getOrElse("input", "data/census/census_148d_train.libsvm")
    val dataType = params.getOrElse("dataType", "libsvm")
    val batchSize = params.getOrElse("batchSize", "100").toInt
    val numEpoch = params.getOrElse("numEpoch", "3").toInt
    val output = params.getOrElse("modelPath", "")
    val modelPath = params.getOrElse("model", "")
    val withBalancePartition = params.getOrElse("balance", "false").toBoolean
    val possionRate = params.getOrElse("possion", "1.0f").toFloat
    val bits = params.getOrElse("bits", "20").toInt
    val numPartitions = params.getOrElse("numPartitions", "100").toInt

    val opt = new FTRL(lambda1, lambda2, alpha, beta)
    val rowType = RowType.T_FLOAT_SPARSE_LONGKEY

    val conf = new SparkConf()
    if (modelPath.length > 0)
      conf.set(AngelConf.ANGEL_LOAD_MODEL_PATH, modelPath + "/back")
    val sc = new SparkContext(conf)
    PSContext.getOrCreate(sc)

    val inputData = sc.textFile(input)
    val data = dataType match {
        case "libsvm" =>
          inputData .map(s => (DataLoader.parseLongFloat(s, dim), DataLoader.parseLabel(s, false)))
          .map {
            f =>
              f._1.setY(f._2)
              f._1
          }
        case "dummy" =>
          inputData .map(s => (DataLoader.parseLongDummy(s, dim), DataLoader.parseLabel(s, false)))
            .map {
              f =>
                f._1.setY(f._2)
                f._1
            }
      }


    data.persist(StorageLevel.DISK_ONLY)
    val size = data.count()

    val max = data.map(f => f.getX.asInstanceOf[LongFloatVector].getStorage().getIndices.max).max()
    val min = data.map(f => f.getX.asInstanceOf[LongFloatVector].getStorage().getIndices.min).min()

    println(s"num examples = ${size} min_index=$min max_index=$max")

    if (withBalancePartition)
      opt.init(min, max + 1, rowType, data.map(f => f.getX),
        new LoadBalancePartitioner(bits, numPartitions))
    else
      opt.init(min, max + 1, -1, rowType, new ColumnRangePartitioner())

    opt.setPossionRate(possionRate)

    if (modelPath.length > 0){
      opt.load(modelPath + "/back")
    }
    for (epoch <- 1 to numEpoch) {
      val totalLoss = data.mapPartitions {
        case iterator =>
          val loss = iterator
            .sliding(batchSize, batchSize)
            .map((f: Seq[LabeledData]) => opt.optimize(f.toArray)).sum
          Iterator.single(loss)
      }.sum()

      val scores = data.sample(false, 0.01, 42).mapPartitions {
        case iterator =>
          iterator.sliding(batchSize, batchSize)
            .flatMap(f => opt.predict(f.toArray))
      }
      val auc = new AUC().calculate(scores)

      println(s"epoch=$epoch loss=${totalLoss / size} auc=$auc")
    }

    if (output.length > 0) {
      println(s"saving model to path $output")
      opt.weight
      opt.saveWeight(output + "/weight")
      opt.save(output + "/back")
      println(s"saving z n and w finish")
    }
  }

  def predict(params: Map[String, String]): Unit = {

    val dim = params.getOrElse("dim", "149").toLong
    val input = params.getOrElse("input", "data/census/census_148d_train.libsvm")
    val dataType = params.getOrElse("dataType", "libsvm")
    val partNum = params.getOrElse("partNum", "10").toInt
    val isTraining = params.getOrElse("isTraining", "false").toBoolean
    val hasLabel = params.getOrElse("hasLabel", "true").toBoolean
    val modelPath = params.getOrElse("modelPath", "")
    val predictPath = params.getOrElse("predict", "")


    val opt = new FTRL()
    val conf = new SparkConf()
    conf.set(AngelConf.ANGEL_LOAD_MODEL_PATH, modelPath + "/weight")

    val sc = new SparkContext(conf)
    PSContext.getOrCreate(sc)

    val inputData = sc.textFile(input)
    val data = dataType match {
      case "libsvm" =>
        inputData .map(s =>
          (DataLoader.parseLongFloat(s, dim, isTraining, hasLabel)))
      case "dummy" =>
        inputData .map(s =>
          (DataLoader.parseLongDummy(s, dim, isTraining, hasLabel)))
    }

    val max = data.map(f => f.getX.asInstanceOf[LongFloatVector].getStorage().getIndices.max).max()
    val min = data.map(f => f.getX.asInstanceOf[LongFloatVector].getStorage().getIndices.min).min()
    opt.init(min, max + 1, -1, RowType.T_FLOAT_SPARSE_LONGKEY, new ColumnRangePartitioner())

    if (modelPath.size > 0) {
      opt.load(modelPath + "/weight")
    }

    val scores = data.mapPartitions {
      case iterator =>
        opt.predict(iterator.toArray, false).iterator
    }

    val path = new Path(predictPath)
    val fs = path.getFileSystem(sc.hadoopConfiguration)
    if (fs.exists(path)) {
      fs.delete(path, true)
    }

    scores.saveAsTextFile(predictPath)
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy