Please wait. This can take some minutes ...
Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance.
Project price only 1 $
You can buy this project and download/modify it how often you want.
org.drools.pmml.pmml_4_2.PMML4Compiler Maven / Gradle / Ivy
/*
* Copyright 2011 Red Hat, Inc. and/or its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.drools.pmml.pmml_4_2;
import org.dmg.pmml.pmml_4_2.descr.ClusteringModel;
import org.dmg.pmml.pmml_4_2.descr.NaiveBayesModel;
import org.dmg.pmml.pmml_4_2.descr.NeuralNetwork;
import org.dmg.pmml.pmml_4_2.descr.PMML;
import org.dmg.pmml.pmml_4_2.descr.RegressionModel;
import org.dmg.pmml.pmml_4_2.descr.Scorecard;
import org.dmg.pmml.pmml_4_2.descr.SupportVectorMachineModel;
import org.dmg.pmml.pmml_4_2.descr.TreeModel;
import org.drools.compiler.compiler.PMMLCompiler;
import org.drools.core.io.impl.ByteArrayResource;
import org.drools.core.io.impl.ClassPathResource;
import org.drools.core.util.IoUtils;
import org.kie.api.KieBase;
import org.kie.api.KieServices;
import org.kie.api.io.Resource;
import org.kie.api.io.ResourceType;
import org.kie.api.runtime.KieContainer;
import org.kie.api.runtime.KieSession;
import org.kie.internal.builder.KnowledgeBuilderResult;
import org.kie.internal.io.ResourceFactory;
import org.mvel2.templates.SimpleTemplateRegistry;
import org.mvel2.templates.TemplateCompiler;
import org.mvel2.templates.TemplateRegistry;
import org.xml.sax.SAXException;
import javax.xml.XMLConstants;
import javax.xml.bind.JAXBContext;
import javax.xml.bind.JAXBException;
import javax.xml.bind.Marshaller;
import javax.xml.bind.Unmarshaller;
import javax.xml.validation.Schema;
import javax.xml.validation.SchemaFactory;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.io.OutputStreamWriter;
import java.io.UnsupportedEncodingException;
import java.io.Writer;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
public class PMML4Compiler implements PMMLCompiler {
public static final String PMML_NAMESPACE = "org.dmg.pmml.pmml_4_2";
public static final String PMML_DROOLS = "org.drools.pmml.pmml_4_2";
public static final String PMML = PMML_NAMESPACE + ".descr";
public static final String SCHEMA_PATH = "xsd/org/dmg/pmml/pmml_4_2/pmml-4-2.xsd";
public static final String BASE_PACK = PMML_DROOLS.replace('.','/');
protected static boolean globalLoaded = false;
protected static final String[] GLOBAL_TEMPLATES = new String[] {
"global/pmml_header.drlt",
"global/pmml_import.drlt",
"global/rule_meta.drlt",
"global/modelMark.drlt",
"global/dataDefinition/common.drlt",
"global/dataDefinition/rootDataField.drlt",
"global/dataDefinition/inputBinding.drlt",
"global/dataDefinition/outputBinding.drlt",
"global/dataDefinition/ioTypeDeclare.drlt",
"global/dataDefinition/updateIOField.drlt",
"global/dataDefinition/inputFromEP.drlt",
"global/dataDefinition/inputBean.drlt",
"global/dataDefinition/outputBean.drlt",
"global/manipulation/confirm.drlt",
"global/manipulation/mapMissingValues.drlt",
"global/manipulation/propagateMissingValues.drlt",
"global/validation/intervalsOnDomainRestriction.drlt",
"global/validation/valuesOnDomainRestriction.drlt",
"global/validation/valuesOnDomainRestrictionMissing.drlt",
"global/validation/valuesOnDomainRestrictionInvalid.drlt",
};
protected static boolean transformationLoaded = false;
protected static final String[] TRANSFORMATION_TEMPLATES = new String[] {
"transformations/normContinuous/boundedLowerOutliers.drlt",
"transformations/normContinuous/boundedUpperOutliers.drlt",
"transformations/normContinuous/normContOutliersAsMissing.drlt",
"transformations/normContinuous/linearTractNormalization.drlt",
"transformations/normContinuous/lowerExtrapolateLinearTractNormalization.drlt",
"transformations/normContinuous/upperExtrapolateLinearTractNormalization.drlt",
"transformations/aggregate/aggregate.drlt",
"transformations/aggregate/collect.drlt",
"transformations/simple/constantField.drlt",
"transformations/simple/aliasedField.drlt",
"transformations/normDiscrete/indicatorFieldYes.drlt",
"transformations/normDiscrete/indicatorFieldNo.drlt",
"transformations/normDiscrete/predicateField.drlt",
"transformations/discretize/intervalBinning.drlt",
"transformations/discretize/outOfBinningDefault.drlt",
"transformations/discretize/outOfBinningMissing.drlt",
"transformations/mapping/mapping.drlt",
"transformations/functions/apply.drlt",
"transformations/functions/function.drlt"
};
protected static boolean miningLoaded = false;
protected static final String[] MINING_TEMPLATES = new String[] {
"models/common/mining/miningField.drlt",
"models/common/mining/miningFieldInvalid.drlt",
"models/common/mining/miningFieldMissing.drlt",
"models/common/mining/miningFieldOutlierAsMissing.drlt",
"models/common/mining/miningFieldOutlierAsExtremeLow.drlt",
"models/common/mining/miningFieldOutlierAsExtremeUpp.drlt",
"models/common/targets/targetReshape.drlt",
"models/common/targets/aliasedOutput.drlt",
"models/common/targets/addOutputFeature.drlt",
"models/common/targets/addRelOutputFeature.drlt",
"models/common/targets/outputQuery.drlt",
"models/common/targets/outputQueryPredicate.drlt"
};
protected static boolean neuralLoaded = false;
protected static final String[] NEURAL_TEMPLATES = new String[] {
"models/neural/neuralBeans.drlt",
"models/neural/neuralWireInput.drlt",
"models/neural/neuralBuildSynapses.drlt",
"models/neural/neuralBuildNeurons.drlt",
"models/neural/neuralLinkSynapses.drlt",
"models/neural/neuralFire.drlt",
"models/neural/neuralLayerMaxNormalization.drlt",
"models/neural/neuralLayerSoftMaxNormalization.drlt",
"models/neural/neuralOutputField.drlt",
"models/neural/neuralClean.drlt"
};
protected static boolean svmLoaded = false;
protected static final String[] SVM_TEMPLATES = new String[] {
"models/svm/svmParams.drlt",
"models/svm/svmDeclare.drlt",
"models/svm/svmFunctions.drlt",
"models/svm/svmBuild.drlt",
"models/svm/svmInitSupportVector.drlt",
"models/svm/svmInitInputVector.drlt",
"models/svm/svmKernelEval.drlt",
"models/svm/svmOutputGeneration.drlt",
"models/svm/svmOutputVoteDeclare.drlt",
"models/svm/svmOutputVote1vN.drlt",
"models/svm/svmOutputVote1v1.drlt",
};
protected static boolean naiveBayesLoaded = false;
protected static final String[] NAIVE_BAYES_TEMPLATES = new String[] {
"models/bayes/naiveBayesDeclare.drlt",
"models/bayes/naiveBayesEvalDiscrete.drlt",
"models/bayes/naiveBayesEvalContinuous.drlt",
"models/bayes/naiveBayesBuildCounts.drlt",
"models/bayes/naiveBayesBuildDistrs.drlt",
"models/bayes/naiveBayesBuildOuts.drlt",
};
protected static boolean simpleRegLoaded = false;
protected static final String[] SIMPLEREG_TEMPLATES = new String[] {
"models/regression/regDeclare.drlt",
"models/regression/regCommon.drlt",
"models/regression/regParams.drlt",
"models/regression/regEval.drlt",
"models/regression/regClaxOutput.drlt",
"models/regression/regNormalization.drlt",
"models/regression/regDecumulation.drlt",
};
protected static boolean clusteringLoaded = false;
protected static final String[] CLUSTERING_TEMPLATES = new String[] {
"models/clustering/clusteringDeclare.drlt",
"models/clustering/clusteringInit.drlt",
"models/clustering/clusteringEvalDistance.drlt",
"models/clustering/clusteringEvalSimilarity.drlt",
"models/clustering/clusteringMatrixCompare.drlt"
};
protected static boolean treeLoaded = false;
protected static final String[] TREE_TEMPLATES = new String[] {
"models/tree/treeDeclare.drlt",
"models/tree/treeCommon.drlt",
"models/tree/treeInputDeclare.drlt",
"models/tree/treeInit.drlt",
"models/tree/treeAggregateEval.drlt",
"models/tree/treeDefaultEval.drlt",
"models/tree/treeEval.drlt",
"models/tree/treeIOBinding.drlt",
"models/tree/treeMissHandleAggregate.drlt",
"models/tree/treeMissHandleWeighted.drlt",
"models/tree/treeMissHandleLast.drlt",
"models/tree/treeMissHandleNull.drlt",
"models/tree/treeMissHandleNone.drlt"
};
protected static boolean scorecardLoaded = false;
protected static final String[] SCORECARD_TEMPLATES = new String[] {
"models/scorecard/scorecardInit.drlt",
"models/scorecard/scorecardParamsInit.drlt",
"models/scorecard/scorecardDeclare.drlt",
"models/scorecard/scorecardDataDeclare.drlt",
"models/scorecard/scorecardPartialScore.drlt",
"models/scorecard/scorecardScoring.drlt",
"models/scorecard/scorecardOutputGeneration.drlt",
"models/scorecard/scorecardOutputRankCode.drlt"
};
protected static final String RESOURCE_PATH = BASE_PACK;
protected static final String TEMPLATE_PATH = "/" + RESOURCE_PATH + "/templates/";
private static TemplateRegistry registry;
private static List visitorBuildResults = new ArrayList();
private List results;
private Schema schema;
private PMML4Helper helper;
public PMML4Compiler() {
super();
this.results = new ArrayList();
helper = new PMML4Helper();
helper.setPack( "org.drools.pmml.pmml_4_2.test" );
SchemaFactory sf = SchemaFactory.newInstance( XMLConstants.W3C_XML_SCHEMA_NS_URI );
try {
schema = sf.newSchema( Thread.currentThread().getContextClassLoader().getResource( SCHEMA_PATH ) );
} catch ( SAXException e ) {
e.printStackTrace();
}
}
public PMML4Helper getHelper() {
return helper;
}
public String generateTheory( PMML pmml ) {
StringBuilder sb = new StringBuilder();
//dumpModel( pmml, System.out );
KieBase visitor;
try {
visitor = checkBuildingResources( pmml );
} catch ( IOException e ) {
this.results.add( new PMMLError( e.getMessage() ) );
return null;
}
KieSession visitorSession = visitor.newKieSession();
helper.reset();
visitorSession.setGlobal( "registry", registry );
visitorSession.setGlobal( "fld2var", new HashMap() );
visitorSession.setGlobal( "utils", helper );
visitorSession.setGlobal( "theory", sb );
visitorSession.insert( pmml );
visitorSession.fireAllRules();
String modelEvaluatingRules = sb.toString();
visitorSession.dispose();
//System.out.println( modelEvaluatingRules );
return modelEvaluatingRules;
}
private static void initRegistry() {
if ( registry == null ) {
registry = new SimpleTemplateRegistry();
}
if ( ! globalLoaded ) {
for ( String ntempl : GLOBAL_TEMPLATES ) {
prepareTemplate( ntempl );
}
globalLoaded = true;
}
if ( ! transformationLoaded ) {
for ( String ntempl : TRANSFORMATION_TEMPLATES ) {
prepareTemplate( ntempl );
}
transformationLoaded = true;
}
if ( ! miningLoaded ) {
for ( String ntempl : MINING_TEMPLATES ) {
prepareTemplate( ntempl );
}
miningLoaded = true;
}
}
private static KieBase checkBuildingResources( PMML pmml ) throws IOException {
KieServices ks = KieServices.Factory.get();
KieContainer kieContainer = ks.getKieClasspathContainer();
if ( registry == null ) {
initRegistry();
}
String chosenKieBase = null;
for ( Object o : pmml.getAssociationModelsAndBaselineModelsAndClusteringModels() ) {
if ( o instanceof NaiveBayesModel ) {
if ( ! naiveBayesLoaded ) {
for ( String ntempl : NAIVE_BAYES_TEMPLATES ) {
prepareTemplate( ntempl );
}
naiveBayesLoaded = true;
}
chosenKieBase = chosenKieBase == null ? "PMML-Bayes" : "PMML";
}
if ( o instanceof NeuralNetwork ) {
if ( ! neuralLoaded ) {
for ( String ntempl : NEURAL_TEMPLATES ) {
prepareTemplate( ntempl );
}
neuralLoaded = true;
}
chosenKieBase = chosenKieBase == null ? "PMML-Neural" : "PMML";
}
if ( o instanceof ClusteringModel ) {
if ( ! clusteringLoaded ) {
for ( String ntempl : CLUSTERING_TEMPLATES ) {
prepareTemplate( ntempl );
}
clusteringLoaded = true;
}
chosenKieBase = chosenKieBase == null ? "PMML-Cluster" : "PMML";
}
if ( o instanceof SupportVectorMachineModel ) {
if ( ! svmLoaded ) {
for ( String ntempl : SVM_TEMPLATES ) {
prepareTemplate( ntempl );
}
svmLoaded = true;
}
chosenKieBase = chosenKieBase == null ? "PMML-SVM" : "PMML";
}
if ( o instanceof TreeModel ) {
if ( ! treeLoaded ) {
for ( String ntempl : TREE_TEMPLATES ) {
prepareTemplate( ntempl );
}
treeLoaded = true;
}
chosenKieBase = chosenKieBase == null ? "PMML-Tree" : "PMML";
}
if ( o instanceof RegressionModel ) {
if ( ! simpleRegLoaded ) {
for ( String ntempl : SIMPLEREG_TEMPLATES ) {
prepareTemplate( ntempl );
}
simpleRegLoaded = true;
}
chosenKieBase = chosenKieBase == null ? "PMML-Regression" : "PMML";
}
if ( o instanceof Scorecard ) {
if ( ! scorecardLoaded ) {
for ( String ntempl : SCORECARD_TEMPLATES ) {
prepareTemplate( ntempl );
}
scorecardLoaded = true;
}
chosenKieBase = chosenKieBase == null ? "PMML-Scorecard" : "PMML";
}
}
if ( chosenKieBase == null ) {
chosenKieBase = "PMML-Base";
}
return kieContainer.getKieBase( chosenKieBase );
}
private static void prepareTemplate( String ntempl ) {
try {
String path = TEMPLATE_PATH + ntempl;
Resource res = ResourceFactory.newClassPathResource(path, PMML4Compiler.class);
if ( res != null ) {
InputStream stream = res.getInputStream();
if ( stream != null ) {
registry.addNamedTemplate( path.substring(path.lastIndexOf('/') + 1),
TemplateCompiler.compileTemplate(stream));
}
}
} catch (IOException e) {
e.printStackTrace();
}
}
public String compile( String resource, ClassLoader classLoader ) {
String theory = null;
Resource cpr = new ClassPathResource( resource );
try {
theory = compile( cpr.getInputStream(), classLoader );
} catch ( IOException e ) {
results.add( new PMMLError( e.toString() ) );
e.printStackTrace();
}
return theory;
}
public Resource[] transform( Resource resource, ClassLoader classLoader ) {
String theory = null;
try {
theory = compile( resource.getInputStream(), classLoader );
} catch ( IOException e ) {
results.add( new PMMLError( e.toString() ) );
e.printStackTrace();
return new Resource[ 0 ];
}
return new Resource[] { buildOutputResource( resource, theory ) };
}
private Resource buildOutputResource( Resource resource, String theory ) {
ByteArrayResource byteArrayResource = new ByteArrayResource( theory.getBytes( IoUtils.UTF8_CHARSET ) );
byteArrayResource.setResourceType( ResourceType.PMML );
if ( resource.getSourcePath() != null ) {
String originalPath = resource.getSourcePath();
int start = originalPath.lastIndexOf( File.separator );
byteArrayResource.setSourcePath( "generated-sources/" + originalPath.substring( start ) + ".pmml" );
} else {
byteArrayResource.setSourcePath( "generated-sources/" + helper.getContext() + ".pmml" );
}
return byteArrayResource;
}
public String compile( InputStream source, ClassLoader classLoader ) {
this.results = new ArrayList();
PMML pmml = loadModel( PMML, source );
helper.setResolver( classLoader );
if ( getResults().isEmpty() ) {
return generateTheory( pmml );
} else {
return null;
}
}
public List getResults() {
List combinedResults = new ArrayList( this.results );
combinedResults.addAll( visitorBuildResults );
return combinedResults;
}
@Override
public void clearResults() {
this.results.clear();
}
public void dump( String s, OutputStream ostream ) {
// write to outstream
Writer writer = null;
try {
writer = new OutputStreamWriter( ostream, "UTF-8" );
writer.write(s);
} catch (UnsupportedEncodingException e) {
e.printStackTrace();
} catch (IOException e) {
e.printStackTrace();
}
finally {
try {
if (writer != null) {
writer.flush();
}
} catch (IOException e) {
e.printStackTrace();
}
}
}
/**
* Imports a PMML source file, returning a Java descriptor
* @param model the PMML package name (classes derived from a specific schema)
* @param source the name of the PMML resource storing the predictive model
* @return the Java Descriptor of the PMML resource
*/
public PMML loadModel( String model, InputStream source ) {
try {
if ( schema == null ) {
visitorBuildResults.add( new PMMLWarning( ResourceFactory.newInputStreamResource( source ), "Could not validate PMML document, schema not available" ) );
}
JAXBContext jc = JAXBContext.newInstance( model );
Unmarshaller unmarshaller = jc.createUnmarshaller();
if ( schema != null ) {
unmarshaller.setSchema( schema );
}
return (PMML) unmarshaller.unmarshal( source );
} catch ( JAXBException e ) {
this.results.add( new PMMLError( e.toString() ) );
return null;
}
}
public static void dumpModel( PMML model, OutputStream target ) {
try {
JAXBContext jc = JAXBContext.newInstance( PMML.class.getPackage().getName() );
Marshaller marshaller = jc.createMarshaller();
marshaller.setProperty( Marshaller.JAXB_FORMATTED_OUTPUT, Boolean.TRUE );
marshaller.marshal( model, target );
} catch ( JAXBException e ) {
e.printStackTrace();
}
}
}