org.apache.spark.ml.attribute.attributes.scala Maven / Gradle / Ivy
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.ml.attribute
import scala.annotation.varargs
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.sql.types.{DoubleType, NumericType, Metadata, MetadataBuilder, StructField}
/**
* :: DeveloperApi ::
* Abstract class for ML attributes.
*/
@DeveloperApi
sealed abstract class Attribute extends Serializable {
name.foreach { n =>
require(n.nonEmpty, "Cannot have an empty string for name.")
}
index.foreach { i =>
require(i >= 0, s"Index cannot be negative but got $i")
}
/** Attribute type. */
def attrType: AttributeType
/** Name of the attribute. None if it is not set. */
def name: Option[String]
/** Copy with a new name. */
def withName(name: String): Attribute
/** Copy without the name. */
def withoutName: Attribute
/** Index of the attribute. None if it is not set. */
def index: Option[Int]
/** Copy with a new index. */
def withIndex(index: Int): Attribute
/** Copy without the index. */
def withoutIndex: Attribute
/**
* Tests whether this attribute is numeric, true for [[NumericAttribute]] and [[BinaryAttribute]].
*/
def isNumeric: Boolean
/**
* Tests whether this attribute is nominal, true for [[NominalAttribute]] and [[BinaryAttribute]].
*/
def isNominal: Boolean
/**
* Converts this attribute to [[Metadata]].
* @param withType whether to include the type info
*/
private[attribute] def toMetadataImpl(withType: Boolean): Metadata
/**
* Converts this attribute to [[Metadata]]. For numeric attributes, the type info is excluded to
* save space, because numeric type is the default attribute type. For nominal and binary
* attributes, the type info is included.
*/
private[attribute] def toMetadataImpl(): Metadata = {
if (attrType == AttributeType.Numeric) {
toMetadataImpl(withType = false)
} else {
toMetadataImpl(withType = true)
}
}
/** Converts to ML metadata with some existing metadata. */
def toMetadata(existingMetadata: Metadata): Metadata = {
new MetadataBuilder()
.withMetadata(existingMetadata)
.putMetadata(AttributeKeys.ML_ATTR, toMetadataImpl())
.build()
}
/** Converts to ML metadata */
def toMetadata(): Metadata = toMetadata(Metadata.empty)
/**
* Converts to a [[StructField]] with some existing metadata.
* @param existingMetadata existing metadata to carry over
*/
def toStructField(existingMetadata: Metadata): StructField = {
val newMetadata = new MetadataBuilder()
.withMetadata(existingMetadata)
.putMetadata(AttributeKeys.ML_ATTR, withoutName.withoutIndex.toMetadataImpl())
.build()
StructField(name.get, DoubleType, nullable = false, newMetadata)
}
/** Converts to a [[StructField]]. */
def toStructField(): StructField = toStructField(Metadata.empty)
override def toString: String = toMetadataImpl(withType = true).toString
}
/** Trait for ML attribute factories. */
private[attribute] trait AttributeFactory {
/**
* Creates an [[Attribute]] from a [[Metadata]] instance.
*/
private[attribute] def fromMetadata(metadata: Metadata): Attribute
/**
* Creates an [[Attribute]] from a [[StructField]] instance, optionally preserving name.
*/
private[ml] def decodeStructField(field: StructField, preserveName: Boolean): Attribute = {
require(field.dataType.isInstanceOf[NumericType])
val metadata = field.metadata
val mlAttr = AttributeKeys.ML_ATTR
if (metadata.contains(mlAttr)) {
val attr = fromMetadata(metadata.getMetadata(mlAttr))
if (preserveName) {
attr
} else {
attr.withName(field.name)
}
} else {
UnresolvedAttribute
}
}
/**
* Creates an [[Attribute]] from a [[StructField]] instance.
*/
def fromStructField(field: StructField): Attribute = decodeStructField(field, false)
}
/**
* :: DeveloperApi ::
*/
@DeveloperApi
object Attribute extends AttributeFactory {
private[attribute] override def fromMetadata(metadata: Metadata): Attribute = {
import org.apache.spark.ml.attribute.AttributeKeys._
val attrType = if (metadata.contains(TYPE)) {
metadata.getString(TYPE)
} else {
AttributeType.Numeric.name
}
getFactory(attrType).fromMetadata(metadata)
}
/** Gets the attribute factory given the attribute type name. */
private def getFactory(attrType: String): AttributeFactory = {
if (attrType == AttributeType.Numeric.name) {
NumericAttribute
} else if (attrType == AttributeType.Nominal.name) {
NominalAttribute
} else if (attrType == AttributeType.Binary.name) {
BinaryAttribute
} else {
throw new IllegalArgumentException(s"Cannot recognize type $attrType.")
}
}
}
/**
* :: DeveloperApi ::
* A numeric attribute with optional summary statistics.
* @param name optional name
* @param index optional index
* @param min optional min value
* @param max optional max value
* @param std optional standard deviation
* @param sparsity optional sparsity (ratio of zeros)
*/
@DeveloperApi
class NumericAttribute private[ml] (
override val name: Option[String] = None,
override val index: Option[Int] = None,
val min: Option[Double] = None,
val max: Option[Double] = None,
val std: Option[Double] = None,
val sparsity: Option[Double] = None) extends Attribute {
std.foreach { s =>
require(s >= 0.0, s"Standard deviation cannot be negative but got $s.")
}
sparsity.foreach { s =>
require(s >= 0.0 && s <= 1.0, s"Sparsity must be in [0, 1] but got $s.")
}
override def attrType: AttributeType = AttributeType.Numeric
override def withName(name: String): NumericAttribute = copy(name = Some(name))
override def withoutName: NumericAttribute = copy(name = None)
override def withIndex(index: Int): NumericAttribute = copy(index = Some(index))
override def withoutIndex: NumericAttribute = copy(index = None)
/** Copy with a new min value. */
def withMin(min: Double): NumericAttribute = copy(min = Some(min))
/** Copy without the min value. */
def withoutMin: NumericAttribute = copy(min = None)
/** Copy with a new max value. */
def withMax(max: Double): NumericAttribute = copy(max = Some(max))
/** Copy without the max value. */
def withoutMax: NumericAttribute = copy(max = None)
/** Copy with a new standard deviation. */
def withStd(std: Double): NumericAttribute = copy(std = Some(std))
/** Copy without the standard deviation. */
def withoutStd: NumericAttribute = copy(std = None)
/** Copy with a new sparsity. */
def withSparsity(sparsity: Double): NumericAttribute = copy(sparsity = Some(sparsity))
/** Copy without the sparsity. */
def withoutSparsity: NumericAttribute = copy(sparsity = None)
/** Copy without summary statistics. */
def withoutSummary: NumericAttribute = copy(min = None, max = None, std = None, sparsity = None)
override def isNumeric: Boolean = true
override def isNominal: Boolean = false
/** Convert this attribute to metadata. */
override private[attribute] def toMetadataImpl(withType: Boolean): Metadata = {
import org.apache.spark.ml.attribute.AttributeKeys._
val bldr = new MetadataBuilder()
if (withType) bldr.putString(TYPE, attrType.name)
name.foreach(bldr.putString(NAME, _))
index.foreach(bldr.putLong(INDEX, _))
min.foreach(bldr.putDouble(MIN, _))
max.foreach(bldr.putDouble(MAX, _))
std.foreach(bldr.putDouble(STD, _))
sparsity.foreach(bldr.putDouble(SPARSITY, _))
bldr.build()
}
/** Creates a copy of this attribute with optional changes. */
private def copy(
name: Option[String] = name,
index: Option[Int] = index,
min: Option[Double] = min,
max: Option[Double] = max,
std: Option[Double] = std,
sparsity: Option[Double] = sparsity): NumericAttribute = {
new NumericAttribute(name, index, min, max, std, sparsity)
}
override def equals(other: Any): Boolean = {
other match {
case o: NumericAttribute =>
(name == o.name) &&
(index == o.index) &&
(min == o.min) &&
(max == o.max) &&
(std == o.std) &&
(sparsity == o.sparsity)
case _ =>
false
}
}
override def hashCode: Int = {
var sum = 17
sum = 37 * sum + name.hashCode
sum = 37 * sum + index.hashCode
sum = 37 * sum + min.hashCode
sum = 37 * sum + max.hashCode
sum = 37 * sum + std.hashCode
sum = 37 * sum + sparsity.hashCode
sum
}
}
/**
* :: DeveloperApi ::
* Factory methods for numeric attributes.
*/
@DeveloperApi
object NumericAttribute extends AttributeFactory {
/** The default numeric attribute. */
val defaultAttr: NumericAttribute = new NumericAttribute
private[attribute] override def fromMetadata(metadata: Metadata): NumericAttribute = {
import org.apache.spark.ml.attribute.AttributeKeys._
val name = if (metadata.contains(NAME)) Some(metadata.getString(NAME)) else None
val index = if (metadata.contains(INDEX)) Some(metadata.getLong(INDEX).toInt) else None
val min = if (metadata.contains(MIN)) Some(metadata.getDouble(MIN)) else None
val max = if (metadata.contains(MAX)) Some(metadata.getDouble(MAX)) else None
val std = if (metadata.contains(STD)) Some(metadata.getDouble(STD)) else None
val sparsity = if (metadata.contains(SPARSITY)) Some(metadata.getDouble(SPARSITY)) else None
new NumericAttribute(name, index, min, max, std, sparsity)
}
}
/**
* :: DeveloperApi ::
* A nominal attribute.
* @param name optional name
* @param index optional index
* @param isOrdinal whether this attribute is ordinal (optional)
* @param numValues optional number of values. At most one of `numValues` and `values` can be
* defined.
* @param values optional values. At most one of `numValues` and `values` can be defined.
*/
@DeveloperApi
class NominalAttribute private[ml] (
override val name: Option[String] = None,
override val index: Option[Int] = None,
val isOrdinal: Option[Boolean] = None,
val numValues: Option[Int] = None,
val values: Option[Array[String]] = None) extends Attribute {
numValues.foreach { n =>
require(n >= 0, s"numValues cannot be negative but got $n.")
}
require(!(numValues.isDefined && values.isDefined),
"Cannot have both numValues and values defined.")
override def attrType: AttributeType = AttributeType.Nominal
override def isNumeric: Boolean = false
override def isNominal: Boolean = true
private lazy val valueToIndex: Map[String, Int] = {
values.map(_.zipWithIndex.toMap).getOrElse(Map.empty)
}
/** Index of a specific value. */
def indexOf(value: String): Int = {
valueToIndex(value)
}
/** Tests whether this attribute contains a specific value. */
def hasValue(value: String): Boolean = valueToIndex.contains(value)
/** Gets a value given its index. */
def getValue(index: Int): String = values.get(index)
override def withName(name: String): NominalAttribute = copy(name = Some(name))
override def withoutName: NominalAttribute = copy(name = None)
override def withIndex(index: Int): NominalAttribute = copy(index = Some(index))
override def withoutIndex: NominalAttribute = copy(index = None)
/** Copy with new values and empty `numValues`. */
def withValues(values: Array[String]): NominalAttribute = {
copy(numValues = None, values = Some(values))
}
/** Copy with new values and empty `numValues`. */
@varargs
def withValues(first: String, others: String*): NominalAttribute = {
copy(numValues = None, values = Some((first +: others).toArray))
}
/** Copy without the values. */
def withoutValues: NominalAttribute = {
copy(values = None)
}
/** Copy with a new `numValues` and empty `values`. */
def withNumValues(numValues: Int): NominalAttribute = {
copy(numValues = Some(numValues), values = None)
}
/** Copy without the `numValues`. */
def withoutNumValues: NominalAttribute = copy(numValues = None)
/**
* Get the number of values, either from `numValues` or from `values`.
* Return None if unknown.
*/
def getNumValues: Option[Int] = {
if (numValues.nonEmpty) {
numValues
} else if (values.nonEmpty) {
Some(values.get.length)
} else {
None
}
}
/** Creates a copy of this attribute with optional changes. */
private def copy(
name: Option[String] = name,
index: Option[Int] = index,
isOrdinal: Option[Boolean] = isOrdinal,
numValues: Option[Int] = numValues,
values: Option[Array[String]] = values): NominalAttribute = {
new NominalAttribute(name, index, isOrdinal, numValues, values)
}
override private[attribute] def toMetadataImpl(withType: Boolean): Metadata = {
import org.apache.spark.ml.attribute.AttributeKeys._
val bldr = new MetadataBuilder()
if (withType) bldr.putString(TYPE, attrType.name)
name.foreach(bldr.putString(NAME, _))
index.foreach(bldr.putLong(INDEX, _))
isOrdinal.foreach(bldr.putBoolean(ORDINAL, _))
numValues.foreach(bldr.putLong(NUM_VALUES, _))
values.foreach(v => bldr.putStringArray(VALUES, v))
bldr.build()
}
override def equals(other: Any): Boolean = {
other match {
case o: NominalAttribute =>
(name == o.name) &&
(index == o.index) &&
(isOrdinal == o.isOrdinal) &&
(numValues == o.numValues) &&
(values.map(_.toSeq) == o.values.map(_.toSeq))
case _ =>
false
}
}
override def hashCode: Int = {
var sum = 17
sum = 37 * sum + name.hashCode
sum = 37 * sum + index.hashCode
sum = 37 * sum + isOrdinal.hashCode
sum = 37 * sum + numValues.hashCode
sum = 37 * sum + values.map(_.toSeq).hashCode
sum
}
}
/**
* :: DeveloperApi ::
* Factory methods for nominal attributes.
*/
@DeveloperApi
object NominalAttribute extends AttributeFactory {
/** The default nominal attribute. */
final val defaultAttr: NominalAttribute = new NominalAttribute
private[attribute] override def fromMetadata(metadata: Metadata): NominalAttribute = {
import org.apache.spark.ml.attribute.AttributeKeys._
val name = if (metadata.contains(NAME)) Some(metadata.getString(NAME)) else None
val index = if (metadata.contains(INDEX)) Some(metadata.getLong(INDEX).toInt) else None
val isOrdinal = if (metadata.contains(ORDINAL)) Some(metadata.getBoolean(ORDINAL)) else None
val numValues =
if (metadata.contains(NUM_VALUES)) Some(metadata.getLong(NUM_VALUES).toInt) else None
val values =
if (metadata.contains(VALUES)) Some(metadata.getStringArray(VALUES)) else None
new NominalAttribute(name, index, isOrdinal, numValues, values)
}
}
/**
* :: DeveloperApi ::
* A binary attribute.
* @param name optional name
* @param index optional index
* @param values optionla values. If set, its size must be 2.
*/
@DeveloperApi
class BinaryAttribute private[ml] (
override val name: Option[String] = None,
override val index: Option[Int] = None,
val values: Option[Array[String]] = None)
extends Attribute {
values.foreach { v =>
require(v.length == 2, s"Number of values must be 2 for a binary attribute but got ${v.toSeq}.")
}
override def attrType: AttributeType = AttributeType.Binary
override def isNumeric: Boolean = true
override def isNominal: Boolean = true
override def withName(name: String): BinaryAttribute = copy(name = Some(name))
override def withoutName: BinaryAttribute = copy(name = None)
override def withIndex(index: Int): BinaryAttribute = copy(index = Some(index))
override def withoutIndex: BinaryAttribute = copy(index = None)
/**
* Copy with new values.
* @param negative name for negative
* @param positive name for positive
*/
def withValues(negative: String, positive: String): BinaryAttribute =
copy(values = Some(Array(negative, positive)))
/** Copy without the values. */
def withoutValues: BinaryAttribute = copy(values = None)
/** Creates a copy of this attribute with optional changes. */
private def copy(
name: Option[String] = name,
index: Option[Int] = index,
values: Option[Array[String]] = values): BinaryAttribute = {
new BinaryAttribute(name, index, values)
}
override private[attribute] def toMetadataImpl(withType: Boolean): Metadata = {
import org.apache.spark.ml.attribute.AttributeKeys._
val bldr = new MetadataBuilder
if (withType) bldr.putString(TYPE, attrType.name)
name.foreach(bldr.putString(NAME, _))
index.foreach(bldr.putLong(INDEX, _))
values.foreach(v => bldr.putStringArray(VALUES, v))
bldr.build()
}
override def equals(other: Any): Boolean = {
other match {
case o: BinaryAttribute =>
(name == o.name) &&
(index == o.index) &&
(values.map(_.toSeq) == o.values.map(_.toSeq))
case _ =>
false
}
}
override def hashCode: Int = {
var sum = 17
sum = 37 * sum + name.hashCode
sum = 37 * sum + index.hashCode
sum = 37 * sum + values.map(_.toSeq).hashCode
sum
}
}
/**
* :: DeveloperApi ::
* Factory methods for binary attributes.
*/
@DeveloperApi
object BinaryAttribute extends AttributeFactory {
/** The default binary attribute. */
final val defaultAttr: BinaryAttribute = new BinaryAttribute
private[attribute] override def fromMetadata(metadata: Metadata): BinaryAttribute = {
import org.apache.spark.ml.attribute.AttributeKeys._
val name = if (metadata.contains(NAME)) Some(metadata.getString(NAME)) else None
val index = if (metadata.contains(INDEX)) Some(metadata.getLong(INDEX).toInt) else None
val values =
if (metadata.contains(VALUES)) Some(metadata.getStringArray(VALUES)) else None
new BinaryAttribute(name, index, values)
}
}
/**
* :: DeveloperApi ::
* An unresolved attribute.
*/
@DeveloperApi
object UnresolvedAttribute extends Attribute {
override def attrType: AttributeType = AttributeType.Unresolved
override def withIndex(index: Int): Attribute = this
override def isNumeric: Boolean = false
override def withoutIndex: Attribute = this
override def isNominal: Boolean = false
override def name: Option[String] = None
override private[attribute] def toMetadataImpl(withType: Boolean): Metadata = {
Metadata.empty
}
override def withoutName: Attribute = this
override def index: Option[Int] = None
override def withName(name: String): Attribute = this
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy