All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.radicalbit.flink.pmml.scala.api.functions.EvaluationCoFunction.scala Maven / Gradle / Ivy

The newest version!
/*
 * Copyright (C) 2017  Radicalbit
 *
 * This file is part of flink-JPMML
 *
 * flink-JPMML is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Affero General Public License as
 * published by the Free Software Foundation, either version 3 of the
 * License, or (at your option) any later version.
 *
 * flink-JPMML is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Affero General Public License for more details.
 *
 * You should have received a copy of the GNU Affero General Public License
 * along with flink-JPMML.  If not, see .
 */

package io.radicalbit.flink.pmml.scala.api.functions

import io.radicalbit.flink.pmml.scala.api.PmmlModel
import io.radicalbit.flink.pmml.scala.api.exceptions.ModelLoadingException
import io.radicalbit.flink.pmml.scala.api.managers.{MetadataManager, ModelsManager}
import io.radicalbit.flink.pmml.scala.api.reader.ModelReader
import io.radicalbit.flink.pmml.scala.logging.LazyLogging
import io.radicalbit.flink.pmml.scala.models.control.{AddMessage, DelMessage, ServingMessage}
import io.radicalbit.flink.pmml.scala.models.core.{ModelId, ModelInfo}
import io.radicalbit.flink.pmml.scala.models.state.CheckpointType.MetadataCheckpoint
import org.apache.flink.api.common.state.{ListState, ListStateDescriptor}
import org.apache.flink.api.common.typeinfo.{TypeHint, TypeInformation}
import org.apache.flink.configuration.Configuration
import org.apache.flink.runtime.state.{FunctionInitializationContext, FunctionSnapshotContext}
import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction
import org.apache.flink.streaming.api.functions.co.CoProcessFunction
import org.apache.flink.util.Collector

import scala.collection.JavaConverters._
import scala.collection.{immutable, mutable}
import scala.util.{Failure, Success, Try}

/** Abstract class extending a [[CoProcessFunction]]; it provides:
  * two maps for caching both [[PmmlModel]] and Metadata of each Model
  * the open method in order to initialize the Metadata Map
  * the processElement2 in order to handle models and metadata against a control stream
  *
  * Abstract class extends [[CheckpointedFunction]] and provides therefore:
  * the snapshotState method in order to checkpoint the current state of the operator
  * the initializeState in order to provide an initial state at the operator and/or restore the latest
  *
  * @tparam EVENT The input Type of the event to predict
  * @tparam CTRL The control stream Type. Note: It must extend [[io.radicalbit.flink.pmml.scala.models.control.ServingMessage]]
  * @tparam OUT The output Type
  */
private[scala] abstract class EvaluationCoFunction[EVENT, CTRL <: ServingMessage, OUT]
    extends CoProcessFunction[EVENT, CTRL, OUT]
    with CheckpointedFunction
    with LazyLogging {

  @transient
  private var snapshotMetadata: ListState[MetadataCheckpoint] = _

  @transient
  final protected var servingMetadata: immutable.Map[ModelId, ModelInfo] = _

  final protected lazy val servingModels: mutable.WeakHashMap[Int, PmmlModel] =
    mutable.WeakHashMap.empty[Int, PmmlModel]

  override def processElement2(control: CTRL,
                               ctx: CoProcessFunction[EVENT, CTRL, OUT]#Context,
                               out: Collector[OUT]): Unit = {
    manageModels(control)
    manageMetadata(control)
  }

  override def snapshotState(context: FunctionSnapshotContext): Unit = {
    snapshotMetadata.clear()
    snapshotMetadata.add(new MetadataCheckpoint(servingMetadata.asJava))
  }

  override def initializeState(context: FunctionInitializationContext): Unit = {
    servingMetadata = immutable.Map.empty[ModelId, ModelInfo]

    val description = new ListStateDescriptor[MetadataCheckpoint](
      "metadata-snapshot",
      TypeInformation.of(new TypeHint[MetadataCheckpoint]() {}))

    snapshotMetadata = context.getOperatorStateStore.getUnionListState(description)

    if (context.isRestored) {
      Try(snapshotMetadata.get()) match {
        case Success(state) => servingMetadata ++= state.asScala.toSet[MetadataCheckpoint].flatMap(_.asScala).toMap
        case Failure(_) => logger.info("Not available state in ListState!")
      }
    }
  }

  final def loadModel(path: String): PmmlModel =
    Try(PmmlModel.fromReader(ModelReader(path))) match {
      case Success(model) =>
        logger.info("Model has been successfully loaded, model name: {}", model.modelName)
        model
      case Failure(e) => throw new ModelLoadingException(e.getMessage, e)
    }

  final def fromMetadata(modelId: String): PmmlModel = {
    val currentModelId: ModelId = ModelId.fromIdentifier(modelId)
    if (!servingMetadata.contains(currentModelId)) PmmlModel.empty
    else addAndRetrieveModel(modelId, currentModelId)
  }

  final private def addAndRetrieveModel(modelId: String, currentModelId: ModelId): PmmlModel = {
    val currentModelPath = servingMetadata(currentModelId).path
    val loadedModel = loadModel(currentModelPath)
    servingModels += (modelId.hashCode -> loadedModel)
    loadedModel
  }

  final private def manageMetadata(control: CTRL): Unit =
    control match {
      case add: AddMessage =>
        servingMetadata = MetadataManager(add, servingMetadata)
      case del: DelMessage =>
        servingMetadata = MetadataManager(del, servingMetadata)
    }

  final private def manageModels(control: CTRL): Unit =
    control match {
      case del: DelMessage => servingModels --= ModelsManager(del, servingModels)
      case _ =>
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy