org.apache.flink.ml.pipeline.ChainedTransformer.scala Maven / Gradle / Ivy
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.ml.pipeline
import org.apache.flink.api.scala.DataSet
import org.apache.flink.ml.common.ParameterMap
/** [[Transformer]] which represents the chaining of two [[Transformer]].
*
* A [[ChainedTransformer]] can be treated as regular [[Transformer]]. Upon calling the fit or
* transform operation, the data is piped through all [[Transformer]] of the pipeline.
*
* The pipeline mechanism has been inspired by scikit-learn
*
* @param left Left [[Transformer]] of the pipeline
* @param right Right [[Transformer]] of the pipeline
* @tparam L Type of the left [[Transformer]]
* @tparam R Type of the right [[Transformer]]
*/
case class ChainedTransformer[L <: Transformer[L], R <: Transformer[R]](left: L, right: R)
extends Transformer[ChainedTransformer[L, R]] {
}
object ChainedTransformer{
/** [[TransformOperation]] implementation for [[ChainedTransformer]].
*
* First the transform operation of the left [[Transformer]] is called with the input data. This
* generates intermediate data which is fed to the right [[Transformer]]'s transform operation.
*
* @param transformOpLeft [[TransformOperation]] for the left [[Transformer]]
* @param transformOpRight [[TransformOperation]] for the right [[Transformer]]
* @tparam L Type of the left [[Transformer]]
* @tparam R Type of the right [[Transformer]]
* @tparam I Type of the input data
* @tparam T Type of the intermediate output data
* @tparam O Type of the output data
* @return
*/
implicit def chainedTransformOperation[
L <: Transformer[L],
R <: Transformer[R],
I,
T,
O](implicit
transformOpLeft: TransformOperation[L, I, T],
transformOpRight: TransformOperation[R, T, O])
: TransformOperation[ChainedTransformer[L,R], I, O] = {
new TransformOperation[ChainedTransformer[L, R], I, O] {
override def transform(
chain: ChainedTransformer[L, R],
transformParameters: ParameterMap,
input: DataSet[I]): DataSet[O] = {
val intermediateResult = transformOpLeft.transform(chain.left, transformParameters, input)
transformOpRight.transform(chain.right, transformParameters, intermediateResult)
}
}
}
/** [[FitOperation]] implementation for [[ChainedTransformer]].
*
* First the fit operation of the left [[Transformer]] is called with the input data. Then
* the data is transformed by this [[Transformer]] and the given to the fit operation of the
* right [[Transformer]].
*
* @param leftFitOperation [[FitOperation]] for the left [[Transformer]]
* @param leftTransformOperation [[TransformOperation]] for the left [[Transformer]]
* @param rightFitOperation [[FitOperation]] for the right [[Transformer]]
* @tparam L Type of the left [[Transformer]]
* @tparam R Type of the right [[Transformer]]
* @tparam I Type of the input data
* @tparam T Type of the intermediate output data
* @return
*/
implicit def chainedFitOperation[L <: Transformer[L], R <: Transformer[R], I, T](implicit
leftFitOperation: FitOperation[L, I],
leftTransformOperation: TransformOperation[L, I, T],
rightFitOperation: FitOperation[R, T]): FitOperation[ChainedTransformer[L, R], I] = {
new FitOperation[ChainedTransformer[L, R], I] {
override def fit(
instance: ChainedTransformer[L, R],
fitParameters: ParameterMap,
input: DataSet[I]): Unit = {
instance.left.fit(input, fitParameters)
val intermediateResult = instance.left.transform(input, fitParameters)
instance.right.fit(intermediateResult, fitParameters)
}
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy