com.feedzai.openml.datarobot.DataRobotModelCreator Maven / Gradle / Ivy
Show all versions of openml-datarobot Show documentation
/*
* Copyright 2018 Feedzai
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package com.feedzai.openml.datarobot;
import com.datarobot.prediction.Predictor;
import com.feedzai.openml.data.schema.AbstractValueSchema;
import com.feedzai.openml.data.schema.CategoricalValueSchema;
import com.feedzai.openml.data.schema.DatasetSchema;
import com.feedzai.openml.data.schema.FieldSchema;
import com.feedzai.openml.java.utils.JavaFileUtils;
import com.feedzai.openml.model.MachineLearningModel;
import com.feedzai.openml.provider.descriptor.fieldtype.ParamValidationError;
import com.feedzai.openml.provider.exception.ModelLoadingException;
import com.feedzai.openml.provider.model.MachineLearningModelLoader;
import com.feedzai.openml.util.load.LoadModelUtils;
import com.feedzai.openml.util.load.LoadSchemaUtils;
import com.feedzai.openml.util.validate.ClassificationValidationUtils;
import com.feedzai.openml.util.validate.ValidationUtils;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.reflect.FieldUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.net.URLClassLoader;
import java.nio.file.Path;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.SortedSet;
import java.util.TreeSet;
import java.util.function.Predicate;
import java.util.stream.Collectors;
/**
* Implementation of the {@link MachineLearningModelLoader}.
*
* This class is responsible for the initialization of a {@link MachineLearningModel} that was generated in DataRobot.
*
* @author Paulo Pereira ([email protected])
* @since 0.1.0
*/
public class DataRobotModelCreator implements MachineLearningModelLoader {
/**
* Logger for this class.
*/
private static final Logger logger = LoggerFactory.getLogger(DataRobotModelCreator.class);
/**
* Template of the package generated by DataRobot that contains the model to import.
*/
private static final String MODEL_PACKAGE_TEMPLATE = "com.datarobot.prediction.dr%s.DRModel";
/**
* This is useful since DR uses the following hard-coded values regardless of the actual boolean values (e.g.,
* "True" will be assumed even if the data has "true" or "TRUE").
*
* @since 0.5.2
*/
private static final Set BOOLEAN_VALUES = ImmutableSet.of("True", "False");
/**
* A function that tells whether the given target variable values are for a binary boolean DR model.
*
* @see #BOOLEAN_VALUES
*
* @since 0.5.2
*/
private static final Predicate IS_BOOLEAN_MODEL = possibleVals ->
Objects.equals(BOOLEAN_VALUES, Arrays.stream(possibleVals).collect(Collectors.toSet()));
/**
* {@inheritDoc}
*
* This provider assumes the {@code modelPath} is a directory.
*/
@Override
public ClassificationBinaryDataRobotModel loadModel(final Path modelPath,
final DatasetSchema schema) throws ModelLoadingException {
if (!schema.getTargetIndex().isPresent()) {
throw new ModelLoadingException("Cannot load a model with a schema that has no target variable.");
}
logger.info("Trying to load a model in path [{}]...", modelPath);
ClassificationValidationUtils.validateParamsModelToLoad(this, modelPath, schema, ImmutableMap.of());
final Pair predictorPair = createPredictorInstance(modelPath);
final Predictor predictor = predictorPair.getKey();
final int predictorSize = predictor.get_double_predictors().length + predictor.get_string_predictors().length;
// We ignore the target variable in the schema for schema matching purposes.
if (predictorSize != schema.getFieldSchemas().size() - 1) {
final String errorMsg = String.format(
"Wrong number of fields in the given schema. The model expected %d feature fields + 1 target field," +
" but the schema had a total of %d fields only (encompassing both features and target fields).",
predictorSize,
schema.getFieldSchemas().size()
);
final String extraMsg = String.format(
" Schema expected by the model %s. Schema provided %s.",
predictor2Str(predictor),
schema
);
logger.error(errorMsg + extraMsg);
throw new ModelLoadingException(errorMsg);
}
final String[] targetModelValues = getTargetModelValues(predictor);
final SortedSet nominalValues = checkTargetModelValuesWithSchema(schema, targetModelValues);
final ClassificationBinaryDataRobotModel resultingModel = new ClassificationBinaryDataRobotModel(
predictor,
nominalValues.first().equals(targetModelValues[0]),
modelPath,
schema,
predictorPair.getValue()
);
ClassificationValidationUtils.validateClassificationModel(schema, resultingModel);
logger.info("Model in path [{}] loaded successfully.", modelPath);
return resultingModel;
}
/**
* Converts a predictor to string.
*
* @param predictor The predictor.
* @return The string.
* @since 0.5.1
*/
private String predictor2Str(final Predictor predictor) {
final StringBuilder stringBuilder = new StringBuilder();
for (final String doublePredName : predictor.get_double_predictors()) {
stringBuilder.append(doublePredName);
stringBuilder.append(",");
}
for (final String strPredName : predictor.get_string_predictors()) {
stringBuilder.append(strPredName);
stringBuilder.append(",");
}
return stringBuilder.toString();
}
/**
* Creates a predictor instance for the binary model generated in DataRobot.
*
* @param modelPath The path from where the model was initially loaded.
* @return The predictor for the binary model generated in DataRobot.
* @throws ModelLoadingException If there is a problem creating the the predictor.
*/
private Pair createPredictorInstance(final Path modelPath) throws ModelLoadingException {
final String modelFilePath = LoadModelUtils.getModelFilePath(modelPath).toAbsolutePath().toString();
final URLClassLoader urlClassLoader = JavaFileUtils.getUrlClassLoader(
modelFilePath,
ClassificationBinaryDataRobotModel.class.getClassLoader()
);
return Pair.of(
(Predictor) JavaFileUtils.createNewInstanceFromClassLoader(
modelFilePath,
MODEL_PACKAGE_TEMPLATE,
urlClassLoader
),
urlClassLoader
);
}
/**
* Gets the values of the target field used to train the DataRobot model.
*
* @param predictor Predictor for the binary model generated in DataRobot.
* @return The target values used to train the model.
*/
@VisibleForTesting
String[] getTargetModelValues(final Predictor predictor) {
final String targetVarField = "classLabels";
try {
return (String[]) FieldUtils.readField(predictor, targetVarField, true);
} catch (final Exception e) {
final String errorMsg = String.format(
"Jar file of the DataRobot model may not be supported. A possible cause is that the model might be to " +
"old and a newer version is required because it lacks the \"%s\" field with the target " +
"values. Ideally, you should create a new project on DataRobot and train new models. " +
"As a workaround, the load will assume the target variable values to be 0 and 1, in that order.",
targetVarField
);
logger.warn(errorMsg, e);
return new String[] { "0", "1" };
}
}
/**
* Check that the target values used to train the DataRobot model are compatible with the ones declared in the schema.
*
* @param schema The {@link DatasetSchema} the model uses.
* @param targetModelValues Target values used to train the model.
* @return The nominal values of the target field declared in the schema.
* @throws ModelLoadingException If the target values are incompatible.
*/
SortedSet checkTargetModelValuesWithSchema(final DatasetSchema schema,
final String[] targetModelValues) throws ModelLoadingException {
final SortedSet nominalValues = schema.getTargetFieldSchema()
.map(FieldSchema::getValueSchema)
.map(CategoricalValueSchema.class::cast)
.map(CategoricalValueSchema::getNominalValues)
.orElseThrow(() -> new ModelLoadingException("Cannot load a model with a schema that has no target variable."));
if (IS_BOOLEAN_MODEL.test(targetModelValues)) {
if (Objects.equals(
nominalValues.stream().map(StringUtils::lowerCase).collect(Collectors.toSet()),
BOOLEAN_VALUES.stream().map(StringUtils::lowerCase).collect(Collectors.toSet())
)) {
return new TreeSet<>(BOOLEAN_VALUES);
}
final String delimiter = ",";
final String errorMsg = String.format(
"Incompatible target values. The model is binary and thus expects some form of: [%s], but the schema had: %s.",
String.join(delimiter, targetModelValues),
String.join(delimiter, nominalValues)
);
logger.error(errorMsg);
throw new ModelLoadingException(errorMsg);
}
if (nominalValues.size() != targetModelValues.length ||
!nominalValues.containsAll(Arrays.asList(targetModelValues))) {
final String delimiter = ",";
final String errorMsg = String.format(
"Incompatible target values. model: [%s], schema: %s.",
String.join(delimiter, targetModelValues),
String.join(delimiter, nominalValues)
);
logger.error(errorMsg);
throw new ModelLoadingException(errorMsg);
}
return nominalValues;
}
@Override
public DatasetSchema loadSchema(final Path modelPath) throws ModelLoadingException {
return LoadSchemaUtils.datasetSchemaFromJson(modelPath);
}
/**
* Validates that the target field only has two possible values.
*
* @param targetValue The target value of a given {@link DatasetSchema} (fetched through {@link DatasetSchema#getTargetFieldSchema()}).
* @return A list of {@link ParamValidationError} with the problems/error found during the validation.
*/
public List validateTargetIsBinary(final AbstractValueSchema targetValue) {
final ImmutableList.Builder validationErrors = ImmutableList.builder();
if (targetValue instanceof CategoricalValueSchema) {
final int nominalValuesSize = ((CategoricalValueSchema) targetValue).getNominalValues().size();
if (nominalValuesSize != 2) {
validationErrors.add(
new ParamValidationError("At the moment only binary classification models are supported")
);
}
}
return validationErrors.build();
}
/**
* Validates that the model file is in the expected file format (jar).
*
* @param modelPath The path from where the model was initially loaded.
* @return A list of {@link ParamValidationError} with the problems/error found during the validation.
*/
List validateModelFileFormat(final Path modelPath) {
final ImmutableList.Builder validationErrors = ImmutableList.builder();
try {
final String modelFilePath = LoadModelUtils.getModelFilePath(modelPath).toAbsolutePath().toString();
if (!JavaFileUtils.isJarFile(modelFilePath)) {
validationErrors.add(
new ParamValidationError(
String.format(
"Extension [%s] not recognized for a DataRobot model, the model should be" +
"exported in the [%s] extension.",
modelFilePath,
JavaFileUtils.JAR_EXTENSION
))
);
}
} catch (ModelLoadingException e) {
validationErrors.add(
new ParamValidationError(
String.format(
"Unable to find a model file in [%s].",
modelPath
))
);
}
return validationErrors.build();
}
@Override
public List validateForLoad(final Path modelPath,
final DatasetSchema schema,
final Map params) {
final ImmutableList.Builder errorBuilder = ImmutableList.builder();
errorBuilder.addAll(ValidationUtils.baseLoadValidations(schema, params));
final List targetRelatedErrors = schema.getTargetFieldSchema()
.map(target -> validateForLoad(schema, target))
.orElse(ImmutableList.of(new ParamValidationError("Cannot load a model with a schema that has no target variable.")));
errorBuilder.addAll(targetRelatedErrors);
errorBuilder.addAll(validateModelFileFormat(modelPath));
errorBuilder.addAll(ValidationUtils.validateModelInDir(modelPath));
return errorBuilder.build();
}
/**
* Validates the state of a dataset schema for a load operation. Assumes that the schema has a target field, which is passed through {@code target}.
*
* @param schema The dataset schema, which must include a target variable.
* @param target The target field of the schema. Must match {@code schema.getTargetFieldSchema()}.
* @return A list of validation errors detected. If empty, then the load operation is safe from the perspective of this method.
*/
private List validateForLoad(final DatasetSchema schema,
final FieldSchema target) {
final ImmutableList.Builder errorBuilder = ImmutableList.builder();
ValidationUtils.validateCategoricalSchema(schema).ifPresent(errorBuilder::add);
errorBuilder.addAll(validateTargetIsBinary(target.getValueSchema()));
return errorBuilder.build();
}
}