org.apache.flink.ml.pipeline.ChainedPredictor.scala Maven / Gradle / Ivy
The newest version!
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.ml.pipeline
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.scala.DataSet
import org.apache.flink.ml.common.ParameterMap
/** [[Predictor]] which represents a pipeline of possibly multiple [[Transformer]] and a trailing
* [[Predictor]].
*
* The [[ChainedPredictor]] can be used as a regular [[Predictor]]. Upon calling the fit method,
* the input data is piped through all preceding [[Transformer]] in the pipeline and the resulting
* data is given to the trailing [[Predictor]]. The same holds true for the predict operation.
*
* The pipeline mechanism has been inspired by scikit-learn
*
* @param transformer Preceding [[Transformer]] of the pipeline
* @param predictor Trailing [[Predictor]] of the pipeline
* @tparam T Type of the preceding [[Transformer]]
* @tparam P Type of the trailing [[Predictor]]
*/
case class ChainedPredictor[T <: Transformer[T], P <: Predictor[P]](transformer: T, predictor: P)
extends Predictor[ChainedPredictor[T, P]]{}
object ChainedPredictor{
/** [[PredictDataSetOperation]] for the [[ChainedPredictor]].
*
* The [[PredictDataSetOperation]] requires the [[TransformDataSetOperation]] of the preceding
* [[Transformer]] and the [[PredictDataSetOperation]] of the trailing [[Predictor]]. Upon
* calling predict, the testing data is first transformed by the preceding [[Transformer]] and
* the result is then used to calculate the prediction via the trailing [[Predictor]].
*
* @param transformOperation [[TransformDataSetOperation]] for the preceding [[Transformer]]
* @param predictOperation [[PredictDataSetOperation]] for the trailing [[Predictor]]
* @tparam T Type of the preceding [[Transformer]]
* @tparam P Type of the trailing [[Predictor]]
* @tparam Testing Type of the testing data
* @tparam Intermediate Type of the intermediate data produced by the preceding [[Transformer]]
* @tparam Prediction Type of the predicted data generated by the trailing [[Predictor]]
* @return
*/
implicit def chainedPredictOperation[
T <: Transformer[T],
P <: Predictor[P],
Testing,
Intermediate,
Prediction](
implicit transformOperation: TransformDataSetOperation[T, Testing, Intermediate],
predictOperation: PredictDataSetOperation[P, Intermediate, Prediction])
: PredictDataSetOperation[ChainedPredictor[T, P], Testing, Prediction] = {
new PredictDataSetOperation[ChainedPredictor[T, P], Testing, Prediction] {
override def predictDataSet(
instance: ChainedPredictor[T, P],
predictParameters: ParameterMap,
input: DataSet[Testing])
: DataSet[Prediction] = {
val testing = instance.transformer.transform(input, predictParameters)
instance.predictor.predict(testing, predictParameters)
}
}
}
/** [[FitOperation]] for the [[ChainedPredictor]].
*
* The [[FitOperation]] requires the [[FitOperation]] and the [[TransformDataSetOperation]] of
* the preceding [[Transformer]] as well as the [[FitOperation]] of the trailing [[Predictor]].
* Upon calling fit, the preceding [[Transformer]] is first fitted to the training data.
* The training data is then transformed by the fitted [[Transformer]]. The transformed data
* is then used to fit the [[Predictor]].
*
* @param fitOperation [[FitOperation]] of the preceding [[Transformer]]
* @param transformOperation [[TransformDataSetOperation]] of the preceding [[Transformer]]
* @param predictorFitOperation [[PredictDataSetOperation]] of the trailing [[Predictor]]
* @tparam L Type of the preceding [[Transformer]]
* @tparam R Type of the trailing [[Predictor]]
* @tparam I Type of the training data
* @tparam T Type of the intermediate data
* @return
*/
implicit def chainedFitOperation[L <: Transformer[L], R <: Predictor[R], I, T](implicit
fitOperation: FitOperation[L, I],
transformOperation: TransformDataSetOperation[L, I, T],
predictorFitOperation: FitOperation[R, T]): FitOperation[ChainedPredictor[L, R], I] = {
new FitOperation[ChainedPredictor[L, R], I] {
override def fit(
instance: ChainedPredictor[L, R],
fitParameters: ParameterMap,
input: DataSet[I])
: Unit = {
instance.transformer.fit(input, fitParameters)
val intermediateResult = instance.transformer.transform(input, fitParameters)
instance.predictor.fit(intermediateResult, fitParameters)
}
}
}
implicit def chainedEvaluationOperation[
T <: Transformer[T],
P <: Predictor[P],
Testing,
Intermediate,
PredictionValue](
implicit transformOperation: TransformDataSetOperation[T, Testing, Intermediate],
evaluateOperation: EvaluateDataSetOperation[P, Intermediate, PredictionValue],
testingTypeInformation: TypeInformation[Testing],
predictionValueTypeInformation: TypeInformation[PredictionValue])
: EvaluateDataSetOperation[ChainedPredictor[T, P], Testing, PredictionValue] = {
new EvaluateDataSetOperation[ChainedPredictor[T, P], Testing, PredictionValue] {
override def evaluateDataSet(
instance: ChainedPredictor[T, P],
evaluateParameters: ParameterMap,
testing: DataSet[Testing])
: DataSet[(PredictionValue, PredictionValue)] = {
val intermediate = instance.transformer.transform(testing, evaluateParameters)
instance.predictor.evaluate(intermediate, evaluateParameters)
}
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy