All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.jpmml.sparkml.ConverterFactory Maven / Gradle / Ivy

The newest version!
/*
 * Copyright (c) 2018 Villu Ruusmann
 *
 * This file is part of JPMML-SparkML
 *
 * JPMML-SparkML is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Affero General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * JPMML-SparkML is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Affero General Public License for more details.
 *
 * You should have received a copy of the GNU Affero General Public License
 * along with JPMML-SparkML.  If not, see .
 */
package org.jpmml.sparkml;

import java.io.IOException;
import java.io.InputStream;
import java.io.StringReader;
import java.lang.reflect.Constructor;
import java.net.URL;
import java.util.Arrays;
import java.util.Enumeration;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Objects;
import java.util.Properties;
import java.util.Set;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import javax.xml.transform.stream.StreamSource;

import jakarta.xml.bind.JAXBException;
import org.apache.log4j.LogManager;
import org.apache.log4j.Logger;
import org.apache.spark.SparkContext;
import org.apache.spark.ml.Transformer;
import org.apache.spark.sql.SparkSession;
import org.jpmml.model.JAXBUtil;

public class ConverterFactory {

	private Map> options = null;


	public ConverterFactory(Map> options){
		setOptions(options);
	}

	public TransformerConverter newConverter(Transformer transformer){
		Class clazz = transformer.getClass();

		Class> converterClazz = ConverterFactory.converters.get(clazz);
		if(converterClazz == null){
			throw new IllegalArgumentException("Transformer class " + clazz.getName() + " is not supported");
		}

		TransformerConverter converter;

		try {
			Constructor> converterConstructor = converterClazz.getDeclaredConstructor(clazz);

			converter = converterConstructor.newInstance(transformer);
		} catch(ReflectiveOperationException roe){
			throw new IllegalArgumentException("Transformer class " + clazz.getName() + " is not supported", roe);
		}

		if(converter != null){
			Map> options = getOptions();

			Map converterOptions = new LinkedHashMap<>();

			options.entrySet().stream()
				.filter(entry -> (entry.getKey()).test(transformer.uid()))
				.map(entry -> entry.getValue())
				.forEach(converterOptions::putAll);

			converter.setOptions(converterOptions);
		}

		return converter;
	}

	public Map> getOptions(){
		return this.options;
	}

	private void setOptions(Map> options){
		this.options = Objects.requireNonNull(options);
	}

	static
	public void checkVersion(){
		SparkSession sparkSession;

		try {
			sparkSession = SparkSession.active();
		} catch(IllegalStateException ise){
			logger.warn("Failed to check Apache Spark ML version", ise);

			return;
		}

		SparkContext sparkContext = sparkSession.sparkContext();

		int[] version = parseVersion(sparkContext.version());

		if(!Arrays.equals(ConverterFactory.VERSION, version)){
			throw new IllegalArgumentException("Expected Apache Spark ML version " + formatVersion(ConverterFactory.VERSION) + ", got version " + formatVersion(version) + " (" + sparkContext.version() + ")");
		}
	}

	static
	public void checkApplicationClasspath(){
		String string = "";

		try {
			JAXBUtil.unmarshalPMML(new StreamSource(new StringReader(string)));
		} catch(JAXBException je){
			throw new IllegalArgumentException("Expected JPMML-Model version 1.5.X, got a legacy version. See https://issues.apache.org/jira/browse/SPARK-15526", je);
		}
	}

	static
	public void checkNoShading(){
		Package _package = TransformerConverter.class.getPackage();

		String name = _package.getName();

		if(!(name).equals("org.jpmml.sparkml")){
			throw new IllegalArgumentException("Expected JPMML-SparkML converter classes to have package name prefix \'org.jpmml.sparkml\', got package name prefix \'" + name + "\'");
		}
	}

	static
	private void init(ClassLoader classLoader){
		Enumeration urls;

		try {
			urls = classLoader.getResources("META-INF/sparkml2pmml.properties");
		} catch(IOException ioe){
			logger.warn("Failed to find resources", ioe);

			return;
		}

		while(urls.hasMoreElements()){
			URL url = urls.nextElement();

			logger.trace("Loading resource " + url);

			try(InputStream is = url.openStream()){
				Properties properties = new Properties();
				properties.load(is);

				init(classLoader, properties);
			} catch(IOException ioe){
				logger.warn("Failed to load resource", ioe);
			}
		}
	}

	@SuppressWarnings({"rawtypes", "unchecked"})
	static
	private void init(ClassLoader classLoader, Properties properties){

		if(properties.isEmpty()){
			return;
		}

		Set keys = properties.stringPropertyNames();
		for(String key : keys){
			String value = properties.getProperty(key);

			logger.trace("Mapping transformer class " + key + " to transformer converter class " + value);

			Class clazz;

			try {
				clazz = classLoader.loadClass(key);
			} catch(ClassNotFoundException cnfe){
				logger.warn("Failed to load transformer class", cnfe);

				continue;
			}

			if(!(Transformer.class).isAssignableFrom(clazz)){
				throw new IllegalArgumentException("Transformer class " + clazz.getName() + " is not a subclass of " + Transformer.class.getName());
			} // End if

			Class converterClazz;

			try {
				converterClazz = classLoader.loadClass(value);
			} catch(ClassNotFoundException cnfe){
				logger.warn("Failed to load transformer converter class", cnfe);

				continue;
			}

			if(!(TransformerConverter.class).isAssignableFrom(converterClazz)){
				throw new IllegalArgumentException("Transformer converter class " + converterClazz.getName() + " is not a subclass of " + TransformerConverter.class.getName());
			}

			ConverterFactory.converters.put((Class)clazz, (Class)converterClazz);
		}
	}

	static
	private int[] parseVersion(String string){
		Pattern pattern = Pattern.compile("^(\\d+)\\.(\\d+)(\\..*)?$");

		Matcher matcher = pattern.matcher(string);
		if(!matcher.matches()){
			return new int[]{-1, -1};
		}

		return new int[]{Integer.parseInt(matcher.group(1)), Integer.parseInt(matcher.group(2))};
	}

	static
	private String formatVersion(int[] version){
		return String.valueOf(version[0]) + "." + String.valueOf(version[1]);
	}

	private static final int[] VERSION = {3, 5};

	private static final Map, Class>> converters = new LinkedHashMap<>();

	private static final Logger logger = LogManager.getLogger(ConverterFactory.class);

	static {
		ClassLoader clazzLoader = ConverterFactory.class.getClassLoader();

		ConverterFactory.init(clazzLoader);
	}
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy