All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.microsoft.ml.spark.train.TrainClassifier.scala Maven / Gradle / Ivy

The newest version!
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.ml.spark.train

import java.util.UUID

import com.microsoft.ml.spark.core.utils.CastUtilities._
import com.microsoft.ml.spark.core.env.InternalWrapper
import com.microsoft.ml.spark.core.schema.{CategoricalUtilities, SchemaConstants, SparkSchema}
import com.microsoft.ml.spark.core.serialize.ConstructorReadable
import com.microsoft.ml.spark.featurize.{Featurize, FeaturizeUtilities, ValueIndexer, ValueIndexerModel}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.ml.classification._
import org.apache.spark.ml.param._
import org.apache.spark.ml.util._
import org.apache.spark.ml._
import org.apache.spark.sql._
import org.apache.spark.sql.types.{DoubleType, StructField, StructType}

import scala.reflect.runtime.universe.{TypeTag, typeTag}

/** Trains a classification model.  Featurizes the given data into a vector of doubles.
  *
  * Note the behavior of the reindex and labels parameters, the parameters interact as:
  *
  *   reindex -> false
  *   labels -> false (Empty)
  * Assume all double values, don't use metadata, assume natural ordering
  *
  *   reindex -> true
  *   labels -> false (Empty)
  * Index, use natural ordering of string indexer
  *
  *   reindex -> false
  *   labels -> true (Specified)
  * Assume user knows indexing, apply label values. Currently only string type supported.
  *
  *   reindex -> true
  *   labels -> true (Specified)
  * Validate labels matches column type, try to recast to label type, reindex label column
  *
  * The currently supported classifiers are:
  * Logistic Regression Classifier
  * Decision Tree Classifier
  * Random Forest Classifier
  * Gradient Boosted Trees Classifier
  * Naive Bayes Classifier
  * Multilayer Perceptron Classifier
  * In addition to any generic learner that inherits from Predictor.
  */
@InternalWrapper
class TrainClassifier(override val uid: String) extends AutoTrainer[TrainedClassifierModel] {

  def this() = this(Identifiable.randomUID("TrainClassifier"))

  /** Doc for model to run.
    */
  override def modelDoc: String = "Classifier to run"

  /** Specifies whether to reindex the given label column.
    * See class documentation for how this parameter interacts with specified labels.
    * @group param
    */
  val reindexLabel = new BooleanParam(this, "reindexLabel", "Re-index the label column")
  setDefault(reindexLabel -> true)
  /** @group getParam */
  def getReindexLabel: Boolean = $(reindexLabel)
  /** @group setParam */
  def setReindexLabel(value: Boolean): this.type = set(reindexLabel, value)

  /** Specifies the labels metadata on the column.
    * See class documentation for how this parameter interacts with reindex labels parameter.
    * @group param
    */
  val labels = new StringArrayParam(this, "labels", "Sorted label values on the labels column")
  /** @group getParam */
  def getLabels: Array[String] = $(labels)
  /** @group setParam */
  def setLabels(value: Array[String]): this.type = set(labels, value)

  /** Optional parameter, specifies the name of the features column passed to the learner.
    * Must have a unique name different from the input columns.
    * By default, set to _features.
    * @group param
    */
  setDefault(featuresCol, this.uid + "_features")

  /** Fits the classification model.
    *
    * @param dataset The input dataset to train.
    * @return The trained classification model.
    */
  override def fit(dataset: Dataset[_]): TrainedClassifierModel = {
    val labelValues =
      if (isDefined(labels)) {
        Some(getLabels)
      } else {
        None
      }
    // Convert label column to categorical on train, remove rows with missing labels
    val (convertedLabelDataset, levels) = convertLabel(dataset, getLabelCol, labelValues)

    val (oneHotEncodeCategoricals, modifyInputLayer, numFeatures) = getFeaturizeParams

    var classifier: Estimator[_ <: PipelineStage] = getModel match {
      case logisticRegressionClassifier: LogisticRegression => {
        if (levels.isDefined && levels.get.length > 2) {
          new OneVsRest()
            .setClassifier(
              logisticRegressionClassifier
                .setLabelCol(getLabelCol)
                .setFeaturesCol(getFeaturesCol))
            .setLabelCol(getLabelCol)
            .setFeaturesCol(getFeaturesCol)
        } else {
          logisticRegressionClassifier
        }
      }
      case gradientBoostedTreesClassifier: GBTClassifier => {
        if (levels.isDefined && levels.get.length > 2) {
          throw new Exception("Multiclass Gradient Boosted Tree Classifier not supported yet")
        } else {
          gradientBoostedTreesClassifier
        }
      }
      case default @ defaultType if defaultType.isInstanceOf[Estimator[_ <: PipelineStage]] => {
        default
      }
      case _ => throw new Exception("Unsupported learner type " + getModel.getClass.toString)
    }

    classifier = classifier match {
      case predictor: Predictor[_, _, _] => {
        predictor
          .setLabelCol(getLabelCol)
          .setFeaturesCol(getFeaturesCol).asInstanceOf[Estimator[_ <: PipelineStage]]
      }
      case default @ defaultType if defaultType.isInstanceOf[Estimator[_ <: PipelineStage]] => {
        // assume label col and features col already set
        default
      }
    }

    val featuresToHashTo =
      if (getNumFeatures != 0) {
        getNumFeatures
      } else {
        numFeatures
      }

    val featureColumns = convertedLabelDataset.columns.filter(col => col != getLabelCol).toSeq

    val featurizer = new Featurize()
      .setFeatureColumns(Map(getFeaturesCol -> featureColumns))
      .setOneHotEncodeCategoricals(oneHotEncodeCategoricals)
      .setNumberOfFeatures(featuresToHashTo)
    val featurizedModel = featurizer.fit(convertedLabelDataset)
    val processedData = featurizedModel.transform(convertedLabelDataset)

    processedData.cache()

    // For neural network, need to modify input layer so it will automatically work during train
    if (modifyInputLayer) {
      val multilayerPerceptronClassifier = classifier.asInstanceOf[MultilayerPerceptronClassifier]
      val row = processedData.take(1)(0)
      val featuresVector = row.get(row.fieldIndex(getFeaturesCol))
      val vectorSize = featuresVector.asInstanceOf[linalg.Vector].size
      multilayerPerceptronClassifier.getLayers(0) = vectorSize
      multilayerPerceptronClassifier.setLayers(multilayerPerceptronClassifier.getLayers)
    }

    // Train the learner
    val fitModel = classifier.fit(processedData)

    processedData.unpersist()

    // Note: The fit shouldn't do anything here
    val pipelineModel = new Pipeline().setStages(Array(featurizedModel, fitModel)).fit(convertedLabelDataset)
    new TrainedClassifierModel(uid, getLabelCol, pipelineModel, levels, getFeaturesCol)
  }

  def getFeaturizeParams: (Boolean, Boolean, Int) = {
    var oneHotEncodeCategoricals = true
    var modifyInputLayer = false
    // Create trainer based on the pipeline stage and set the parameters
    val numFeatures: Int = getModel match {
      case _: DecisionTreeClassifier | _: GBTClassifier | _: RandomForestClassifier =>
        oneHotEncodeCategoricals = false
        FeaturizeUtilities.NumFeaturesTreeOrNNBased
      case _: MultilayerPerceptronClassifier =>
        modifyInputLayer = true
        FeaturizeUtilities.NumFeaturesTreeOrNNBased
      case _ =>
        FeaturizeUtilities.NumFeaturesDefault
    }
    (oneHotEncodeCategoricals, modifyInputLayer, numFeatures)
  }

  def convertLabel(dataset: Dataset[_],
                   labelColumn: String,
                   labelValues: Option[Array[_]]): (DataFrame, Option[Array[_]]) = {
    var levels: Option[Array[_]] = None
    if (getReindexLabel) {

      val dataframe = dataset.toDF().na.drop(Seq(labelColumn))

      if (labelValues.isDefined) {
        if (SparkSchema.isCategorical(dataframe, labelColumn)) {
          throw new Exception("Column is already categorical, cannot set label values")
        }
        // Reindex is true, and labels are given, set levels, make column categorical given the levels
        val labelDataType = dataset.schema(labelColumn).dataType
        // Cast the labels to the given data type, validate labels match column type
        labelValues.get.map(value => value.toDataType(labelDataType))
        levels = labelValues
        // Reindex the column to be categorical with given metadata
        val reindexedDF = new ValueIndexerModel()
          .setLevels(levels.get)
          .setDataType(labelDataType)
          .setInputCol(labelColumn)
          .setOutputCol(labelColumn)
          .transform(dataframe)
        (reindexedDF, levels)
      } else {
        if (!SparkSchema.isCategorical(dataframe, labelColumn)) {
          val model = new ValueIndexer().setInputCol(labelColumn).setOutputCol(labelColumn).fit(dataframe)
          val categoricalLabelDataset = model.transform(dataframe)
          levels = CategoricalUtilities.getLevels(categoricalLabelDataset.schema, labelColumn)
          (categoricalLabelDataset.withColumn(labelColumn,
            categoricalLabelDataset(labelColumn).cast(DoubleType).as(labelColumn,
              categoricalLabelDataset.schema(labelColumn).metadata)), levels)
        } else {
          levels = CategoricalUtilities.getLevels(dataframe.schema, labelColumn)
          (dataframe, levels)
        }
      }
    } else {
      if (labelValues.isDefined) {
        // Reindex is false, set label metadata (only strings supported, since we cannot infer types) on
        // column directly
        levels = labelValues
      }
      (dataset.na.drop(Seq(labelColumn)), levels)
    }
  }

  override def copy(extra: ParamMap): Estimator[TrainedClassifierModel] = {
    setModel(getModel.copy(extra))
    defaultCopy(extra)
  }

  @DeveloperApi
  override def transformSchema(schema: StructType): StructType = {
    val hasScoreCols =
      $(model) match {
        case _: GBTClassifier => false
        case _: MultilayerPerceptronClassifier => false
        case _ => true
      }
    TrainClassifier.validateTransformSchema(hasScoreCols, schema)
  }
}

object TrainClassifier extends ComplexParamsReadable[TrainClassifier] {
  def validateTransformSchema(hasScoreCols: Boolean, schema: StructType): StructType = {
    val scoresSchema =
      if (hasScoreCols) {
        StructType(schema.fields :+ StructField(SchemaConstants.ScoresColumn, DoubleType))
      } else schema
    val probSchema =
      if (hasScoreCols) {
        StructType(scoresSchema.fields :+ StructField(SchemaConstants.ScoredProbabilitiesColumn, DoubleType))
      } else scoresSchema
    StructType(probSchema.fields :+ StructField(SchemaConstants.ScoredLabelsColumn, DoubleType))
  }
}

/** Model produced by [[TrainClassifier]]. */
@InternalWrapper
class TrainedClassifierModel(val uid: String,
                             val labelColumn: String,
                             override val model: PipelineModel,
                             val levels: Option[Array[_]],
                             val featuresColumn: String)
    extends AutoTrainedModel[TrainedClassifierModel](model) {

  val ttag: TypeTag[TrainedClassifierModel] = typeTag[TrainedClassifierModel]
  val objectsToSave: List[AnyRef] = List(uid, labelColumn, model, levels, featuresColumn)

  override def copy(extra: ParamMap): TrainedClassifierModel =
    new TrainedClassifierModel(uid,
      labelColumn,
      model.copy(extra),
      levels,
      featuresColumn)

  override def transform(dataset: Dataset[_]): DataFrame = {
    val hasScoreCols = hasScoreColumns(model.stages.last)

    // re-featurize and score the data
    val scoredData = model.transform(dataset)

    // Drop the vectorized features column
    val cleanedScoredData = scoredData.drop(featuresColumn)

    // Update the schema - TODO: create method that would generate GUID and add it to the scored model
    val moduleName = SchemaConstants.ScoreModelPrefix + UUID.randomUUID().toString
    val labelColumnExists = cleanedScoredData.columns.contains(labelColumn)
    val schematizedScoredDataWithLabel =
      if (labelColumnExists) {
        SparkSchema.setLabelColumnName(cleanedScoredData, moduleName, labelColumn, SchemaConstants.ClassificationKind)
      } else {
        cleanedScoredData
      }

    // Note: The GBT model does not have scores, only scored labels.  Same for OneVsRest with any binary model.
    val schematizedScoredDataWithScores =
      if (hasScoreCols) {
        setMetadataForColumnName(SparkSchema.setScoredProbabilitiesColumnName,
          SchemaConstants.SparkProbabilityColumn,
          SchemaConstants.ScoredProbabilitiesColumn,
          moduleName,
          setMetadataForColumnName(SparkSchema.setScoresColumnName,
            SchemaConstants.SparkRawPredictionColumn,
            SchemaConstants.ScoresColumn,
            moduleName,
            schematizedScoredDataWithLabel))
      } else schematizedScoredDataWithLabel

    val scoredDataWithUpdatedScoredLabels =
      setMetadataForColumnName(SparkSchema.setScoredLabelsColumnName,
        SchemaConstants.SparkPredictionColumn,
        SchemaConstants.ScoredLabelsColumn,
        moduleName,
        schematizedScoredDataWithScores)

    val scoredDataWithUpdatedScoredLevels =
      if (levels.isEmpty) scoredDataWithUpdatedScoredLabels
      else CategoricalUtilities.setLevels(scoredDataWithUpdatedScoredLabels,
        SchemaConstants.ScoredLabelsColumn,
        levels.get)

    // add metadata to the scored labels and true labels for the levels in label column
    if (levels.isEmpty || !labelColumnExists) scoredDataWithUpdatedScoredLevels
    else CategoricalUtilities.setLevels(scoredDataWithUpdatedScoredLevels,
      labelColumn,
      levels.get)
  }

  private def setMetadataForColumnName(setter: (DataFrame, String, String, String) => DataFrame,
                                       sparkColumnName: String,
                                       mmlColumnName: String,
                                       moduleName: String,
                                       dataset: DataFrame): DataFrame = {
    if (dataset.columns.contains(sparkColumnName)) {
      setter(dataset.withColumnRenamed(sparkColumnName, mmlColumnName),
        moduleName,
        mmlColumnName,
        SchemaConstants.ClassificationKind)
    } else {
      dataset
    }
  }

  @DeveloperApi
  override def transformSchema(schema: StructType): StructType =
    TrainClassifier.validateTransformSchema(hasScoreColumns(model.stages.last), schema)

  def hasScoreColumns(model: Transformer): Boolean = {
    model match {
      case _: GBTClassificationModel => false
      case _: MultilayerPerceptronClassificationModel => false
      case _ => true
    }
  }
}

object TrainedClassifierModel extends ConstructorReadable[TrainedClassifierModel]




© 2015 - 2024 Weber Informatics LLC | Privacy Policy