All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ai.libs.mlplan.multilabel.mekamlplan.MekaPipelineFactory Maven / Gradle / Ivy

package ai.libs.mlplan.multilabel.mekamlplan;

import java.util.LinkedList;
import java.util.List;
import java.util.Map.Entry;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import ai.libs.jaicore.basic.sets.SetUtil;
import ai.libs.jaicore.components.exceptions.ComponentInstantiationFailedException;
import ai.libs.jaicore.components.model.ComponentInstance;
import ai.libs.jaicore.components.model.NumericParameterDomain;
import ai.libs.jaicore.components.model.Parameter;
import ai.libs.jaicore.ml.classification.multilabel.learner.IMekaClassifier;
import ai.libs.jaicore.ml.classification.multilabel.learner.MekaClassifier;
import meka.classifiers.multilabel.MultiLabelClassifier;
import weka.classifiers.Classifier;
import weka.classifiers.MultipleClassifiersCombiner;
import weka.classifiers.SingleClassifierEnhancer;
import weka.classifiers.functions.supportVector.Kernel;
import weka.core.OptionHandler;

/**
 * A pipeline factory that converts a given ComponentInstance that consists of
 * components that correspond to MEKA algorithms to a MultiLabelClassifier.
 *
 */
public class MekaPipelineFactory implements IMekaPipelineFactory {

	private static final String PARAMETER_NAME_WITH_DASH_WARNING = "Required interface of component {} has dash or underscore in interface id {}";

	/* loggin */
	private static final Logger logger = LoggerFactory.getLogger(MekaPipelineFactory.class);

	@Override
	public IMekaClassifier getComponentInstantiation(final ComponentInstance ci) throws ComponentInstantiationFailedException {
		MultiLabelClassifier instance = null;
		try {
			instance = (MultiLabelClassifier) this.getClassifier(ci);
			return new MekaClassifier(instance);
		} catch (Exception e) {
			throw new ComponentInstantiationFailedException(e, "Could not instantiate " + ci.getComponent().getName());
		}
	}

	private Classifier getClassifier(final ComponentInstance ci) throws Exception {
		Classifier c = (Classifier) Class.forName(ci.getComponent().getName()).newInstance();
		List optionsList = this.getOptionsForParameterValues(ci);

		for (Entry reqI : ci.getSatisfactionOfRequiredInterfaces().entrySet()) {
			if (reqI.getKey().startsWith("-") || reqI.getKey().startsWith("_")) {
				logger.warn(PARAMETER_NAME_WITH_DASH_WARNING, ci.getComponent(), reqI.getKey());
			}
			if (!reqI.getKey().equals("B") && !(c instanceof SingleClassifierEnhancer) && !(reqI.getKey().equals("K") && ci.getComponent().getName().endsWith("SMO"))) {
				logger.warn("Classifier {} is not a single classifier enhancer and still has an unexpected required interface: {}. Try to set this configuration in the form of options.", ci.getComponent().getName(), reqI);
				optionsList.add("-" + reqI.getKey());
				optionsList.add(reqI.getValue().getComponent().getName());
				if (!reqI.getValue().getParameterValues().isEmpty() || !reqI.getValue().getSatisfactionOfRequiredInterfaces().isEmpty()) {
					optionsList.add("--");
					optionsList.addAll(this.getOptionsRecursively(reqI.getValue()));
				}
			}
		}
		if (c instanceof OptionHandler) {
			((OptionHandler) c).setOptions(optionsList.toArray(new String[0]));
		}
		for (Entry reqI : ci.getSatisfactionOfRequiredInterfaces().entrySet()) {
			if (reqI.getKey().startsWith("-") || reqI.getKey().startsWith("_")) {
				logger.warn(PARAMETER_NAME_WITH_DASH_WARNING, ci.getComponent(), reqI.getKey());
			}
			if (reqI.getKey().equals("K") && ci.getComponent().getName().endsWith("SMO")) {
				ComponentInstance kernelCI = reqI.getValue();
				logger.debug("Set kernel for SMO to be {}", kernelCI.getComponent().getName());
				Kernel k = (Kernel) Class.forName(kernelCI.getComponent().getName()).newInstance();
				k.setOptions(this.getOptionsForParameterValues(kernelCI).toArray(new String[0]));
			} else if (reqI.getKey().equals("B") && (c instanceof MultipleClassifiersCombiner)) {
				Classifier[] classifiers = this.getListOfBaseLearners(reqI.getValue()).toArray(new Classifier[0]);
				((MultipleClassifiersCombiner) c).setClassifiers(classifiers);
			} else if (reqI.getKey().equals("W") && (c instanceof SingleClassifierEnhancer)) {
				if (logger.isTraceEnabled()) {
					logger.trace("Set {} as a base classifier for {}", reqI.getValue().getComponent().getName(), ci.getComponent().getName());
				}
				((SingleClassifierEnhancer) c).setClassifier(this.getClassifier(reqI.getValue()));
			}
		}

		return c;
	}

	private List getListOfBaseLearners(final ComponentInstance ci) throws Exception {
		List baseLearnerList = new LinkedList<>();
		if (ci.getComponent().getName().equals("MultipleBaseLearnerListElement")) {
			baseLearnerList.add(this.getClassifier(ci.getSatisfactionOfRequiredInterfaces().get("classifier")));
		} else if (ci.getComponent().getName().equals("MultipleBaseLearnerListChain")) {
			baseLearnerList.add(this.getClassifier(ci.getSatisfactionOfRequiredInterfaces().get("classifier")));
			baseLearnerList.addAll(this.getListOfBaseLearners(ci.getSatisfactionOfRequiredInterfaces().get("chain")));
		}
		return baseLearnerList;
	}

	public static List getOptionsForParameterValues(final ComponentInstance ci) {
		List optionsList = new LinkedList<>();
		for (Entry parameterValue : ci.getParameterValues().entrySet()) {
			Parameter param = ci.getComponent().getParameterWithName(parameterValue.getKey());
			if (param.isDefaultValue(parameterValue.getValue()) || parameterValue.getKey().toLowerCase().contains("activator") || parameterValue.getValue().equals("false")) {
				continue;
			}

			if (parameterValue.getValue().equals("true")) {
				optionsList.add("-" + parameterValue.getKey());
			} else {
				optionsList.add("-" + parameterValue.getKey());
				if (ci.getComponent().getParameterWithName(parameterValue.getKey()).isNumeric()) {
					NumericParameterDomain numDom = (NumericParameterDomain) ci.getComponent().getParameterWithName(parameterValue.getKey()).getDefaultDomain();
					if (numDom.isInteger()) {
						optionsList.add(((int) Double.parseDouble(parameterValue.getValue())) + "");
					} else {
						optionsList.add(parameterValue.getValue());
					}
				} else {
					optionsList.add(parameterValue.getValue());
				}

			}
		}
		return optionsList;
	}

	private List getOptionsRecursively(final ComponentInstance ci) {
		List optionsList = this.getOptionsForParameterValues(ci);

		for (Entry reqI : ci.getSatisfactionOfRequiredInterfaces().entrySet()) {
			if (reqI.getKey().startsWith("-") || reqI.getKey().startsWith("_")) {
				logger.warn(PARAMETER_NAME_WITH_DASH_WARNING, ci.getComponent(), reqI.getKey());
			}

			optionsList.add("-" + reqI.getKey());
			if (reqI.getKey().equals("B") || reqI.getKey().equals("K")) {
				List valueList = new LinkedList<>();
				valueList.add(reqI.getValue().getComponent().getName());
				valueList.addAll(this.getOptionsRecursively(reqI.getValue()));
				optionsList.add(SetUtil.implode(valueList, " "));
			} else {
				optionsList.add(reqI.getValue().getComponent().getName());
				if (!reqI.getValue().getParameterValues().isEmpty() || !reqI.getValue().getSatisfactionOfRequiredInterfaces().isEmpty()) {
					optionsList.add("--");
					optionsList.addAll(this.getOptionsRecursively(reqI.getValue()));
				}
			}
		}

		return optionsList;
	}
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy