All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.microsoft.ml.spark.cntk.CNTKModel.scala Maven / Gradle / Ivy

The newest version!
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.ml.spark.cntk

import com.microsoft.CNTK.CNTKExtensions._
import com.microsoft.CNTK.CNTKUtils._
import com.microsoft.CNTK.{CNTKExtensions, DataType => CNTKDataType, SerializableFunction => CNTKFunction, _}
import com.microsoft.ml.spark.cntk.ConversionUtils.GVV
import com.microsoft.ml.spark.core.contracts.Wrappable
import com.microsoft.ml.spark.core.env.InternalWrapper
import com.microsoft.ml.spark.core.schema.DatasetExtensions.findUnusedColumnName
import com.microsoft.ml.spark.stages.{FixedMiniBatchTransformer, FlattenBatch, HasMiniBatcher}
import org.apache.spark.SparkContext
import org.apache.spark.broadcast._
import org.apache.spark.ml.{ComplexParamsReadable, ComplexParamsWritable, Model}
import org.apache.spark.ml.linalg.{SQLDataTypes, Vectors, Vector => SVector}
import org.apache.spark.ml.linalg.SQLDataTypes.VectorType
import org.apache.spark.ml.param._
import org.apache.spark.ml.util._
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import spray.json.DefaultJsonProtocol._

import scala.collection.JavaConversions._

private object CNTKModelUtils extends java.io.Serializable {

  def applyCNTKFunction(model: CNTKFunction,
                        feedDict: Map[Variable, GVV],
                        outputVars: List[Variable],
                        device: DeviceDescriptor): List[GVV] = {

    val valueMap = feedDict.map { case (v, gvv) =>
      gvv match {
        case Left(fvv) => v -> Value.createDenseFloat(v.getShape, fvv, device)
        case Right(dvv) => v -> Value.createDenseDouble(v.getShape, dvv, device)
      }
    }

    val inputDataMap = new UnorderedMapVariableValuePtr()
    valueMap.foreach { case (vr, vl) => inputDataMap.add(vr, vl) }

    val outputDataMap = new UnorderedMapVariableValuePtr()

    outputVars.foreach(ov => outputDataMap.add(ov, null)) //scalastyle:ignore null
    model.evaluate(inputDataMap, outputDataMap, device)

    val out = outputVars.map { ov: Variable =>
      ov.getDataType match {
        case CNTKDataType.Float =>
          val fvv = new FloatVectorVector() //TODO try re-using
        val value = outputDataMap.getitem(ov)
          value.copyVariableValueToFloat(ov, fvv)
          value.delete()
          Left(fvv)
        case CNTKDataType.Double =>
          val dvv = new DoubleVectorVector() //TODO try re-using
        val value = outputDataMap.getitem(ov)
          value.copyVariableValueToDouble(ov, dvv)
          value.delete()
          Right(dvv)
      }
    }

    valueMap.values.foreach(_.delete())
    out
  }

  private def makeInputExtractors(inputMapVar: Map[Int, Variable]) = {
    inputMapVar.map {
      case (colnum, variable) => variable -> {
        variable.getDataType match {
          case CNTKDataType.Float =>
            r: Row => Left(r.getAs[Seq[Seq[Float]]](colnum))
          case CNTKDataType.Double =>
            r: Row => Right(r.getAs[Seq[Seq[Double]]](colnum))
        }
      }
    }
  }

  def applyModel(inputMap: Map[String, Int],
                 broadcastedModel: Broadcast[CNTKFunction],
                 outputMap: Map[String, String])
                (inputRows: Iterator[Row]): Iterator[Row] = {

    if (!inputRows.hasNext) {
      Iterator() // Quickly skip empty partitions
    } else {
      val device = DeviceDescriptor.useDefaultDevice
      //CNTKLib.SetMaxNumCPUThreads(1)
      val m = CNTKExtensions.fromSerializable(broadcastedModel.value).clone(ParameterCloningMethod.Share)

      val inputMapVar = inputMap.map { case (k, v) => v -> m.getInputVar(k) }
      val outputMapVar = outputMap.map { case (k, v) => m.getOutputVar(v) -> k }
      val inputExtractors = makeInputExtractors(inputMapVar)

      val inputGVVs = inputMapVar.map {
        case (_, variable) => variable -> {
          variable.getDataType match {
            case CNTKDataType.Float =>
              Left(new FloatVectorVector())
            case CNTKDataType.Double =>
              Right(new DoubleVectorVector())
          }
        }
      }

      // WARNING: DO NOT simplify this to mapValues,
      // for some reason it calls the inner function more than it should
      val preprocessFunction: Row => Map[Variable, GVV] = {
        { row: Row =>
          inputExtractors.map { case (k, f) =>
            k -> ConversionUtils.toGVV(f(row), inputGVVs(k))
          }
        }
      }

      val outputVars = outputMapVar.keys.toList

      val outputVarVector = new VariableVector()
      outputVars.foreach(outputVarVector.add)
      val of = CNTKLib.Combine(outputVarVector)

      inputRows.map { row =>
        val feedDict = preprocessFunction(row)
        val outputGVVs = applyCNTKFunction(of, feedDict, outputVars, device)
        val resultRow = Row(outputGVVs.map(ConversionUtils.convertGVV): _*)
        val outputRow = Row.merge(row, resultRow)
        outputGVVs.foreach(ConversionUtils.deleteGVV)
        outputRow
      }
    }
  }

}

object CNTKModel extends ComplexParamsReadable[CNTKModel]

@InternalWrapper
class CNTKModel(override val uid: String) extends Model[CNTKModel] with ComplexParamsWritable
  with HasMiniBatcher with Wrappable {

  def this() = this(Identifiable.randomUID("CNTKModel"))

  /** Array of bytes containing the serialized CNTK Function
    *
    * @group param
    */
  val model: CNTKFunctionParam =
    new CNTKFunctionParam(this, "model", "Array of bytes containing the serialized CNTKModel")

  private var broadcastedModelOption: Option[Broadcast[CNTKFunction]] = None

  /** @group setParam */
  def setModel(value: CNTKFunction): this.type = {
    // Free up memory used by the previous model
    // TODO: investigate using destroy()
    broadcastedModelOption.foreach(_.unpersist())
    broadcastedModelOption = None
    set(model, value)
  }

  /** @group getParam */
  def getModel: CNTKFunction = $(model)

  /** @group setParam */
  def setModelLocation(path: String): this.type = {
    val modelBytes = SparkContext.getOrCreate().binaryFiles(path).first()._2.toArray
    setModel(CNTKFunction.loadModelFromBytes(modelBytes))
  }

  val batchInput = new BooleanParam(this, "batchInput",
    "whether to use a batcher")

  setDefault(batchInput -> true)

  def setBatchInput(v: Boolean): this.type = set(batchInput, v)

  def getBatchInput: Boolean = $(batchInput)

  val shapeOutput = new BooleanParam(this, "shapeOutput",
    "whether to shape the output")

  setDefault(shapeOutput -> false)

  def setShapeOutput(v: Boolean): this.type = set(shapeOutput, v)

  def getShapeOutput: Boolean = $(shapeOutput)

  val convertOutputToDenseVector = new BooleanParam(this, "convertOutputToDenseVector",
    "whether to convert the output to dense vectors")

  setDefault(convertOutputToDenseVector -> true)

  def setConvertOutputToDenseVector(v: Boolean): this.type = set(convertOutputToDenseVector, v)

  def getConvertOutputToDenseVector: Boolean = $(convertOutputToDenseVector)

  val feedDict: MapParam[String, String] = new MapParam[String, String](this, "feedDict",
    " Map of CNTK Variable names (keys) and Column Names (values)")

  setDefault(feedDict -> Map((ArgumentPrefix + 0) -> (ArgumentPrefix + 0)))

  def setFeedDict(value: Map[String, String]): this.type = set(feedDict, value)

  def setFeedDict(k: String, v: String): this.type = set(feedDict, Map(k -> v))

  def getFeedDict: Map[String, String] = $(feedDict)

  val fetchDict: MapParam[String, String] = new MapParam[String, String](this, "fetchDict",
    " Map of Column Names (keys) and CNTK Variable names (values)")
  setDefault(fetchDict -> Map((OutputPrefix + 0) -> (OutputPrefix + 0)))

  def setFetchDict(value: Map[String, String]): this.type = set("fetchDict", value)

  def setFetchDict(k: String, v: String): this.type = set(fetchDict, Map(k -> v))

  def getFetchDict: Map[String, String] = $(fetchDict)

  // Alternative Input APIs

  def setInputNodeIndex(value: Int): this.type = {
    val fd = getFeedDict
    if (fd.isEmpty) {
      setFeedDict(ArgumentPrefix + value, ArgumentPrefix + value)
    } else if (fd.size == 1) {
      setFeedDict(ArgumentPrefix + value, fd.values.head)
    } else {
      throw new IllegalArgumentException("existing feed dict has too many elements," +
        " consider using the more expressive feedDict param directly")
    }
  }

  def getInputNodeIndex: Int = {
    val fd = getFeedDict
    if (fd.size == 1) {
      fd.keys.head match {
        case node if node.startsWith(ArgumentPrefix) =>
          node.stripPrefix(ArgumentPrefix).toInt
        case _ => throw new RuntimeException("Feed dict did not have the proper structure")
      }
    } else {
      throw new IllegalArgumentException("existing feed dict has too many elements," +
        " consider using the more expressive feedDict param directly")
    }
  }

  def setInputNode(value: String): this.type = {
    val fd = getFeedDict
    if (fd.isEmpty) {
      setFeedDict(value, value)
    } else if (fd.size == 1) {
      setFeedDict(value, fd.values.head)
    } else {
      throw new IllegalArgumentException("existing feed dict has too many elements," +
        " consider using the more expressive feedDict param directly")
    }
  }

  def getInputNode: String = {
    val fd = getFeedDict
    if (fd.size == 1) {
      fd.keys.head
    } else {
      throw new IllegalArgumentException("existing feed dict has too many elements," +
        " consider using the more expressive feedDict param directly")
    }
  }

  def setInputCol(value: String): this.type = {
    val fd = getFeedDict
    if (fd.isEmpty) {
      setFeedDict(value, value)
    } else if (fd.size == 1) {
      setFeedDict(fd.keys.head, value)
    } else {
      throw new IllegalArgumentException("existing feed dict has too many elements," +
        " consider using the more expressive feedDict param directly")
    }
  }

  def getInputCol: String = {
    val fd = getFeedDict
    if (fd.size == 1) {
      fd.values.head
    } else {
      throw new IllegalArgumentException("existing feed dict has too many elements," +
        " consider using the more expressive feedDict param directly")
    }
  }

  // Alternative Output APIs

  def setOutputNodeIndex(value: Int): this.type = {
    val fd = getFetchDict
    if (fd.isEmpty) {
      setFetchDict(OutputPrefix + value, OutputPrefix + value)
    } else if (fd.size == 1) {
      setFetchDict(fd.keys.head, OutputPrefix + value)
    } else {
      throw new IllegalArgumentException("existing fetch dict has too many elements," +
        " consider using the more expressive fetchDict param directly")
    }
  }

  def getOutputNodeIndex: Int = {
    val fd = getFetchDict
    if (fd.size == 1) {
      fd.values.head match {
        case node if node.startsWith(OutputPrefix) =>
          node.stripPrefix(OutputPrefix).toInt
        case _ => throw new RuntimeException("Fetch dict did not have the proper structure")
      }
    } else {
      throw new IllegalArgumentException("existing fetch dict has too many elements," +
        " consider using the more expressive fetchDict param directly")
    }
  }

  def setOutputNode(value: String): this.type = {
    val fd = getFetchDict
    if (fd.isEmpty) {
      setFetchDict(value, value)
    } else if (fd.size == 1) {
      setFetchDict(fd.keys.head, value)
    } else {
      throw new IllegalArgumentException("existing fetch dict has too many elements," +
        " consider using the more expressive fetchDict param directly")
    }
  }

  def getOutputNode: String = {
    val fd = getFetchDict
    if (fd.size == 1) {
      fd.values.head
    } else {
      throw new IllegalArgumentException("existing fetch dict has too many elements," +
        " consider using the more expressive fetchDict param directly")
    }
  }

  def setOutputCol(value: String): this.type = {
    val fd = getFetchDict
    if (fd.isEmpty) {
      setFetchDict(value, value)
    } else if (fd.size == 1) {
      setFetchDict(value, fd.values.head)
    } else {
      throw new IllegalArgumentException("existing fetch dict has too many elements," +
        " consider using the more expressive fetchDict param directly")
    }
  }

  def getOutputCol: String = {
    val fd = getFetchDict
    if (fd.size == 1) {
      fd.keys.head
    } else {
      throw new IllegalArgumentException("existing fetch dict has too many elements," +
        " consider using the more expressive fetchDict param directly")
    }
  }

  /** Returns the dimensions of the required input */
  def getInputShapes: List[Array[Int]] = {
    getModel.getArguments.toList.map(_.getShape.getDimensions.map(_.toInt))
  }

  setDefault(miniBatcher -> new FixedMiniBatchTransformer().setBatchSize(10)) //scalastyle:ignore magic.number

  private def getElementType(t: DataType): DataType = {
    t match {
      case ArrayType(et, _) => getElementType(et)
      case et => et
    }
  }

  def transformSchema(schema: StructType): StructType = {
    getFeedDict.foreach { case (varName, colName) =>
      val colType = schema(colName).dataType
      val innerTypes = Set(VectorType,
        ArrayType(DoubleType, true),
        ArrayType(FloatType, true),
        ArrayType(DoubleType, false),
        ArrayType(FloatType, false))
      val allowedTypes = if (getBatchInput) {
        innerTypes
      } else {
        innerTypes.map(ArrayType(_))
      }
      assert(allowedTypes.contains(colType))
    }

    if (getConvertOutputToDenseVector) {
      getFetchDict.toList.sorted
        .foldLeft(schema) { case (st, (varname, colname)) => st.add(colname, VectorType) }
    } else {
      getModel.getOutputSchema(getFetchDict)
        .foldLeft(schema) { case (st, sf) => st.add(sf) }
    }

  }

  override def copy(extra: ParamMap): this.type = defaultCopy(extra)

  def rebroadcastCNTKModel(spark: SparkSession): Unit = {
    broadcastedModelOption = Some(spark.sparkContext.broadcast(getModel))
  }

  private val coercionPrefix = s"coerced_$uid"

  private def coerceType(schema: StructType, colname: String, targetElementType: DataType):
  (Option[(UserDefinedFunction, String)]) = {
    val colType = schema(colname).dataType match {
      case ArrayType(dt, _) => dt
    }

    val funcOpt = (colType, targetElementType) match {
      case (VectorType, DoubleType) =>
        Some({ av: Seq[SVector] => av.map(_.toArray) })
      case (VectorType, FloatType) =>
        Some({ av: Seq[SVector] => av.map(_.toArray.map(_.toFloat)) })
      case (ArrayType(FloatType, _), DoubleType) =>
        Some({ av: Seq[Seq[Float]] => av.map(_.map(_.toDouble)) })
      case (ArrayType(DoubleType, _), FloatType) =>
        Some({ av: Seq[Seq[Double]] => av.map(_.map(_.toFloat)) })
      case (ArrayType(DoubleType, _), DoubleType) =>
        None
      case (ArrayType(FloatType, _), FloatType) =>
        None
    }

    funcOpt.map { f =>
      (udf(f, ArrayType(ArrayType(targetElementType))),
        findUnusedColumnName(coercionPrefix, schema))
    }
  }

  private def coerceDFAndFeedDict(df: DataFrame,
                                  feedDict: Map[String, String]
                                 ): (DataFrame, Map[String, String]) = {
    feedDict.foldLeft((df, feedDict)) {
      case ((dfInternal, fdInternal), (varname, colname)) =>
        val elementType = variableToElementType(getModel.getInputVar(varname))
        coerceType(dfInternal.schema, colname, elementType) match {
          case Some((udfVal, newColName)) =>
            (
              dfInternal.withColumn(newColName, udfVal(col(colname))),
              fdInternal.updated(varname, newColName)
            )
          case None =>
            (dfInternal, feedDict)
        }
    }
  }

  private def coerceOutputDF(unbatchedDF: DataFrame): DataFrame = {
    val floatToDV = udf({ v: Seq[Float] => Vectors.dense(v.map(_.toDouble).toArray) }, SQLDataTypes.VectorType)
    val doubleToDV = udf({ v: Seq[Double] => Vectors.dense(v.toArray) }, SQLDataTypes.VectorType)

    if (getConvertOutputToDenseVector) {
      val outputSchema = getModel.getOutputSchema(getFetchDict)
      val outputColumnNames = outputSchema.map(_.name).toSet
      val colsToSelect = unbatchedDF.schema.map {
        case sf if outputColumnNames(sf.name) =>
          sf match {
            case StructField(name, ArrayType(FloatType, _), _, _) =>
              floatToDV(col(name)).alias(name)
            case StructField(name, ArrayType(DoubleType, _), _, _) =>
              doubleToDV(col(name)).alias(name)
          }
        case sf => col(sf.name)
      }
      unbatchedDF.select(colsToSelect: _*)
    } else {
      unbatchedDF
    }
  }

  /** Evaluate the model
    *
    * @param dataset the dataset to featurize
    * @return featurized dataset
    */
  def transform(dataset: Dataset[_]): DataFrame = {
    val spark = dataset.sparkSession
    val df = dataset.toDF()

    transformSchema(df.schema) // Check if the schema is correct

    val batchedDF = if (getBatchInput) {
      getMiniBatcher.transform(df)
    } else {
      df
    }

    val (preprocessedDF, coercedFeedDict) =
      coerceDFAndFeedDict(batchedDF, getFeedDict)

    val columnIndexToVar = coercedFeedDict.map { case (k, v) =>
      k -> preprocessedDF.schema.fieldIndex(v)
    }

    if (broadcastedModelOption.isEmpty) rebroadcastCNTKModel(spark)

    val encoder = RowEncoder(getModel.getOutputSchema(getFetchDict)
      .foldLeft(preprocessedDF.schema) { case (st, sf) => st.add(sf.name, ArrayType(sf.dataType)) }
    )

    val outputDF = preprocessedDF.mapPartitions { it =>
      CNTKModelUtils.applyModel(
        columnIndexToVar,
        broadcastedModelOption.get,
        getFetchDict)(it)
    }(encoder)

    val droppedDF = outputDF.drop(outputDF.columns.filter(_.startsWith(coercionPrefix)): _*)

    val unbatchedDF = if (getBatchInput) {
      new FlattenBatch().transform(droppedDF)
    } else {
      droppedDF
    }
    coerceOutputDF(unbatchedDF)
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy