com.feedzai.openml.provider.lightgbm.FairGBMDescriptorUtil Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of openml-lightgbm Show documentation
Show all versions of openml-lightgbm Show documentation
OpenML LightGBM Machine Learning Model and Classifier provider
The newest version!
/*
* Copyright 2022 Feedzai
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package com.feedzai.openml.provider.lightgbm;
import com.feedzai.openml.provider.descriptor.ModelParameter;
import com.feedzai.openml.provider.descriptor.fieldtype.ChoiceFieldType;
import com.feedzai.openml.provider.descriptor.fieldtype.FreeTextFieldType;
import com.feedzai.openml.provider.descriptor.fieldtype.NumericFieldType;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Sets;
import java.util.Set;
import java.util.stream.Collectors;
/**
* Utility to organize all the necessary Machine Learning Hyper-Parameters for configuring the training of LightGBM.
*
* @author Andre Cruz ([email protected])
* @since 1.4.0
*/
public class FairGBMDescriptorUtil extends LightGBMDescriptorUtil {
public static final String CONSTRAINT_GROUP_COLUMN_PARAMETER_NAME = "constraint_group_column";
/**
* Defines the set of model parameters supported by the FairGBM algorithm.
*/
static final Set PARAMS = Sets.union(ImmutableSet.of(
// The single parameter that will change for every different dataset
new ModelParameter(
CONSTRAINT_GROUP_COLUMN_PARAMETER_NAME,
"(Fairness) Sensitive group column",
"Fairness constraints are enforced over this column.\n"
+ "This column must be in categorical format.\n"
+ "Start this string with `name:` to use the name of a column, \n"
+ "e.g., `name:age_group` for a column named `age_group`.",
MANDATORY,
new FreeTextFieldType("")
// new FreeTextFieldType("", ".+") # TODO: https://github.com/feedzai/feedzai-openml/issues/68
),
new ModelParameter(
"constraint_type",
"(Fairness) Constraint type",
"Enforces group-wise parity on the given target metric for the selected group column. "
+ "In general, FPR can be used for most detection settings "
+ "to equalize the negative outcomes on legitimate individuals "
+ "(false positives).",
NOT_MANDATORY,
new ChoiceFieldType(ImmutableSet.of("FPR", "FNR", "FPR,FNR"), "FPR")
),
// Parameters related to global constraints
new ModelParameter(
"global_constraint_type",
"(Fairness) Global constraint type",
"FairGBM modifies the output scores to meet your target FPR and/or FNR as well as "
+ "fairness at a decision threshold of approximately 0.5 (or 500 in Pulse). Set parameters "
+ "(Fairness) Global target FPR/FNR accordingly. Using decision thresholds far from 0.5 "
+ "will not ensure fairness.",
NOT_MANDATORY,
new ChoiceFieldType(ImmutableSet.of("FPR", "FNR", "FPR,FNR"), "FPR,FNR")
),
new ModelParameter(
"global_target_fpr",
"(Fairness) Global target FPR",
"This parameter is only active when '(Fairness) Global constraint type' includes "
+ "'FPR'. This is an inequality constraint: inactive when FPR is lower than the target. "
+ "Oftentimes, some tension is required between global FPR and FNR constraints in order to "
+ "achieve the target values (in these cases pick 'FPR,FNR' for the '(Fairness) Global "
+ "constraint type' parameter).",
NOT_MANDATORY,
doubleRange(0.0, 1.0, 0.05)
),
new ModelParameter(
"global_target_fnr",
"(Fairness) Global target FNR",
"This parameter is only active when '(Fairness) Global constraint type' includes "
+ "'FNR'. This is an inequality constraint: inactive when FNR is lower than the target. "
+ "Oftentimes, some tension is required between global FPR and FNR constraints in order to "
+ "achieve the target values (in these cases pick 'FPR,FNR' for the '(Fairness) Global "
+ "constraint type' parameter).",
NOT_MANDATORY,
doubleRange(0.0, 1.0, 0.5)
),
new ModelParameter(
"objective",
"(Fairness) Objective function",
"For FairGBM you must use a constrained optimization function. "
+ "`constrained_cross_entropy` is recommended for most cases.",
NOT_MANDATORY,
new ChoiceFieldType(
ImmutableSet.of("constrained_cross_entropy", "constrained_recall_objective"),
"constrained_cross_entropy")
),
// Tolerance on the fairness constraints
new ModelParameter(
"constraint_fpr_threshold",
"(Fairness) FPR tolerance for fairness",
"The tolerance when fulfilling fairness FPR constraints. "
+ "The allowed difference between group-wise FPR. "
+ "The value 0.0 enforces group-wise FPR to be *exactly* equal. "
+ "Higher values lead to a less strict fairness enforcement.",
NOT_MANDATORY,
doubleRange(0.0, 1.0, 0.0)
),
new ModelParameter(
"constraint_fnr_threshold",
"(Fairness) FNR tolerance for fairness",
"The tolerance when fulfilling fairness FNR constraints. "
+ "The allowed difference between group-wise FNR. "
+ "The value 0.0 enforces group-wise FNR to be *exactly* equal. "
+ "Higher values lead to a less strict fairness enforcement.",
NOT_MANDATORY,
doubleRange(0.0, 1.0, 0.0)
),
// Eventually we want this parameter to not depend as much on the size of the dataset
// But currently this needs to be changed for each dataset considering its size (larger for larger datasets)
// See: https://github.com/feedzai/fairgbm/issues/7
new ModelParameter(
"multiplier_learning_rate",
"(Fairness) Multipliers' learning rate",
"The Lagrangian multipliers control how strict the constraint enforcement is.",
NOT_MANDATORY,
NumericFieldType.min(Float.MIN_VALUE, NumericFieldType.ParameterConfigType.DOUBLE, 1e3)
), // NOTE: I'm using Float.MIN_VALUE here because the minimum value of a double in C++ depends on the architecture it's ran on, using float here is more conservative
new ModelParameter(
"init_multipliers",
"(Fairness) Initial multipliers",
"The Lagrangian multipliers control how strict the constraint enforcement is. "
+ "The default value is starting with zero `0` for each constraint.",
NOT_MANDATORY,
new FreeTextFieldType("")
// new FreeTextFieldType("", "^((\\d+(\\.\\d*)?,)*(\\d+(\\.\\d*)?))?$") # TODO: https://github.com/feedzai/feedzai-openml/issues/68
),
// These parameters probably shouldn't be changed in 90% of cases
new ModelParameter(
"constraint_stepwise_proxy",
"(Fairness) Stepwise proxy for fairness constraints",
"The type of proxy function to use for the fairness constraint. "
+ "We need to use a differentiable proxy function, as FPR and FNR have discontinuous gradients.",
NOT_MANDATORY,
new ChoiceFieldType(ImmutableSet.of("cross_entropy", "quadratic", "hinge"), "cross_entropy")
),
new ModelParameter(
"objective_stepwise_proxy",
"(Fairness) Stepwise proxy for global constraints",
"The proxy function to use for the objective function. "
+ "Only used when explicitly optimizing for Recall (or any other metric of the "
+ "confusion matrix). Leave blank when using standard objectives, such as cross-entropy.",
NOT_MANDATORY,
new ChoiceFieldType(ImmutableSet.of("cross_entropy", "quadratic", "hinge", ""), "")
),
// Override this parameter from LightGBM so we can disallow using RF
new ModelParameter(
BOOSTING_TYPE_PARAMETER_NAME,
"Boosting type",
"Type of boosting model:\n"
+ "'gbdt' is a good starting point,\n"
+ "'goss' is faster but slightly less accurate,\n"
+ "'dart' is much slower but might improve performance,\n"
+ "'rf' is the random forest mode.",
MANDATORY,
new ChoiceFieldType(
ImmutableSet.of("gbdt", "dart", "goss"),
"gbdt"
)
)
// TODO: assess whether these parameters would ever be useful
// // These parameters probably shouldn't be changed in 99% of cases
// new ModelParameter(
// "stepwise_proxy_margin",
// "",
// "",
// NOT_MANDATORY,
// new FreeTextFieldType("")
// ),
// new ModelParameter(
// "score_threshold",
// "",
// "",
// NOT_MANDATORY,
// new FreeTextFieldType("")
// ),
// new ModelParameter(
// "global_score_threshold",
// "",
// "",
// NOT_MANDATORY,
// new FreeTextFieldType("")
// )
), LightGBMDescriptorUtil.PARAMS.stream()
.filter(el -> !el.getName().equals(BOOSTING_TYPE_PARAMETER_NAME))
.collect(Collectors.toSet()));
}