All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.jpmml.evaluator.ServiceFactory Maven / Gradle / Ivy

/*
 * Copyright (c) 2021 Villu Ruusmann
 *
 * This file is part of JPMML-Evaluator
 *
 * JPMML-Evaluator is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Affero General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * JPMML-Evaluator is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Affero General Public License for more details.
 *
 * You should have received a copy of the GNU Affero General Public License
 * along with JPMML-Evaluator.  If not, see .
 */
package org.jpmml.evaluator;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.Serializable;
import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;
import java.net.URL;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Enumeration;
import java.util.List;
import java.util.Objects;

import com.google.common.collect.ListMultimap;
import com.google.common.collect.Multimaps;
import org.dmg.pmml.PMMLObject;

abstract
public class ServiceFactory implements Serializable {

	private Class keyClazz = null;

	private Class serviceClazz = null;

	transient
	private ListMultimap, Class> serviceProviderClazzes = null;


	protected ServiceFactory(){
	}

	protected ServiceFactory(Class keyClazz, Class serviceClazz){
		setKeyClass(keyClazz);
		setServiceClass(serviceClazz);
	}

	public List> getServiceProviderClasses(Class objectClazz) throws ClassNotFoundException, IOException {
		Class keyClazz = getKeyClass();
		ListMultimap, Class> serviceProviderClazzes = getServiceProviderClasses();

		Class clazz = objectClazz;

		while(clazz != null){

			if(serviceProviderClazzes.containsKey(clazz)){
				return serviceProviderClazzes.get(clazz);
			}

			Class superClazz = clazz.getSuperclass();

			try {
				clazz = superClazz.asSubclass(keyClazz);
			} catch(ClassCastException cce){
				break;
			}
		}

		return Collections.emptyList();
	}

	public ListMultimap, Class> getServiceProviderClasses() throws ClassNotFoundException, IOException {

		if(this.serviceProviderClazzes == null){
			ClassLoader clazzLoader = getClassLoader();

			Class keyClazz = getKeyClass();
			Class serviceClazz = getServiceClass();

			List> serviceProviderClazzes = loadServiceProviderClasses(clazzLoader, serviceClazz);

			this.serviceProviderClazzes = Multimaps.index(serviceProviderClazzes, serviceProviderClazz -> getKey(keyClazz, serviceClazz, serviceProviderClazz));
		}

		return this.serviceProviderClazzes;
	}

	public ClassLoader getClassLoader(){
		Class clazz = getClass();

		return clazz.getClassLoader();
	}

	public Class getKeyClass(){
		return this.keyClazz;
	}

	private void setKeyClass(Class keyClazz){
		this.keyClazz = Objects.requireNonNull(keyClazz);
	}

	public Class getServiceClass(){
		return this.serviceClazz;
	}

	private void setServiceClass(Class serviceClazz){
		this.serviceClazz = Objects.requireNonNull(serviceClazz);
	}

	static
	protected  List> loadServiceProviderClasses(ClassLoader clazzLoader, Class serviceClazz) throws ClassNotFoundException, IOException {
		List> result = new ArrayList<>();

		Enumeration urls = clazzLoader.getResources("META-INF/services/" + serviceClazz.getName());

		while(urls.hasMoreElements()){
			URL url = urls.nextElement();

			try(InputStream is = url.openStream()){
				List> serviceProviderClazzes = loadServiceProviderClasses(is, clazzLoader, serviceClazz);

				result.addAll(serviceProviderClazzes);
			}
		}

		return result;
	}

	static
	private  List> loadServiceProviderClasses(InputStream is, ClassLoader clazzLoader, Class serviceClazz) throws ClassNotFoundException, IOException {
		List> result = new ArrayList<>();

		BufferedReader reader = new BufferedReader(new InputStreamReader(is, "UTF-8"), 1024);

		while(true){
			String line = reader.readLine();

			if(line == null){
				break;
			}

			int hash = line.indexOf('#');
			if(hash > -1){
				line = line.substring(0, hash);
			}

			line = line.trim();

			if(line.isEmpty()){
				continue;
			}

			Class serviceProviderClazz = Class.forName(line, false, clazzLoader);

			if(!(serviceClazz).isAssignableFrom(serviceProviderClazz)){
				throw new IllegalArgumentException(line);
			}

			result.add((Class)serviceProviderClazz);
		}

		reader.close();

		return result;
	}

	static
	protected  Class getKey(Class keyClazz, Class serviceClazz, Class serviceProviderClazz){
		Class clazz = serviceProviderClazz;

		while(clazz != null){
			Class superClazz = clazz.getSuperclass();

			if((serviceClazz).equals(superClazz)){
				ParameterizedType parameterizedType = (ParameterizedType)clazz.getGenericSuperclass();

				Type[] arguments = parameterizedType.getActualTypeArguments();
				if(arguments.length != 1){
					throw new IllegalArgumentException(clazz.getName());
				}

				Class argumentClazz = (Class)arguments[0];

				return argumentClazz.asSubclass(keyClazz);
			}

			clazz = superClazz;
		}

		throw new IllegalArgumentException(serviceProviderClazz.getName());
	}
}