All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.tencent.angel.sona.ml.attribute.AttributeGroup.scala Maven / Gradle / Ivy

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.tencent.angel.sona.ml.attribute

import org.apache.spark.linalg.VectorUDT

import scala.collection.mutable.ArrayBuffer
import org.apache.spark.sql.types.{Metadata, MetadataBuilder, StructField}

/**
  * :: DeveloperApi ::
  * Attributes that describe a vector ML column.
  *
  * @param name          name of the attribute group (the ML column name)
  * @param numAttributes optional number of attributes. At most one of `numAttributes` and `attrs`
  *                      can be defined.
  * @param attrs         optional array of attributes. Attribute will be copied with their corresponding
  *                      indices in the array.
  */
class AttributeGroup private(
                              val name: String,
                              val numAttributes: Option[Long],
                              attrs: Option[Array[Attribute]]) extends Serializable {

  require(name.nonEmpty, "Cannot have an empty string for name.")
  require(!(numAttributes.isDefined && attrs.isDefined),
    "Cannot have both numAttributes and attrs defined.")

  /**
    * Creates an attribute group without attribute info.
    *
    * @param name name of the attribute group
    */
  def this(name: String) = this(name, None, None)

  /**
    * Creates an attribute group knowing only the number of attributes.
    *
    * @param name          name of the attribute group
    * @param numAttributes number of attributes
    */
  def this(name: String, numAttributes: Long) = this(name, Some(numAttributes), None)

  /**
    * Creates an attribute group with attributes.
    *
    * @param name  name of the attribute group
    * @param attrs array of attributes. Attributes will be copied with their corresponding indices in
    *              the array.
    */
  def this(name: String, attrs: Array[Attribute]) = this(name, None, Some(attrs))

  /**
    * Optional array of attributes. At most one of `numAttributes` and `attributes` can be defined.
    */
  val attributes: Option[Array[Attribute]] = attrs.map(_.zipWithIndex.map { case (attr, i) =>
    attr.withIndex(i)
  }.toArray)

  private lazy val nameToIndex: Map[String, Int] = {
    attributes.map(_.flatMap { attr =>
      attr.name.map(_ -> attr.index.get.toInt)
    }.toMap).getOrElse(Map.empty)
  }

  /** Size of the attribute group. Returns -1 if the size is unknown. */
  def size: Long = {
    if (numAttributes.isDefined) {
      numAttributes.get
    } else if (attributes.isDefined) {
      attributes.get.length
    } else {
      -1
    }
  }

  /** Test whether this attribute group contains a specific attribute. */
  def hasAttr(attrName: String): Boolean = nameToIndex.contains(attrName)

  /** Index of an attribute specified by name. */
  def indexOf(attrName: String): Int = nameToIndex(attrName)

  /** Gets an attribute by its name. */
  def apply(attrName: String): Attribute = {
    attributes.get(indexOf(attrName))
  }

  /** Gets an attribute by its name. */
  def getAttr(attrName: String): Attribute = this (attrName)

  /** Gets an attribute by its index. */
  def apply(attrIndex: Int): Attribute = attributes.get(attrIndex)

  /** Gets an attribute by its index. */
  def getAttr(attrIndex: Int): Attribute = this (attrIndex)

  /** Converts to metadata without name. */
  private[attribute] def toMetadataImpl: Metadata = {
    import AttributeKeys._
    val bldr = new MetadataBuilder()
    if (attributes.isDefined) {
      val numericMetadata = ArrayBuffer.empty[Metadata]
      val nominalMetadata = ArrayBuffer.empty[Metadata]
      val binaryMetadata = ArrayBuffer.empty[Metadata]
      attributes.get.foreach {
        case numeric: NumericAttribute =>
          // Skip default numeric attributes.
          if (numeric.withoutIndex != NumericAttribute.defaultAttr) {
            numericMetadata += numeric.toMetadataImpl(withType = false)
          }
        case nominal: NominalAttribute =>
          nominalMetadata += nominal.toMetadataImpl(withType = false)
        case binary: BinaryAttribute =>
          binaryMetadata += binary.toMetadataImpl(withType = false)
        case UnresolvedAttribute =>
      }
      val attrBldr = new MetadataBuilder
      if (numericMetadata.nonEmpty) {
        attrBldr.putMetadataArray(AttributeType.Numeric.name, numericMetadata.toArray)
      }
      if (nominalMetadata.nonEmpty) {
        attrBldr.putMetadataArray(AttributeType.Nominal.name, nominalMetadata.toArray)
      }
      if (binaryMetadata.nonEmpty) {
        attrBldr.putMetadataArray(AttributeType.Binary.name, binaryMetadata.toArray)
      }
      bldr.putMetadata(ATTRIBUTES, attrBldr.build())
      bldr.putLong(NUM_ATTRIBUTES, attributes.get.length)
    } else if (numAttributes.isDefined) {
      bldr.putLong(NUM_ATTRIBUTES, numAttributes.get)
    }
    bldr.build()
  }

  /** Converts to ML metadata with some existing metadata. */
  def toMetadata(existingMetadata: Metadata): Metadata = {
    new MetadataBuilder()
      .withMetadata(existingMetadata)
      .putMetadata(AttributeKeys.ML_ATTR, toMetadataImpl)
      .build()
  }

  /** Converts to ML metadata */
  def toMetadata: Metadata = toMetadata(Metadata.empty)

  /** Converts to a StructField with some existing metadata. */
  def toStructField(existingMetadata: Metadata): StructField = {
    StructField(name, new VectorUDT, nullable = false, toMetadata(existingMetadata))
  }

  /** Converts to a StructField. */
  def toStructField: StructField = toStructField(Metadata.empty)

  override def equals(other: Any): Boolean = {
    other match {
      case o: AttributeGroup =>
        (name == o.name) &&
          (numAttributes == o.numAttributes) &&
          (attributes.map(_.toSeq) == o.attributes.map(_.toSeq))
      case _ =>
        false
    }
  }

  override def hashCode: Int = {
    var sum = 17
    sum = 37 * sum + name.hashCode
    sum = 37 * sum + numAttributes.hashCode
    sum = 37 * sum + attributes.map(_.toSeq).hashCode
    sum
  }

  override def toString: String = toMetadata.toString
}

/**
  * :: DeveloperApi ::
  * Factory methods to create attribute groups.
  */
object AttributeGroup {

  import AttributeKeys._

  /** Creates an attribute group from a [[Metadata]] instance with name. */
  private[attribute] def fromMetadata(metadata: Metadata, name: String): AttributeGroup = {
    import AttributeType._
    if (metadata.contains(ATTRIBUTES)) {
      val numAttrs = metadata.getLong(NUM_ATTRIBUTES).toInt
      val attributes = new Array[Attribute](numAttrs)
      val attrMetadata = metadata.getMetadata(ATTRIBUTES)
      if (attrMetadata.contains(Numeric.name)) {
        attrMetadata.getMetadataArray(Numeric.name)
          .map(NumericAttribute.fromMetadata)
          .foreach { attr =>
            attributes(attr.index.get.toInt) = attr
          }
      }
      if (attrMetadata.contains(Nominal.name)) {
        attrMetadata.getMetadataArray(Nominal.name)
          .map(NominalAttribute.fromMetadata)
          .foreach { attr =>
            attributes(attr.index.get.toInt) = attr
          }
      }
      if (attrMetadata.contains(Binary.name)) {
        attrMetadata.getMetadataArray(Binary.name)
          .map(BinaryAttribute.fromMetadata)
          .foreach { attr =>
            attributes(attr.index.get.toInt) = attr
          }
      }
      var i = 0
      while (i < numAttrs) {
        if (attributes(i) == null) {
          attributes(i) = NumericAttribute.defaultAttr
        }
        i += 1
      }
      new AttributeGroup(name, attributes)
    } else if (metadata.contains(NUM_ATTRIBUTES)) {
      new AttributeGroup(name, metadata.getLong(NUM_ATTRIBUTES))
    } else {
      new AttributeGroup(name)
    }
  }

  /**
    * Creates an attribute group from a `StructField` instance.
    */
  def fromStructField(field: StructField): AttributeGroup = {
    require(field.dataType == new VectorUDT)
    if (field.metadata.contains(ML_ATTR)) {
      fromMetadata(field.metadata.getMetadata(ML_ATTR), field.name)
    } else {
      new AttributeGroup(field.name)
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy