org.apache.spark.ml.feature.StandardScaler.scala Maven / Gradle / Ivy
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* See the License for the specific language governing permissions and
* limitations under the License.
package org.apache.spark.ml.feature
import org.apache.hadoop.fs.Path
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml._
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util._
import org.apache.spark.mllib.feature
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{StructField, StructType}
* Params for [[StandardScaler]] and [[StandardScalerModel]].
private[feature] trait StandardScalerParams extends Params with HasInputCol with HasOutputCol {
* Whether to center the data with mean before scaling.
* It will build a dense output, so this does not work on sparse input
* and will raise an exception.
* Default: false
* @group param
val withMean: BooleanParam = new BooleanParam(this, "withMean",
"Whether to center data with mean")
/** @group getParam */
def getWithMean: Boolean = $(withMean)
* Whether to scale the data to unit standard deviation.
* Default: true
* @group param
val withStd: BooleanParam = new BooleanParam(this, "withStd",
"Whether to scale the data to unit standard deviation")
/** @group getParam */
def getWithStd: Boolean = $(withStd)
setDefault(withMean -> false, withStd -> true)
* :: Experimental ::
* Standardizes features by removing the mean and scaling to unit variance using column summary
* statistics on the samples in the training set.
class StandardScaler(override val uid: String) extends Estimator[StandardScalerModel]
with StandardScalerParams with DefaultParamsWritable {
def this() = this(Identifiable.randomUID("stdScal"))
/** @group setParam */
def setInputCol(value: String): this.type = set(inputCol, value)
/** @group setParam */
def setOutputCol(value: String): this.type = set(outputCol, value)
/** @group setParam */
def setWithMean(value: Boolean): this.type = set(withMean, value)
/** @group setParam */
def setWithStd(value: Boolean): this.type = set(withStd, value)
override def fit(dataset: DataFrame): StandardScalerModel = {
transformSchema(dataset.schema, logging = true)
val input = dataset.select($(inputCol)).map { case Row(v: Vector) => v }
val scaler = new feature.StandardScaler(withMean = $(withMean), withStd = $(withStd))
val scalerModel = scaler.fit(input)
copyValues(new StandardScalerModel(uid, scalerModel.std, scalerModel.mean).setParent(this))
override def transformSchema(schema: StructType): StructType = {
val inputType = schema($(inputCol)).dataType
s"Input column ${$(inputCol)} must be a vector column")
s"Output column ${$(outputCol)} already exists.")
val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false)
override def copy(extra: ParamMap): StandardScaler = defaultCopy(extra)
object StandardScaler extends DefaultParamsReadable[StandardScaler] {
override def load(path: String): StandardScaler = super.load(path)
* :: Experimental ::
* Model fitted by [[StandardScaler]].
* @param std Standard deviation of the StandardScalerModel
* @param mean Mean of the StandardScalerModel
class StandardScalerModel private[ml] (
override val uid: String,
val std: Vector,
val mean: Vector)
extends Model[StandardScalerModel] with StandardScalerParams with MLWritable {
import StandardScalerModel._
/** @group setParam */
def setInputCol(value: String): this.type = set(inputCol, value)
/** @group setParam */
def setOutputCol(value: String): this.type = set(outputCol, value)
override def transform(dataset: DataFrame): DataFrame = {
transformSchema(dataset.schema, logging = true)
val scaler = new feature.StandardScalerModel(std, mean, $(withStd), $(withMean))
val scale = udf { scaler.transform _ }
dataset.withColumn($(outputCol), scale(col($(inputCol))))
override def transformSchema(schema: StructType): StructType = {
val inputType = schema($(inputCol)).dataType
s"Input column ${$(inputCol)} must be a vector column")
s"Output column ${$(outputCol)} already exists.")
val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false)
override def copy(extra: ParamMap): StandardScalerModel = {
val copied = new StandardScalerModel(uid, std, mean)
copyValues(copied, extra).setParent(parent)
override def write: MLWriter = new StandardScalerModelWriter(this)
object StandardScalerModel extends MLReadable[StandardScalerModel] {
class StandardScalerModelWriter(instance: StandardScalerModel) extends MLWriter {
private case class Data(std: Vector, mean: Vector)
override protected def saveImpl(path: String): Unit = {
DefaultParamsWriter.saveMetadata(instance, path, sc)
val data = Data(instance.std, instance.mean)
val dataPath = new Path(path, "data").toString
private class StandardScalerModelReader extends MLReader[StandardScalerModel] {
private val className = classOf[StandardScalerModel].getName
override def load(path: String): StandardScalerModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
val Row(std: Vector, mean: Vector) = sqlContext.read.parquet(dataPath)
.select("std", "mean")
val model = new StandardScalerModel(metadata.uid, std, mean)
DefaultParamsReader.getAndSetParams(model, metadata)
override def read: MLReader[StandardScalerModel] = new StandardScalerModelReader
override def load(path: String): StandardScalerModel = super.load(path)
© 2015 - 2025 Weber Informatics LLC | Privacy Policy