ai.libs.mlplan.multilabel.MekaPipelineFactory Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of mlplancli Show documentation
Show all versions of mlplancli Show documentation
This project provides an implementation of the AutoML tool ML-Plan.
The newest version!
package ai.libs.mlplan.multilabel;
import java.util.LinkedList;
import java.util.List;
import java.util.Map.Entry;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import ai.libs.hasco.exceptions.ComponentInstantiationFailedException;
import ai.libs.hasco.model.ComponentInstance;
import ai.libs.hasco.model.NumericParameterDomain;
import ai.libs.jaicore.basic.sets.SetUtil;
import ai.libs.mlplan.multiclass.wekamlplan.IClassifierFactory;
import meka.classifiers.multilabel.MultiLabelClassifier;
import weka.classifiers.Classifier;
import weka.classifiers.MultipleClassifiersCombiner;
import weka.classifiers.SingleClassifierEnhancer;
import weka.classifiers.functions.supportVector.Kernel;
import weka.core.OptionHandler;
/**
* A pipeline factory that converts a given ComponentInstance that consists of
* components that correspond to MEKA algorithms to a MultiLabelClassifier.
*
*/
public class MekaPipelineFactory implements IClassifierFactory {
private static final String PARAMETER_NAME_WITH_DASH_WARNING = "Required interface of component {} has dash or underscore in interface id {}";
/* loggin */
private static final Logger logger = LoggerFactory.getLogger(MekaPipelineFactory.class);
@Override
public Classifier getComponentInstantiation(final ComponentInstance ci) throws ComponentInstantiationFailedException {
MultiLabelClassifier instance = null;
try {
instance = (MultiLabelClassifier) this.getClassifier(ci);
return instance;
} catch (Exception e) {
throw new ComponentInstantiationFailedException(e, "Could not instantiate " + ci.getComponent().getName());
}
}
private Classifier getClassifier(final ComponentInstance ci) throws Exception {
Classifier c = (Classifier) Class.forName(ci.getComponent().getName()).newInstance();
List optionsList = this.getOptionsForParameterValues(ci);
for (Entry reqI : ci.getSatisfactionOfRequiredInterfaces().entrySet()) {
if (reqI.getKey().startsWith("-") || reqI.getKey().startsWith("_")) {
logger.warn(PARAMETER_NAME_WITH_DASH_WARNING, ci.getComponent(), reqI.getKey());
}
if (!reqI.getKey().equals("B") && !(c instanceof SingleClassifierEnhancer) && !(reqI.getKey().equals("K") && ci.getComponent().getName().endsWith("SMO"))) {
logger.warn("Classifier {} is not a single classifier enhancer and still has an unexpected required interface: {}. Try to set this configuration in the form of options.", ci.getComponent().getName(), reqI);
optionsList.add("-" + reqI.getKey());
optionsList.add(reqI.getValue().getComponent().getName());
if (!reqI.getValue().getParameterValues().isEmpty() || !reqI.getValue().getSatisfactionOfRequiredInterfaces().isEmpty()) {
optionsList.add("--");
optionsList.addAll(this.getOptionsRecursively(reqI.getValue()));
}
}
}
if (c instanceof OptionHandler) {
((OptionHandler) c).setOptions(optionsList.toArray(new String[0]));
}
for (Entry reqI : ci.getSatisfactionOfRequiredInterfaces().entrySet()) {
if (reqI.getKey().startsWith("-") || reqI.getKey().startsWith("_")) {
logger.warn(PARAMETER_NAME_WITH_DASH_WARNING, ci.getComponent(), reqI.getKey());
}
if (reqI.getKey().equals("K") && ci.getComponent().getName().endsWith("SMO")) {
ComponentInstance kernelCI = reqI.getValue();
logger.debug("Set kernel for SMO to be {}", kernelCI.getComponent().getName());
Kernel k = (Kernel) Class.forName(kernelCI.getComponent().getName()).newInstance();
k.setOptions(this.getOptionsForParameterValues(kernelCI).toArray(new String[0]));
} else if (reqI.getKey().equals("B") && (c instanceof MultipleClassifiersCombiner)) {
Classifier[] classifiers = this.getListOfBaseLearners(reqI.getValue()).toArray(new Classifier[0]);
((MultipleClassifiersCombiner) c).setClassifiers(classifiers);
} else if (reqI.getKey().equals("W") && (c instanceof SingleClassifierEnhancer)) {
if (logger.isTraceEnabled()) {
logger.trace("Set {} as a base classifier for {}", reqI.getValue().getComponent().getName(), ci.getComponent().getName());
}
((SingleClassifierEnhancer) c).setClassifier(this.getClassifier(reqI.getValue()));
}
}
return c;
}
private List getListOfBaseLearners(final ComponentInstance ci) throws Exception {
List baseLearnerList = new LinkedList<>();
if (ci.getComponent().getName().equals("MultipleBaseLearnerListElement")) {
baseLearnerList.add(this.getClassifier(ci.getSatisfactionOfRequiredInterfaces().get("classifier")));
} else if (ci.getComponent().getName().equals("MultipleBaseLearnerListChain")) {
baseLearnerList.add(this.getClassifier(ci.getSatisfactionOfRequiredInterfaces().get("classifier")));
baseLearnerList.addAll(this.getListOfBaseLearners(ci.getSatisfactionOfRequiredInterfaces().get("chain")));
}
return baseLearnerList;
}
private List getOptionsForParameterValues(final ComponentInstance ci) {
List optionsList = new LinkedList<>();
for (Entry parameterValue : ci.getParameterValues().entrySet()) {
if (parameterValue.getKey().startsWith("-") || parameterValue.getKey().startsWith("_")) {
logger.warn(PARAMETER_NAME_WITH_DASH_WARNING, ci.getComponent(), parameterValue);
}
if (parameterValue.getValue().equals("true")) {
optionsList.add("-" + parameterValue.getKey());
} else if (parameterValue.getKey().toLowerCase().contains("activator") || parameterValue.getValue().equals("false")) {
// ignore this parameter
} else {
optionsList.add("-" + parameterValue.getKey());
if (ci.getComponent().getParameterWithName(parameterValue.getKey()).isNumeric()) {
NumericParameterDomain numDom = (NumericParameterDomain) ci.getComponent().getParameterWithName(parameterValue.getKey()).getDefaultDomain();
if (numDom.isInteger()) {
optionsList.add(((int) Double.parseDouble(parameterValue.getValue())) + "");
} else {
optionsList.add(parameterValue.getValue());
}
} else {
optionsList.add(parameterValue.getValue());
}
}
}
return optionsList;
}
private List getOptionsRecursively(final ComponentInstance ci) {
List optionsList = this.getOptionsForParameterValues(ci);
for (Entry reqI : ci.getSatisfactionOfRequiredInterfaces().entrySet()) {
if (reqI.getKey().startsWith("-") || reqI.getKey().startsWith("_")) {
logger.warn(PARAMETER_NAME_WITH_DASH_WARNING, ci.getComponent(), reqI.getKey());
}
optionsList.add("-" + reqI.getKey());
if (reqI.getKey().equals("B") || reqI.getKey().equals("K")) {
List valueList = new LinkedList<>();
valueList.add(reqI.getValue().getComponent().getName());
valueList.addAll(this.getOptionsRecursively(reqI.getValue()));
optionsList.add(SetUtil.implode(valueList, " "));
} else {
optionsList.add(reqI.getValue().getComponent().getName());
if (!reqI.getValue().getParameterValues().isEmpty() || !reqI.getValue().getSatisfactionOfRequiredInterfaces().isEmpty()) {
optionsList.add("--");
optionsList.addAll(this.getOptionsRecursively(reqI.getValue()));
}
}
}
return optionsList;
}
}