com.eharmony.aloha.models.h2o.H2oModel.scala Maven / Gradle / Ivy
package com.eharmony.aloha.models.h2o
import com.eharmony.aloha.factory.{ModelParser, ModelParserWithSemantics, ParserProviderCompanion}
import com.eharmony.aloha.id.ModelIdentity
import com.eharmony.aloha.io.AlohaReadable
import com.eharmony.aloha.io.sources.ModelSource
import com.eharmony.aloha.models.BaseModel
import com.eharmony.aloha.models.h2o.H2oModel.Features
import com.eharmony.aloha.models.h2o.categories._
import com.eharmony.aloha.models.h2o.compiler.Compiler
import com.eharmony.aloha.models.h2o.json.{H2oSpec, H2oAst}
import com.eharmony.aloha.reflect.{RefInfo, RefInfoOps}
import com.eharmony.aloha.score.Scores.Score
import com.eharmony.aloha.score.basic.ModelOutput
import com.eharmony.aloha.score.conversions.ScoreConverter
import com.eharmony.aloha.semantics.Semantics
import com.eharmony.aloha.semantics.func.GenAggFunc
import com.eharmony.aloha.util.{EitherHelpers, Logging}
import hex.genmodel.GenModel
import hex.genmodel.easy.exception.PredictUnknownCategoricalLevelException
import hex.genmodel.easy.{EasyPredictModelWrapper, RowData}
import spray.json.{DeserializationException, JsValue, JsonReader}
import scala.annotation.tailrec
import scala.collection.{immutable => sci}
import scala.util.{Failure, Success, Try}
/**
* Created by deak on 9/30/15.
*/
// TODO: Need to make sure the model source is used rather than a model.
// This is because we want to protected against a user compiling a model locally and serializing it.
// Then deserializing on a cluster where the class files don't exist. If we follow the same methodology as
// the VW model, then we can get the same guarantees.
final case class H2oModel[-A, +B](
modelId: ModelIdentity,
modelSource: ModelSource,
featureNames: sci.IndexedSeq[String],
featureFunctions: sci.IndexedSeq[FeatureFunction[A]],
numMissingThreshold: Option[Int] = None)(implicit private[this] val scb: ScoreConverter[B])
extends BaseModel[A, B]
with Logging {
// Because H2o's RowData object is essentially a Map of String to Object, we unapply the wrapper
// and throw away the type information on the function return type. We have type safety because
// FeatureFunction is sealed (ADT).
@transient private[this] lazy val lazyAnyRefFF = featureFunctions map {
case DoubleFeatureFunction(f) => f
case StringFeatureFunction(f) => f
}
@transient private[this] lazy val h2oPredictor: RowData => Either[IllConditioned, B] = {
val sourceFile = new java.io.File(modelSource.localVfs.descriptor)
val p = getH2oPredictor(sourceFile, _.fromFile).get
if (modelSource.shouldDelete)
Try[Unit] { sourceFile.delete() }
p
}
// Force initialization of lazy vals.
require(lazyAnyRefFF != null)
require(h2oPredictor != null)
override private[aloha] def getScore(a: A)(implicit audit: Boolean): (ModelOutput[B], Option[Score]) = {
val f = constructFeatures(a)
if (!f.missingOk)
failureDueToMissing(f.missing)
else
try {
predict(f)
} catch {
// We know about this specifically from the H2o documentation.
case e: PredictUnknownCategoricalLevelException => handleBadCategorical(e, f)
case e: IllegalArgumentException if isCategoricalMissing(e, f) => handleMissingCategorical(e, f)
}
}
protected[this] def predict(f: Features[RowData])(implicit audit: Boolean) =
h2oPredictor(f.features).fold(ill => failure(Seq(ill.errorMsg), getMissingVariables(f.missing)),
s => success(s))
/**
*/
/**
* ''Attempt'' to determine if a categorical was missing in the h2o model.
*
* Currently (3.6.0.3), h2o generated model says: "" when a categorical value is not supplied.
* This is determined from inspecting the generated H2o model code so it's likely brittle and subject
* to change but its better than throwing an IllegalArgumentException with no diagnostics information
* when there is missing data in a categorical variable.
* @param e exception thrown by h2o
* @param f the data passed in.
* @return whether to attempt to recover. Don't attempt to recover unless a string-based feature appears to
* be missing. This is so that we can diagnose when the model will fail every time.
*/
protected[this] def isCategoricalMissing(e: IllegalArgumentException, f: Features[RowData]): Boolean =
if (e.getClass == classOf[IllegalArgumentException] && e.getMessage.toLowerCase.contains("categorical")) {
val foundSomeMissingString = featureFunctions.view.zipWithIndex.exists {
case (StringFeatureFunction(sff), i) if f.missing contains sff.specification => true
case _ => false
}
foundSomeMissingString
}
else false
/**
* Report a problem presumably resulting from a missing categorical variable.
* @param t the error to be reported.
* @param f the feature values
* @param audit whether to audit the score
* @return
*/
protected[this] def handleMissingCategorical(t: IllegalArgumentException, f: Features[RowData])(implicit audit: Boolean) = {
val missing = featureFunctions.view.zipWithIndex.collect {
case (StringFeatureFunction(sff), i) if f.missing.contains(sff.specification) => featureNames(i)
}
val prefix = "H2o model may have encountered a missing categorical variable. Likely features: " + missing.mkString(", ")
val stackError = t.getStackTrace.headOption.fold(List.empty[String])(s =>
List("See: " + s.getClassName + "." + s.getMethodName + "(" + s.getFileName + ":" + s.getLineNumber + ")"))
failure(prefix :: stackError, f.missing.keys)
}
protected[this] def handleBadCategorical(e: PredictUnknownCategoricalLevelException, f: Features[RowData])(implicit audit: Boolean) =
failure(Seq(s"unknown categorical value ${e.getUnknownLevel} for variable: ${e.getColumnName}"), getMissingVariables(f.missing))
// TODO: Extract to trait.
protected[this] def getMissingVariables(missing: Map[String, Seq[String]]): Seq[String] =
missing.unzip._2.foldLeft(Set.empty[String])(_ ++ _).toIndexedSeq.sorted
protected[this] def failureDueToMissing(missing: Map[String, Seq[String]])(implicit audit: Boolean) =
failure(Seq(s"Too many features with missing variables: ${missing.count(_._2.nonEmpty)}"), getMissingVariables(missing))
protected[this] def constructFeatures(a: A): Features[RowData] = {
// If we are going to err out, allow a linear scan (with repeated work so that we can get richer error
// diagnostics. Only include the values where the list of missing accessors variables is not empty.
def fullMissing(ff: sci.IndexedSeq[GenAggFunc[A, _]]): Map[String, Seq[String]] =
ff.foldLeft(Map.empty[String, Seq[String]])((missing, f) => f.accessorOutputMissing(a) match {
case m if m.nonEmpty => missing + (f.specification -> m)
case _ => missing
})
@tailrec def features(i: Int,
n: Int,
rowData: RowData,
missing: Map[String, Seq[String]],
ff: sci.IndexedSeq[GenAggFunc[A, Option[AnyRef]]]): Features[RowData] =
if (i >= n) {
val numMissingOk = numMissingThreshold.fold(true)(missing.size <= _)
val m = if (numMissingOk) missing else fullMissing(ff)
Features(rowData, m, numMissingOk)
}
else ff(i)(a) match {
case Some(x) => features(i + 1, n, rowData + (featureNames(i), x), missing, ff)
case None => features(i + 1, n, rowData, missing + (ff(i).specification -> ff(i).accessorOutputMissing(a)), ff)
}
// Store lazyAnyRefFF to a local variable to avoid the repeated cost of asking for the lazy val.
val ff = lazyAnyRefFF
features(0, ff.size, new RowData, Map.empty, ff)
}
protected[h2o] def mapRetrievalError[B: RefInfo](genModel: GenModel, retrieval: Either[PredictionFuncRetrievalError, RowData => Either[IllConditioned, B]]) = retrieval match {
case Right(f) => Success(f)
case Left(UnsupportedModelCategory(category)) => Failure(new UnsupportedOperationException(s"In model ${genModel.getClass.getCanonicalName}: ModelCategory ${category.name} non supported."))
case Left(TypeCoercionNotFound(category)) => Failure(new IllegalArgumentException(s"In model ${genModel.getClass.getCanonicalName}: Could not ${category.name} model to Aloha output type: ${RefInfoOps.toString[B]}."))
}
protected[h2o] def getH2oPredictor[C](input: => C, f: AlohaReadable[Try[GenModel]] => C => Try[GenModel]) = {
val compiler = new Compiler[GenModel]
implicit val rib = scb.ri
for {
genModel <- f(compiler)(input)
predictorRetrieval = H2oModelCategory.predictor[B](new EasyPredictModelWrapper(genModel))
predictor <- mapRetrievalError[B](genModel, predictorRetrieval)
} yield predictor
}
}
/**
* {{{
*
* {
* "modelType": "H2o",
* "modelId": { "id": 0, "name": "" },
* "features": {
* "Gender": "gender.toString"
* },
* "model": "b64 encoded model"
* }
*
* }}}
*
* {{{
* {
* "modelType": "H2o",
* "modelId": { "id": 0, "name": "" },
* "features": {
* "Gender": "gender.toString"
* },
* "modelUrl": "hdfs://asdf",
* "via": "vfs1"
* }
* }}}
*/
object H2oModel extends ParserProviderCompanion
with Logging {
protected[h2o] case class Features[F](features: F,
missing: Map[String, Seq[String]] = Map.empty,
missingOk: Boolean = true)
override def parser: ModelParser = Parser
object Parser extends ModelParserWithSemantics with EitherHelpers { self =>
val modelType = "H2o"
override def modelJsonReader[A, B](semantics: Semantics[A])(implicit jrB: JsonReader[B], scB: ScoreConverter[B]): JsonReader[H2oModel[A, B]] = new JsonReader[H2oModel[A, B]] {
override def read(json: JsValue): H2oModel[A, B] = {
val h2o = json.convertTo[H2oAst]
features(h2o.features.toSeq, semantics) match {
case Left(errors) => throw new DeserializationException(errors.mkString("errors: ", "\n ", ""))
case Right(featureMap) =>
val (names, functions) = featureMap.toIndexedSeq.unzip
H2oModel[A, B](h2o.modelId,
h2o.modelSource,
names,
functions,
h2o.numMissingThreshold)
}
}
// TODO: Copied from RegressionModel. Refactor for reuse.
private[this] def features[A](featureMap: Seq[(String, H2oSpec)], semantics: Semantics[A]) =
mapSeq(featureMap) { case (k, s) =>
s.compile(semantics).
left.map(Seq(s"Error processing spec '${s.spec}'") ++ _).
right.map(v => (k, v))
}
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy