All Downloads are FREE. Search and download functionalities are using the official Maven repository.

cloud.eppo.BanditEvaluator Maven / Gradle / Ivy

There is a newer version: 3.3.2
Show newest version
package cloud.eppo;

import static cloud.eppo.Utils.getShard;

import cloud.eppo.api.Actions;
import cloud.eppo.api.Attributes;
import cloud.eppo.api.DiscriminableAttributes;
import cloud.eppo.api.EppoValue;
import cloud.eppo.ufc.dto.*;
import java.util.*;
import java.util.stream.Collectors;

public class BanditEvaluator {

  private static final int BANDIT_ASSIGNMENT_SHARDS = 10000; // hard-coded for now

  public static BanditEvaluationResult evaluateBandit(
      String flagKey,
      String subjectKey,
      DiscriminableAttributes subjectAttributes,
      Actions actions,
      BanditModelData modelData) {
    Map actionScores = scoreActions(subjectAttributes, actions, modelData);
    Map actionWeights =
        weighActions(actionScores, modelData.getGamma(), modelData.getActionProbabilityFloor());
    String selectedActionKey = selectAction(flagKey, subjectKey, actionWeights);

    // Compute optimality gap in terms of score
    double topScore =
        actionScores.values().stream().mapToDouble(Double::doubleValue).max().orElse(0);
    double optimalityGap = topScore - actionScores.get(selectedActionKey);

    return new BanditEvaluationResult(
        flagKey,
        subjectKey,
        subjectAttributes,
        selectedActionKey,
        actions.get(selectedActionKey),
        actionScores.get(selectedActionKey),
        actionWeights.get(selectedActionKey),
        modelData.getGamma(),
        optimalityGap);
  }

  private static Map scoreActions(
      DiscriminableAttributes subjectAttributes, Actions actions, BanditModelData modelData) {
    return actions.entrySet().stream()
        .collect(
            Collectors.toMap(
                Map.Entry::getKey,
                e -> {
                  String actionName = e.getKey();
                  DiscriminableAttributes actionAttributes = e.getValue();

                  // get all coefficients known to the model for this action
                  BanditCoefficients banditCoefficients =
                      modelData.getCoefficients().get(actionName);

                  if (banditCoefficients == null) {
                    // Unknown action; return the default action score
                    return modelData.getDefaultActionScore();
                  }

                  // Score the action using the provided attributes
                  double actionScore = banditCoefficients.getIntercept();
                  actionScore +=
                      scoreContextForCoefficients(
                          actionAttributes.getNumericAttributes(),
                          banditCoefficients.getActionNumericCoefficients());
                  actionScore +=
                      scoreContextForCoefficients(
                          actionAttributes.getCategoricalAttributes(),
                          banditCoefficients.getActionCategoricalCoefficients());
                  actionScore +=
                      scoreContextForCoefficients(
                          subjectAttributes.getNumericAttributes(),
                          banditCoefficients.getSubjectNumericCoefficients());
                  actionScore +=
                      scoreContextForCoefficients(
                          subjectAttributes.getCategoricalAttributes(),
                          banditCoefficients.getSubjectCategoricalCoefficients());

                  return actionScore;
                }));
  }

  private static double scoreContextForCoefficients(
      Attributes attributes, Map coefficients) {

    double totalScore = 0.0;

    for (BanditAttributeCoefficients attributeCoefficients : coefficients.values()) {
      EppoValue contextValue = attributes.get(attributeCoefficients.getAttributeKey());
      // The coefficient implementation knows how to score
      double attributeScore = attributeCoefficients.scoreForAttributeValue(contextValue);
      totalScore += attributeScore;
    }

    return totalScore;
  }

  private static Map weighActions(
      Map actionScores, double gamma, double actionProbabilityFloor) {
    Double highestScore = null;
    String highestScoredAction = null;
    for (Map.Entry actionScore : actionScores.entrySet()) {
      if (highestScore == null
          || actionScore.getValue() > highestScore
          || actionScore
                  .getValue()
                  .equals(highestScore) // note: we break ties for scores by action name
              && actionScore.getKey().compareTo(highestScoredAction) < 0) {
        highestScore = actionScore.getValue();
        highestScoredAction = actionScore.getKey();
      }
    }

    // Weigh all the actions using their score
    Map actionWeights = new HashMap<>();
    double totalNonHighestWeight = 0.0;
    for (Map.Entry actionScore : actionScores.entrySet()) {
      if (actionScore.getKey().equals(highestScoredAction)) {
        // The highest scored action is weighed at the end
        continue;
      }

      // Compute weight (probability)
      double unboundedProbability =
          1 / (actionScores.size() + (gamma * (highestScore - actionScore.getValue())));
      double minimumProbability = actionProbabilityFloor / actionScores.size();
      double boundedProbability = Math.max(unboundedProbability, minimumProbability);
      totalNonHighestWeight += boundedProbability;

      actionWeights.put(actionScore.getKey(), boundedProbability);
    }

    // Weigh the highest scoring action (defensively preventing a negative probability)
    double weightForHighestScore = Math.max(1 - totalNonHighestWeight, 0);
    actionWeights.put(highestScoredAction, weightForHighestScore);
    return actionWeights;
  }

  private static String selectAction(
      String flagKey, String subjectKey, Map actionWeights) {
    // Deterministically "shuffle" the actions
    // This way as action weights shift, a bunch of users who were on the edge of one action won't
    // all be shifted to the same new action at the same time.
    List shuffledActionKeys =
        actionWeights.keySet().stream()
            .sorted(
                Comparator.comparingInt(
                        (String actionKey) ->
                            getShard(
                                flagKey + "-" + subjectKey + "-" + actionKey,
                                BANDIT_ASSIGNMENT_SHARDS))
                    .thenComparing(actionKey -> actionKey))
            .collect(Collectors.toList());

    // Select action from the shuffled actions, based on weight
    double assignedShard = getShard(flagKey + "-" + subjectKey, BANDIT_ASSIGNMENT_SHARDS);
    double assignmentWeightThreshold = assignedShard / (double) BANDIT_ASSIGNMENT_SHARDS;
    double cumulativeWeight = 0;
    String assignedAction = null;
    for (String actionKey : shuffledActionKeys) {
      cumulativeWeight += actionWeights.get(actionKey);
      if (cumulativeWeight > assignmentWeightThreshold) {
        assignedAction = actionKey;
        break;
      }
    }
    return assignedAction;
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy