cloud.eppo.ufc.dto.BanditCategoricalAttributeCoefficients Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of sdk-common-jvm Show documentation
Show all versions of sdk-common-jvm Show documentation
Eppo SDK for JVM shared library
package cloud.eppo.ufc.dto;
import cloud.eppo.api.EppoValue;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class BanditCategoricalAttributeCoefficients implements BanditAttributeCoefficients {
private final Logger logger =
LoggerFactory.getLogger(BanditCategoricalAttributeCoefficients.class);
private final String attributeKey;
private final Double missingValueCoefficient;
private final Map valueCoefficients;
public BanditCategoricalAttributeCoefficients(
String attributeKey, Double missingValueCoefficient, Map valueCoefficients) {
this.attributeKey = attributeKey;
this.missingValueCoefficient = missingValueCoefficient;
this.valueCoefficients = valueCoefficients;
}
@Override
public String getAttributeKey() {
return attributeKey;
}
public double scoreForAttributeValue(EppoValue attributeValue) {
if (attributeValue == null || attributeValue.isNull()) {
return missingValueCoefficient;
}
if (attributeValue.isNumeric()) {
logger.warn("Unexpected numeric attribute value for attribute {}", attributeKey);
return missingValueCoefficient;
}
String valueKey = attributeValue.toString();
Double coefficient = valueCoefficients.get(valueKey);
// Categorical attributes are treated as one-hot booleans, so it's just the coefficient * 1 when
// present
return coefficient != null ? coefficient : missingValueCoefficient;
}
public Double getMissingValueCoefficient() {
return missingValueCoefficient;
}
public Map getValueCoefficients() {
return valueCoefficients;
}
}