All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.spark.ml.h2o.H2OPipeline.scala Maven / Gradle / Ivy

There is a newer version: 1.6.8
Show newest version
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements.  See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License.  You may obtain a copy of the License at
*
*    http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml.h2o

import org.apache.spark.annotation.Since
import org.apache.spark.ml.Pipeline.SharedReadWrite
import org.apache.spark.ml.{PipelineStage, Transformer, PipelineModel, Pipeline}
import _root_.org.apache.spark.sql.DataFrame
import org.apache.spark.ml.util.{Identifiable, MLReadable, MLReader}

/**
  * Exact Spark pipeline with new pipeline stage called OneTimeTransformer. This transformer is called only during the
  * pipeline.fit so can be used to do some additional work during fitting the model. This transformer is removed from
  * list of transformers in the PipelineModel since we don't want to execute this estimator also during prediction
  */
class H2OPipeline(override val uid: String) extends Pipeline {
  def this() = this(Identifiable.randomUID("pipeline"))

  override def fit(dataset: DataFrame): PipelineModel = {
    val model = super.fit(dataset)
    val newStages = model.stages.filter(p=> !p.isInstanceOf[OneTimeTransformer])
    new PipelineModel(model.uid,newStages).setParent(model.parent)
  }
}

object H2OPipeline extends MLReadable[H2OPipeline] {

  @Since("1.6.0")
  override def read: MLReader[H2OPipeline] = new H2OPipelineReader

  @Since("1.6.0")
  override def load(path: String): H2OPipeline = super.load(path)

  private class H2OPipelineReader extends MLReader[H2OPipeline] {

    /** Checked against metadata when loading model */
    private val className = classOf[H2OPipeline].getName

    override def load(path: String): H2OPipeline = {
      val (uid: String, stages: Array[PipelineStage]) = SharedReadWrite.load(className, sc, path)
      new H2OPipeline(uid).setStages(stages)
    }
  }
}

/**
  * Special kind of transformer which is executed only in the H2OPipeline.fit call
  */
abstract class OneTimeTransformer extends Transformer {
}





© 2015 - 2025 Weber Informatics LLC | Privacy Policy