All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.tribbloids.spookystuff.pipeline.TransformerLike.scala Maven / Gradle / Ivy

The newest version!
package com.tribbloids.spookystuff.pipeline

import java.util.UUID

import com.tribbloids.spookystuff.PipelineException
import com.tribbloids.spookystuff.sparkbinding.PageRowRDD
import org.apache.spark.ml.param.{Param, ParamMap, Params}

import scala.language.dynamics

/**
 * Created by peng on 25/09/15.
 */

private[pipeline] trait TransformerLike extends Params with Serializable {

  def transform(dataset: PageRowRDD): PageRowRDD

  def copy(extra: ParamMap): TransformerLike = this.defaultCopy(extra)

  def +> (another: SpookyTransformer): ChainTransformer
}

trait SpookyTransformer extends TransformerLike{

  override def copy(extra: ParamMap): SpookyTransformer = this.defaultCopy(extra)

  def +> (another: SpookyTransformer): ChainTransformer = new ChainTransformer(Seq(this)) +> another

  def $(col: Param[String]): Symbol = {
    val colName = Option(getOrDefault(col))
    colName.map(Symbol(_)).orNull
  }
}

trait DynamicSetter extends SpookyTransformer with Dynamic {

  def applyDynamic(methodName: String)(args: Any*): this.type = {
    assert(args.length == 1)
    val arg = args.head

    if (methodName.startsWith("set")) {
      val fieldName = methodName.stripPrefix("set")
      val field = this.getClass.getMethod(fieldName) //this gets all the getter generated by Scala
      val value = field.invoke(this).asInstanceOf[Param[Any]]

      set(value, arg)
      this
    }
    else throw new PipelineException(s"setting $methodName doesn't exist")
  }
}

class ChainTransformer(
                        self: Seq[SpookyTransformer],
                        override val uid: String =
                        classOf[ChainTransformer].getCanonicalName + "_" + UUID.randomUUID().toString
                        ) extends TransformerLike {

  //this is mandatory for Params.defaultCopy()
  def this(uid: String) = this(Nil, uid)

  override def transform(dataset: PageRowRDD): PageRowRDD = self.foldLeft(dataset) {
    (rdd, transformer) =>
      transformer.transform(rdd)
  }

  override def copy(extra: ParamMap): ChainTransformer = new ChainTransformer(
    self = this
      .self
      .map(_.copy(extra)),
    uid = this.uid
  )

  def +> (another: SpookyTransformer): ChainTransformer = new ChainTransformer(
    this.self :+ another,
    uid = this.uid
  )
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy