All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.jpmml.sparkml.PMMLBuilder Maven / Gradle / Ivy

The newest version!
/*
 * Copyright (c) 2018 Villu Ruusmann
 *
 * This file is part of JPMML-SparkML
 *
 * JPMML-SparkML is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Affero General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * JPMML-SparkML is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Affero General Public License for more details.
 *
 * You should have received a copy of the GNU Affero General Public License
 * along with JPMML-SparkML.  If not, see .
 */
package org.jpmml.sparkml;

import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.ListIterator;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.Function;
import java.util.regex.Pattern;
import java.util.stream.Collectors;

import com.google.common.collect.Iterables;
import jakarta.xml.bind.JAXBException;
import org.apache.spark.ml.PipelineModel;
import org.apache.spark.ml.PipelineStage;
import org.apache.spark.ml.Transformer;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.param.shared.HasPredictionCol;
import org.apache.spark.ml.param.shared.HasProbabilityCol;
import org.apache.spark.ml.regression.GeneralizedLinearRegressionModel;
import org.apache.spark.ml.tuning.CrossValidatorModel;
import org.apache.spark.ml.tuning.TrainValidationSplitModel;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.dmg.pmml.DerivedField;
import org.dmg.pmml.MiningField;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.MiningSchema;
import org.dmg.pmml.Output;
import org.dmg.pmml.OutputField;
import org.dmg.pmml.PMML;
import org.dmg.pmml.ResultFeature;
import org.dmg.pmml.VerificationField;
import org.dmg.pmml.mining.Segmentation;
import org.jpmml.converter.Feature;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.SchemaUtil;
import org.jpmml.converter.mining.MiningModelUtil;
import org.jpmml.model.metro.MetroJAXBUtil;
import org.jpmml.sparkml.model.HasFeatureImportances;
import org.jpmml.sparkml.model.HasTreeOptions;

public class PMMLBuilder {

	private StructType schema = null;

	private PipelineModel pipelineModel = null;

	private Map> options = new LinkedHashMap<>();

	private Verification verification = null;


	public PMMLBuilder(StructType schema, PipelineModel pipelineModel){
		setSchema(schema);
		setPipelineModel(pipelineModel);
	}

	public PMMLBuilder(StructType schema, PipelineStage pipelineStage){
		throw new IllegalArgumentException("Expected a fitted pipeline model (class " + PipelineModel.class.getName() + "), got a pipeline stage (" + (pipelineStage != null ? ("class " + (pipelineStage.getClass()).getName()) : null) + ")");
	}

	public PMML build(){
		StructType schema = getSchema();
		PipelineModel pipelineModel = getPipelineModel();
		Map> options = getOptions();
		Verification verification = getVerification();

		ConverterFactory converterFactory = new ConverterFactory(options);

		SparkMLEncoder encoder = new SparkMLEncoder(schema, converterFactory);

		Map derivedFields = encoder.getDerivedFields();

		List models = new ArrayList<>();

		List predictionColumns = new ArrayList<>();
		List probabilityColumns = new ArrayList<>();

		// Transformations preceding the last model
		List preProcessorNames = Collections.emptyList();

		Iterable transformers = getTransformers(pipelineModel);
		for(Transformer transformer : transformers){
			TransformerConverter converter = converterFactory.newConverter(transformer);

			if(converter instanceof FeatureConverter){
				FeatureConverter featureConverter = (FeatureConverter)converter;

				featureConverter.registerFeatures(encoder);
			} else

			if(converter instanceof ModelConverter){
				ModelConverter modelConverter = (ModelConverter)converter;

				org.dmg.pmml.Model model = modelConverter.registerModel(encoder);

				models.add(model);

				featureImportances:
				if(modelConverter instanceof HasFeatureImportances){
					HasFeatureImportances hasFeatureImportances = (HasFeatureImportances)modelConverter;

					Boolean estimateFeatureImportances = (Boolean)modelConverter.getOption(HasTreeOptions.OPTION_ESTIMATE_FEATURE_IMPORTANCES, Boolean.FALSE);
					if(!estimateFeatureImportances){
						break featureImportances;
					}

					List featureImportances = VectorUtil.toList(hasFeatureImportances.getFeatureImportances());

					List features = modelConverter.getFeatures(encoder);

					SchemaUtil.checkSize(featureImportances.size(), features);

					for(int i = 0; i < featureImportances.size(); i++){
						Double featureImportance = featureImportances.get(i);
						Feature feature = features.get(i);

						encoder.addFeatureImportance(model, feature, featureImportance);
					}
				} // End if

				hasPredictionCol:
				if(transformer instanceof HasPredictionCol){
					HasPredictionCol hasPredictionCol = (HasPredictionCol)transformer;

					// XXX
					if((transformer instanceof GeneralizedLinearRegressionModel) && (model.requireMiningFunction() == MiningFunction.CLASSIFICATION)){
						break hasPredictionCol;
					}

					predictionColumns.add(hasPredictionCol.getPredictionCol());
				} // End if

				if(transformer instanceof HasProbabilityCol){
					HasProbabilityCol hasProbabilityCol = (HasProbabilityCol)transformer;

					probabilityColumns.add(hasProbabilityCol.getProbabilityCol());
				}

				preProcessorNames = new ArrayList<>(derivedFields.keySet());
			} else

			{
				throw new IllegalArgumentException("Expected a subclass of " + FeatureConverter.class.getName() + " or " + ModelConverter.class.getName() + ", got " + (converter != null ? ("class " + (converter.getClass()).getName()) : null));
			}
		}

		// Transformations following the last model
		List postProcessorNames = new ArrayList<>(derivedFields.keySet());
		postProcessorNames.removeAll(preProcessorNames);

		org.dmg.pmml.Model model;

		if(models.size() == 0){
			model = null;
		} else

		if(models.size() == 1){
			model = Iterables.getOnlyElement(models);
		} else

		{
			model = MiningModelUtil.createModelChain(models, Segmentation.MissingPredictionTreatment.CONTINUE);
		} // End if

		if((model != null) && !postProcessorNames.isEmpty()){
			org.dmg.pmml.Model finalModel = MiningModelUtil.getFinalModel(model);

			Output output = ModelUtil.ensureOutput(finalModel);

			for(String postProcessorName : postProcessorNames){
				DerivedField derivedField = derivedFields.get(postProcessorName);

				encoder.removeDerivedField(postProcessorName);

				OutputField outputField = new OutputField(derivedField.requireName(), derivedField.requireOpType(), derivedField.requireDataType())
					.setResultFeature(ResultFeature.TRANSFORMED_VALUE)
					.setExpression(derivedField.requireExpression());

				output.addOutputFields(outputField);
			}
		}

		PMML pmml = encoder.encodePMML(model);

		if((model != null) && (!predictionColumns.isEmpty() || !probabilityColumns.isEmpty()) && (verification != null)){
			Dataset dataset = verification.getDataset();
			Dataset transformedDataset = verification.getTransformedDataset();
			Double precision = verification.getPrecision();
			Double zeroThreshold = verification.getZeroThreshold();

			List inputColumns = new ArrayList<>();

			MiningSchema miningSchema = model.requireMiningSchema();

			List miningFields = miningSchema.getMiningFields();
			for(MiningField miningField : miningFields){
				MiningField.UsageType usageType = miningField.getUsageType();

				switch(usageType){
					case ACTIVE:
						String name = miningField.getName();

						inputColumns.add(name);
						break;
					default:
						break;
				}
			}

			Map> data = new LinkedHashMap<>();

			for(String inputColumn : inputColumns){
				VerificationField verificationField = ModelUtil.createVerificationField(inputColumn);

				data.put(verificationField, getColumn(dataset, inputColumn));
			}

			for(String predictionColumn : predictionColumns){
				Feature feature = encoder.getOnlyFeature(predictionColumn);

				VerificationField verificationField = ModelUtil.createVerificationField(feature.getName())
					.setPrecision(precision)
					.setZeroThreshold(zeroThreshold);

				data.put(verificationField, getColumn(transformedDataset, predictionColumn));
			}

			for(String probabilityColumn : probabilityColumns){
				List features = encoder.getFeatures(probabilityColumn);

				for(int i = 0; i < features.size(); i++){
					Feature feature = features.get(i);

					VerificationField verificationField = ModelUtil.createVerificationField(feature.getName())
						.setPrecision(precision)
						.setZeroThreshold(zeroThreshold);

					data.put(verificationField, getVectorColumn(transformedDataset, probabilityColumn, i));
				}
			}

			model.setModelVerification(ModelUtil.createModelVerification(data));
		}

		return pmml;
	}

	public byte[] buildByteArray(){
		return buildByteArray(1024 * 1024);
	}

	private byte[] buildByteArray(int size){
		PMML pmml = build();

		ByteArrayOutputStream os = new ByteArrayOutputStream(size);

		try {
			MetroJAXBUtil.marshalPMML(pmml, os);
		} catch(JAXBException je){
			throw new RuntimeException(je);
		}

		return os.toByteArray();
	}

	public File buildFile(File file) throws IOException {
		PMML pmml = build();

		OutputStream os = new FileOutputStream(file);

		try {
			MetroJAXBUtil.marshalPMML(pmml, os);
		} catch(JAXBException je){
			throw new RuntimeException(je);
		} finally {
			os.close();
		}

		return file;
	}

	public PMMLBuilder extendSchema(Set names){
		StructType schema = getSchema();
		PipelineModel pipelineModel = getPipelineModel();

		StructType transformedSchema = pipelineModel.transformSchema(schema);

		for(String name : names){
			StructField field = transformedSchema.apply(name);

			schema = schema.add(field);
		}

		setSchema(schema);

		return this;
	}

	public PMMLBuilder putOption(String key, Object value){
		return putOptions(Collections.singletonMap(key, value));
	}

	public PMMLBuilder putOptions(Map map){
		return putOptions(Pattern.compile(".*"), map);
	}

	public PMMLBuilder putOption(PipelineStage pipelineStage, String key, Object value){
		return putOptions(pipelineStage, Collections.singletonMap(key, value));
	}

	public PMMLBuilder putOptions(PipelineStage pipelineStage, Map map){
		return putOptions(Pattern.compile(pipelineStage.uid(), Pattern.LITERAL), map);
	}

	public PMMLBuilder putOptions(Pattern pattern, Map map){
		Map> options = getOptions();

		RegexKey key = new RegexKey(pattern);

		Map patternOptions = options.get(key);
		if(patternOptions == null){
			patternOptions = new LinkedHashMap<>();

			options.put(key, patternOptions);
		}

		patternOptions.putAll(map);

		return this;
	}

	public PMMLBuilder verify(Dataset dataset){
		return verify(dataset, 1e-14, 1e-14);
	}

	public PMMLBuilder verify(Dataset dataset, double precision, double zeroThreshold){
		PipelineModel pipelineModel = getPipelineModel();

		Dataset transformedDataset = pipelineModel.transform(dataset);

		Verification verification = new Verification(dataset, transformedDataset)
			.setPrecision(precision)
			.setZeroThreshold(zeroThreshold);

		return setVerification(verification);
	}

	public StructType getSchema(){
		return this.schema;
	}

	public PMMLBuilder setSchema(StructType schema){
		this.schema = Objects.requireNonNull(schema);

		return this;
	}

	public PipelineModel getPipelineModel(){
		return this.pipelineModel;
	}

	public PMMLBuilder setPipelineModel(PipelineModel pipelineModel){
		this.pipelineModel = Objects.requireNonNull(pipelineModel);

		return this;
	}

	public Map> getOptions(){
		return this.options;
	}

	private PMMLBuilder setOptions(Map> options){
		this.options = Objects.requireNonNull(options);

		return this;
	}

	public Verification getVerification(){
		return this.verification;
	}

	private PMMLBuilder setVerification(Verification verification){
		this.verification = verification;

		return this;
	}

	static
	private Iterable getTransformers(PipelineModel pipelineModel){
		List result = new ArrayList<>();
		result.add(pipelineModel);

		Function> function = new Function>(){

			@Override
			public List apply(Transformer transformer){

				if(transformer instanceof PipelineModel){
					PipelineModel pipelineModel = (PipelineModel)transformer;

					return Arrays.asList(pipelineModel.stages());
				} else

				if(transformer instanceof CrossValidatorModel){
					CrossValidatorModel crossValidatorModel = (CrossValidatorModel)transformer;

					return Collections.singletonList(crossValidatorModel.bestModel());
				} else

				if(transformer instanceof TrainValidationSplitModel){
					TrainValidationSplitModel trainValidationSplitModel = (TrainValidationSplitModel)transformer;

					return Collections.singletonList(trainValidationSplitModel.bestModel());
				}

				return null;
			}
		};

		while(true){
			boolean modified = false;

			ListIterator transformerIt = result.listIterator();
			while(transformerIt.hasNext()){
				Transformer transformer = transformerIt.next();

				List childTransformers = function.apply(transformer);
				if(childTransformers != null){
					transformerIt.remove();

					for(Transformer childTransformer : childTransformers){
						transformerIt.add(childTransformer);
					}

					modified = true;
				}
			}

			if(!modified){
				break;
			}
		}

		return result;
	}

	static
	private List getColumn(Dataset dataset, String name){
		List rows = dataset.select(name)
			.collectAsList();

		return rows.stream()
			.map(row -> row.apply(0))
			.collect(Collectors.toList());
	}

	static
	private List getVectorColumn(Dataset dataset, String name, int index){
		List column = getColumn(dataset, name);

		return column.stream()
			.map(Vector.class::cast)
			.map(vector -> vector.apply(index))
			.collect(Collectors.toList());
	}

	static
	private void init(){
		ConverterFactory.checkVersion();
		ConverterFactory.checkApplicationClasspath();
		ConverterFactory.checkNoShading();
	}

	static
	public class Verification {

		private Dataset dataset = null;

		private Dataset transformedDataset = null;

		public Double precision = null;

		public Double zeroThreshold = null;


		private Verification(Dataset dataset, Dataset transformedDataset){
			setDataset(dataset);
			setTransformedDataset(transformedDataset);
		}

		public Dataset getDataset(){
			return this.dataset;
		}

		private Verification setDataset(Dataset dataset){
			this.dataset = dataset;

			return this;
		}

		public Dataset getTransformedDataset(){
			return this.transformedDataset;
		}

		private Verification setTransformedDataset(Dataset transformedDataset){
			this.transformedDataset = transformedDataset;

			return this;
		}

		public Double getPrecision(){
			return this.precision;
		}

		public Verification setPrecision(Double precision){
			this.precision = precision;

			return this;
		}

		public Double getZeroThreshold(){
			return this.zeroThreshold;
		}

		public Verification setZeroThreshold(Double zeroThreshold){
			this.zeroThreshold = zeroThreshold;

			return this;
		}
	}

	static {
		init();
	}
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy