org.pmml4s.xml.GeneralRegressionBuilder.scala Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of pmml4s_3 Show documentation
Show all versions of pmml4s_3 Show documentation
A PMML scoring library in Scala
The newest version!
/*
* Copyright (c) 2017-2023 AutoDeployAI
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.pmml4s.xml
import org.pmml4s.common.{Interval, Matrix, Value}
import org.pmml4s.model._
/**
* Builder of General Regression Model
*/
class GeneralRegressionBuilder extends Builder[GeneralRegressionModel] {
protected var attributes: GeneralRegressionAttributes = _
private var parameterList: ParameterList = _
private var factorList: FactorList = _
private var covariateList: CovariateList = _
private var ppMatrix: PPMatrix = _
private var pCovMatrix: PCovMatrix = _
private var paramMatrix: ParamMatrix = _
private var eventValues: EventValues = _
private var baseCumHazardTables: BaseCumHazardTables = _
/** Builds a PMML model from a specified XML reader. */
override def build(reader: XMLEventReader, attrs: XmlAttrs, parent: Model): GeneralRegressionModel = {
this.parent = parent
this.attributes = makeAttributes(attrs)
traverseModel(reader, ElemTags.GENERAL_REGRESSION_MODEL, {
case EvElemStart(_, ElemTags.PARAMETER_LIST, _, _) => parameterList = makeParameterList(reader)
case EvElemStart(_, ElemTags.FACTOR_LIST, _, _) => factorList = makeFactorList(reader)
case EvElemStart(_, ElemTags.COVARIATE_LIST, _, _) => covariateList = makeCovariateList(reader)
case EvElemStart(_, ElemTags.PP_MATRIX, _, _) => ppMatrix = makePPMatrix(reader)
case EvElemStart(_, ElemTags.P_COV_MATRIX, attrs, _) => pCovMatrix = makePCovMatrix(reader, attrs)
case EvElemStart(_, ElemTags.PARAM_MATRIX, _, _) => paramMatrix = makeParamMatrix(reader)
case EvElemStart(_, ElemTags.EVENT_VALUES, _, _) => eventValues = makeEventValues(reader)
case EvElemStart(_, ElemTags.BASE_CUM_HAZARD_TABLES, attrs, _) =>
baseCumHazardTables = makeBaseCumHazardTables(reader, attrs)
})
new GeneralRegressionModel(parent, attributes, miningSchema,
parameterList, Option(factorList), Option(covariateList), ppMatrix, Option(pCovMatrix), paramMatrix, Option(eventValues), Option(baseCumHazardTables),
output, targets, localTransformations, modelStats, modelExplanation, modelVerification, extensions.toIndexedSeq)
}
def makeBaselineCell(reader: XMLEventReader, attrs: XmlAttrs): BaselineCell = {
val time = attrs.double(AttrTags.TIME)
val cumHazard = attrs.double(AttrTags.CUM_HAZARD)
new BaselineCell(time, cumHazard)
}
def makeBaseCumHazardTables(reader: XMLEventReader, attrs: XmlAttrs): BaseCumHazardTables = makeElem(reader, attrs,
new ElemBuilder[BaseCumHazardTables] {
override def build(reader: XMLEventReader, attrs: XmlAttrs): BaseCumHazardTables = {
val (baselineStratums, baselineCells) = makeElems(reader, ElemTags.BASE_CUM_HAZARD_TABLES,
ElemTags.BASELINE_STRATUM, new ElemBuilder[BaselineStratum] {
override def build(reader: XMLEventReader, attrs: XmlAttrs): BaselineStratum = {
val value = attrs(AttrTags.VALUE)
val label = attrs.get(AttrTags.LABEL)
val maxTime = attrs.double(AttrTags.MAX_TIME)
val cells = makeElems(reader, ElemTags.BASELINE_STRATUM, ElemTags.BASELINE_CELL, new ElemBuilder[BaselineCell] {
override def build(reader: XMLEventReader, attrs: XmlAttrs): BaselineCell = makeBaselineCell(reader, attrs)
})
new BaselineStratum(cells, attributes.baselineStrataVariable.get.toVal(value), maxTime, label)
}
}, ElemTags.BASELINE_CELL, new ElemBuilder[BaselineCell] {
override def build(reader: XMLEventReader, attrs: XmlAttrs): BaselineCell = makeBaselineCell(reader, attrs)
})
val maxTime: Option[Double] = attrs.getDouble(AttrTags.MAX_TIME)
new BaseCumHazardTables(baselineStratums, baselineCells, maxTime)
}
})
def makeEventValues(reader: XMLEventReader): EventValues = {
val (values, intervals) = makeElems(reader, ElemTags.EVENT_VALUES, ElemTags.VALUE, new ElemBuilder[Value] {
override def build(reader: XMLEventReader, attrs: XmlAttrs): Value = makeValue(reader, attrs)
}, ElemTags.INTERVAL, new ElemBuilder[Interval] {
override def build(reader: XMLEventReader, attrs: XmlAttrs): Interval = makeInterval(reader, attrs)
})
new EventValues(values, intervals)
}
def makeParamMatrix(reader: XMLEventReader): ParamMatrix = {
val cells = makeElems(reader, ElemTags.PARAM_MATRIX, ElemTags.P_CELL, new ElemBuilder[PCell] {
override def build(reader: XMLEventReader, attrs: XmlAttrs): PCell = {
// 'targetCategory' is required for the classification model, while useless for the regression model. Some
// examples still write an empty string for regression, so just ignore it
val targetCategory = if (attributes.isRegression) None else attrs.get(AttrTags.TARGET_CATEGORY).map(target.toVal(_))
val parameterName = attrs(AttrTags.PARAMETER_NAME)
// 'beta' is a required attribute defined by the PMML standard, but some examples exported from R could be absent.
val beta = attrs.getDouble(AttrTags.BETA, 0.0)
val df = attrs.getInt(AttrTags.DF)
new PCell(parameterName, beta, targetCategory, df)
}
})
new ParamMatrix(cells)
}
def makePCovMatrix(reader: XMLEventReader, attrs: XmlAttrs): PCovMatrix = makeElem(reader, attrs, new ElemBuilder[PCovMatrix] {
override def build(reader: XMLEventReader, attrs: XmlAttrs): PCovMatrix = {
val tpe = attrs.getEnum(AttrTags.TYPE, PCovMatrixType)
val cells = makeElems(reader, ElemTags.P_COV_MATRIX, ElemTags.P_COV_CELL, new ElemBuilder[PCovCell] {
override def build(reader: XMLEventReader, attrs: XmlAttrs): PCovCell = {
val pRow = attrs(AttrTags.P_ROW)
val pCol = attrs(AttrTags.P_COL)
val tRow = attrs.get(AttrTags.T_ROW)
val tCol = attrs.get(AttrTags.T_COL)
val value = attrs.double(AttrTags.VALUE)
val targetCategory = attrs.get(AttrTags.TARGET_CATEGORY).map(target.toVal(_))
new PCovCell(pRow, pCol, value, tRow, tCol, targetCategory)
}
})
new PCovMatrix(cells, tpe)
}
})
def makePPMatrix(reader: XMLEventReader): PPMatrix = {
val cells = makeElems(reader, ElemTags.PP_MATRIX, ElemTags.PP_CELL, new ElemBuilder[PPCell] {
override def build(reader: XMLEventReader, attrs: XmlAttrs): PPCell = {
val predictorName = attrs(AttrTags.PREDICTOR_NAME)
val f = field(predictorName)
val value = f.toVal(attrs(AttrTags.VALUE))
val parameterName = attrs(AttrTags.PARAMETER_NAME)
val targetCategory = attrs.get(AttrTags.TARGET_CATEGORY).map(target.toVal)
new PPCell(value, f, parameterName, targetCategory)
}
})
new PPMatrix(cells)
}
def makeCovariateList(reader: XMLEventReader): CovariateList = {
val predictors = makeElems(reader, ElemTags.COVARIATE_LIST, ElemTags.PREDICTOR, new ElemBuilder[Predictor] {
override def build(reader: XMLEventReader, attrs: XmlAttrs): Predictor = makePredictor(reader, attrs)
})
new CovariateList(predictors)
}
def makePredictor(reader: XMLEventReader, attrs: XmlAttrs): Predictor = {
val name = attrs(AttrTags.NAME)
val f = field(name)
val contrastMatrixType = attrs.get((AttrTags.CONTRAST_MATRIX_TYPE))
var categories: Categories = null
var matrix: Matrix = null
traverseElems(reader, ElemTags.PREDICTOR, {
case EvElemStart(_, ElemTags.CATEGORIES, attrs, _) => categories = makeElem(reader, attrs, new ElemBuilder[Categories] {
override def build(reader: XMLEventReader, attrs: XmlAttrs): Categories = {
val categories = makeElems(reader, ElemTags.CATEGORIES, ElemTags.CATEGORY, new ElemBuilder[Category] {
override def build(reader: XMLEventReader, attrs: XmlAttrs): Category = {
val value = f.toVal(attrs(AttrTags.VALUE))
new Category(value)
}
})
new Categories(categories)
}
})
case EvElemStart(_, ElemTags.MATRIX, attrs, _) => matrix = makeMatrix(reader, attrs)
})
new Predictor(name, contrastMatrixType, Option(categories), Option(matrix))
}
def makeFactorList(reader: XMLEventReader): FactorList = {
val predictors = makeElems(reader, ElemTags.FACTOR_LIST, ElemTags.PREDICTOR, new ElemBuilder[Predictor] {
override def build(reader: XMLEventReader, attrs: XmlAttrs): Predictor = makePredictor(reader, attrs)
})
new FactorList(predictors)
}
def makeParameterList(reader: XMLEventReader): ParameterList = {
val parameters = makeElems(reader, ElemTags.PARAMETER_LIST, ElemTags.PARAMETER, new ElemBuilder[Parameter] {
override def build(reader: XMLEventReader, attrs: XmlAttrs): Parameter = {
val name = attrs(AttrTags.NAME)
val label = attrs.get(AttrTags.LABEL)
val referencePoint = attrs.getDouble(AttrTags.REFERENCE_POINT).getOrElse(0.0)
new Parameter(name, label, referencePoint)
}
})
new ParameterList(parameters)
}
/** Extracts these common attributes from a model */
override protected def makeAttributes(attrs: XmlAttrs): GeneralRegressionAttributes = {
val attributes = super.makeAttributes(attrs)
new GeneralRegressionAttributes(
functionName = attributes.functionName,
modelType = attrs.`enum`(AttrTags.MODEL_TYPE, GeneralModelType),
targetVariableName = attrs.get(AttrTags.TARGET_VARIABLE_NAME),
targetReferenceCategory = attrs.get(AttrTags.TARGET_REFERENCE_CATEGORY),
cumulativeLink = attrs.getEnum(AttrTags.CUMULATIVE_LINK, CumulativeLinkFunction),
linkFunction = attrs.getEnum(AttrTags.LINK_FUNCTION, LinkFunction),
linkParameter = attrs.getDouble(AttrTags.LINK_PARAMETER),
trialsVariable = attrs.get(AttrTags.TRIALS_VARIABLE).map(field(_)),
trialsValue = attrs.getInt(AttrTags.TRIALS_VALUE),
distribution = attrs.getEnum(AttrTags.DISTRIBUTION, Distribution),
distParameter = attrs.getDouble(AttrTags.DIST_PARAMETER),
offsetVariable = attrs.get(AttrTags.OFFSET_VARIABLE).map(field(_)),
offsetValue = attrs.getDouble(AttrTags.OFFSET_VALUE),
modelDF = attrs.getDouble(AttrTags.MODEL_DF),
endTimeVariable = attrs.get(AttrTags.END_TIME_VARIABLE).map(field(_)),
startTimeVariable = attrs.get(AttrTags.START_TIME_VARIABLE).map(field(_)),
subjectIDVariable = attrs.get(AttrTags.SUBJECT_ID_VARIABLE).map(field(_)),
statusVariable = attrs.get(AttrTags.STATUS_VARIABLE).map(field(_)),
baselineStrataVariable = attrs.get(AttrTags.BASELINE_STRATA_VARIABLE).map(field(_)),
modelName = attributes.modelName,
algorithmName = attributes.algorithmName,
isScorable = attributes.isScorable)
}
/** Name of the builder. */
override def name: String = ElemTags.GENERAL_REGRESSION_MODEL
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy