All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.radicalbit.flink.pmml.scala.api.PmmlModel.scala Maven / Gradle / Ivy

The newest version!
/*
 * Copyright (C) 2017  Radicalbit
 *
 * This file is part of flink-JPMML
 *
 * flink-JPMML is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Affero General Public License as
 * published by the Free Software Foundation, either version 3 of the
 * License, or (at your option) any later version.
 *
 * flink-JPMML is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Affero General Public License for more details.
 *
 * You should have received a copy of the GNU Affero General Public License
 * along with flink-JPMML.  If not, see .
 */

package io.radicalbit.flink.pmml.scala.api

import java.io.StringReader
import java.util

import io.radicalbit.flink.pmml.scala.api.exceptions.{InputValidationException, JPMMLExtractionException}
import io.radicalbit.flink.pmml.scala.api.pipeline.Pipeline
import io.radicalbit.flink.pmml.scala.api.reader.ModelReader
import io.radicalbit.flink.pmml.scala.models.prediction.Prediction
import org.apache.flink.ml.math.Vector
import org.dmg.pmml.FieldName
import org.jpmml.evaluator._
import org.jpmml.model.{ImportFilter, JAXBUtil}
import org.xml.sax.InputSource

import scala.collection.JavaConversions._
import scala.util.Try

/** Contains [[PmmlModel]] `fromReader` factory method.
  *
  * As singleton, it guarantees one and only copy of the model when the latter is requested.
  *
  */
object PmmlModel {

  private val evaluatorInstance: ModelEvaluatorFactory = ModelEvaluatorFactory.newInstance()

  /** Loads the distributed path by [[io.radicalbit.flink.pmml.scala.api.reader.FsReader.buildDistributedPath]]
    * and construct a [[PmmlModel]] instance starting from `evalauatorInstance`
    *
    * @param reader The instance providing method in order to read the model from distributed backends lazily
    * @return [[PmmlModel]] instance from the Reader
    */
  private[api] def fromReader(reader: ModelReader): PmmlModel = {
    val readerFromFs = reader.buildDistributedPath
    val result = fromFilteredSource(readerFromFs)

    new PmmlModel(Evaluator(evaluatorInstance.newModelEvaluator(JAXBUtil.unmarshalPMML(result))))
  }

  private def fromFilteredSource(PMMLPath: String) =
    JAXBUtil.createFilteredSource(new InputSource(new StringReader(PMMLPath)), new ImportFilter())

  /** It provides a new instance of the [[PmmlModel]] with [[EmptyEvaluator]]
    *
    * @return [[PmmlModel]] with [[EmptyEvaluator]]
    */
  private[api] def empty = new PmmlModel(Evaluator.empty)

}

/** Provides to the user the model instance and its methods.
  *
  * Users get the instance along the [[io.radicalbit.flink.pmml.scala.RichDataStream.evaluate]] UDF input.
  *
  * {{{
  *   toEvaluateStream.evaluate(reader) { (input, model) =>
  *     val pmmlModel: PmmlModel = model
  *     val prediction = model.predict(vectorized(input))
  *   }
  * }}}
  *
  * This class contains [[PmmlModel#predict]] method and implements the core logic of the project.
  *
  * @param evaluator The PMML model instance
  */
class PmmlModel(private[api] val evaluator: Evaluator) extends Pipeline {

  import io.radicalbit.flink.pmml.scala.api.converter.VectorConverter._

  final def modelName: String = evaluator.model.getModel.getModelName

  /** Implements the entire prediction pipeline, which can be described as 4 main steps:
    *
    * - `validateInput` validates the input to be conform to PMML model size
    *
    * - `prepareInput` prepares the input in full compliance to [[org.jpmml.evaluator.EvaluatorUtil.prepare]] JPMML method
    *
    * - `evaluateInput` evaluates the input against inner PMML model instance and returns a Java Map output
    *
    * - `extractTarget` extracts the target from evaluation result.
    *
    * As final action the pipelined statement is executed by [[Prediction]]
    *
    * @param inputVector the input event as a [[org.apache.flink.ml.math.Vector]] instance
    * @param replaceNan A [[scala.Option]] describing a replace value for not defined vector values
    * @tparam V subclass of [[org.apache.flink.ml.math.Vector]]
    * @return [[Prediction]] instance
    */
  final def predict[V <: Vector](inputVector: V, replaceNan: Option[Double] = None): Prediction = {
    val result = Try {
      val validatedInput = validateInput(inputVector)
      val preparedInput = prepareInput(validatedInput, replaceNan)
      val evaluationResult = evaluateInput(preparedInput)
      val extractResult = extractTarget(evaluationResult)
      extractResult
    }

    Prediction.extractPrediction(result)
  }

  /** Validates the input vector in size terms and converts it as a `Map[String, Any]` (see [[PmmlInput]])
    *
    * @param v The raw input vector
    * @param vec2Pmml The conversion function
    * @return The converted instance
    */
  private[api] def validateInput(v: Vector)(implicit vec2Pmml: (Vector, Evaluator) => PmmlInput): PmmlInput = {
    val modelSize = evaluator.model.getActiveFields.size

    if (v.size != modelSize)
      throw new InputValidationException(s"input vector $v size ${v.size} is not conform to model size $modelSize")
    else
      vec2Pmml(v, evaluator)
  }

  /** Binds each field with input value and prepare the record to be evaluated
    * by [[EvaluatorUtil.prepare]] method.
    *
    * @param input Validated input as a [[Map]] keyed by field name
    * @param replaceNaN Optional replace value in case of missing values
    * @return Prepared input to be evaluated
    */
  private[api] def prepareInput(input: PmmlInput, replaceNaN: Option[Double]): Map[FieldName, FieldValue] = {

    val activeFields = evaluator.model.getActiveFields

    activeFields.map { field =>
      val rawValue = input.get(field.getName.getValue).orElse(replaceNaN).orNull
      prepareAndEmit(Try { EvaluatorUtil.prepare(field, rawValue) }, field.getName)
    }.toMap

  }

  /** Evaluates the prepared input against the PMML model
    *
    * @param preparedInput JPMML prepared input as `Map[FieldName, FieldValue]`
    * @return JPMML output result as a Java map
    */
  private[api] def evaluateInput(preparedInput: Map[FieldName, FieldValue]): util.Map[FieldName, _] =
    evaluator.model.evaluate(preparedInput)

  /** Extracts the target from evaluation result
    * @throws JPMMLExtractionException if the target couldn't be extracted
    * @param evaluationResult outcome from JPMML evaluation
    * @return The prediction value as a [Double]
    */
  private[api] def extractTarget(evaluationResult: java.util.Map[FieldName, _]): Double = {
    val targets = extractTargetFields(evaluationResult)

    targets.headOption.flatMap {
      case (_, target) => extractTargetValue(target)
    } getOrElse (throw new JPMMLExtractionException("Target value is null."))

  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy