All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.johnsnowlabs.nlp.RecursivePipeline.scala Maven / Gradle / Ivy

package com.johnsnowlabs.nlp

import org.apache.spark.ml.util.Identifiable
import org.apache.spark.ml.{Estimator, Pipeline, PipelineModel, Transformer}
import org.apache.spark.sql.Dataset

import scala.collection.mutable.ListBuffer

class RecursivePipeline(override val uid: String) extends Pipeline {

  def this() = this(Identifiable.randomUID("RECURSIVE_PIPELINE"))

  /**Workaround to PipelineModel being private in Spark*/
  private def createPipeline(dataset: Dataset[_], uid: String, transformers: Array[Transformer]) = {
    new Pipeline().setStages(transformers).fit(dataset)
  }


  /** Has to behave as of spark 2.x.x */
  override def fit(dataset: Dataset[_]): PipelineModel = {
    transformSchema(dataset.schema, logging = true)
    val theStages = $(stages)
    var indexOfLastEstimator = -1
    theStages.view.zipWithIndex.foreach { case (stage, index) =>
      stage match {
        case _: Estimator[_] =>
          indexOfLastEstimator = index
        case _ =>
      }
    }
    var curDataset = dataset
    val transformers = ListBuffer.empty[Transformer]
    theStages.view.zipWithIndex.foreach { case (stage, index) =>
      if (index <= indexOfLastEstimator) {
        val transformer = stage match {
          case estimator: HasRecursiveFit[_] =>
            estimator.recursiveFit(curDataset, createPipeline(curDataset, uid, transformers.toArray))
          case estimator: Estimator[_] =>
            estimator.fit(curDataset)
          case t: Transformer =>
            t
          case _ =>
            throw new IllegalArgumentException(
              s"Does not support stage $stage of type ${stage.getClass}")
        }
        if (index < indexOfLastEstimator) {
          curDataset = transformer.transform(curDataset)
        }
        transformers += transformer
      } else {
        transformers += stage.asInstanceOf[Transformer]
      }
    }

    createPipeline(curDataset, uid, transformers.toArray).setParent(this)
  }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy