All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.flink.ml.pipeline.ChainedPredictor.scala Maven / Gradle / Ivy

/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.flink.ml.pipeline

import org.apache.flink.api.scala.DataSet
import org.apache.flink.ml.common.ParameterMap

/** [[Predictor]] which represents a pipeline of possibly multiple [[Transformer]] and a trailing
  * [[Predictor]].
  *
  * The [[ChainedPredictor]] can be used as a regular [[Predictor]]. Upon calling the fit method,
  * the input data is piped through all preceding [[Transformer]] in the pipeline and the resulting
  * data is given to the trailing [[Predictor]]. The same holds true for the predict operation.
  *
  * The pipeline mechanism has been inspired by scikit-learn
  *
  * @param transformer Preceding [[Transformer]] of the pipeline
  * @param predictor Trailing [[Predictor]] of the pipeline
  * @tparam T Type of the preceding [[Transformer]]
  * @tparam P Type of the trailing [[Predictor]]
  */
case class ChainedPredictor[T <: Transformer[T], P <: Predictor[P]](transformer: T, predictor: P)
  extends Predictor[ChainedPredictor[T, P]]{}

object ChainedPredictor{

  /** [[PredictOperation]] for the [[ChainedPredictor]].
    *
    * The [[PredictOperation]] requires the [[TransformOperation]] of the preceding [[Transformer]]
    * and the [[PredictOperation]] of the trailing [[Predictor]]. Upon calling predict, the testing
    * data is first transformed by the preceding [[Transformer]] and the result is then used to
    * calculate the prediction via the trailing [[Predictor]].
    *
    * @param transformOperation [[TransformOperation]] for the preceding [[Transformer]]
    * @param predictOperation [[PredictOperation]] for the trailing [[Predictor]]
    * @tparam T Type of the preceding [[Transformer]]
    * @tparam P Type of the trailing [[Predictor]]
    * @tparam Testing Type of the testing data
    * @tparam Intermediate Type of the intermediate data produced by the preceding [[Transformer]]
    * @tparam Prediction Type of the predicted data generated by the trailing [[Predictor]]
    * @return
    */
  implicit def chainedPredictOperation[
      T <: Transformer[T],
      P <: Predictor[P],
      Testing,
      Intermediate,
      Prediction](
      implicit transformOperation: TransformOperation[T, Testing, Intermediate],
      predictOperation: PredictOperation[P, Intermediate, Prediction])
    : PredictOperation[ChainedPredictor[T, P], Testing, Prediction] = {

    new PredictOperation[ChainedPredictor[T, P], Testing, Prediction] {
      override def predict(
          instance: ChainedPredictor[T, P],
          predictParameters: ParameterMap,
          input: DataSet[Testing])
        : DataSet[Prediction] = {

        val testing = instance.transformer.transform(input, predictParameters)
        instance.predictor.predict(testing, predictParameters)
      }
    }
  }

  /** [[FitOperation]] for the [[ChainedPredictor]].
    *
    * The [[FitOperation]] requires the [[FitOperation]] and the [[TransformOperation]] of the
    * preceding [[Transformer]] as well as the [[FitOperation]] of the trailing [[Predictor]].
    * Upon calling fit, the preceding [[Transformer]] is first fitted to the training data.
    * The training data is then transformed by the fitted [[Transformer]]. The transformed data
    * is then used to fit the [[Predictor]].
    *
    * @param fitOperation [[FitOperation]] of the preceding [[Transformer]]
    * @param transformOperation [[TransformOperation]] of the preceding [[Transformer]]
    * @param predictorFitOperation [[PredictOperation]] of the trailing [[Predictor]]
    * @tparam L Type of the preceding [[Transformer]]
    * @tparam R Type of the trailing [[Predictor]]
    * @tparam I Type of the training data
    * @tparam T Type of the intermediate data
    * @return
    */
  implicit def chainedFitOperation[L <: Transformer[L], R <: Predictor[R], I, T](implicit
    fitOperation: FitOperation[L, I],
    transformOperation: TransformOperation[L, I, T],
    predictorFitOperation: FitOperation[R, T]): FitOperation[ChainedPredictor[L, R], I] = {
    new FitOperation[ChainedPredictor[L, R], I] {
      override def fit(
          instance: ChainedPredictor[L, R],
          fitParameters: ParameterMap,
          input: DataSet[I])
        : Unit = {
        instance.transformer.fit(input, fitParameters)
        val intermediateResult = instance.transformer.transform(input, fitParameters)
        instance.predictor.fit(intermediateResult, fitParameters)
      }
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy