All Downloads are FREE. Search and download functionalities are using the official Maven repository.

nl.topicus.jdbc.shaded.io.grpc.util.RoundRobinLoadBalancerFactory Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2016 The gRPC Authors
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package nl.topicus.jdbc.shaded.io.grpc.util;

import static nl.topicus.jdbc.shaded.com.google.common.base.Preconditions.checkNotNull;
import static nl.topicus.jdbc.shaded.io.grpc.ConnectivityState.CONNECTING;
import static nl.topicus.jdbc.shaded.io.grpc.ConnectivityState.IDLE;
import static nl.topicus.jdbc.shaded.io.grpc.ConnectivityState.READY;
import static nl.topicus.jdbc.shaded.io.grpc.ConnectivityState.SHUTDOWN;
import static nl.topicus.jdbc.shaded.io.grpc.ConnectivityState.TRANSIENT_FAILURE;

import nl.topicus.jdbc.shaded.com.google.common.annotations.VisibleForTesting;
import nl.topicus.jdbc.shaded.io.grpc.Attributes;
import nl.topicus.jdbc.shaded.io.grpc.ConnectivityState;
import nl.topicus.jdbc.shaded.io.grpc.ConnectivityStateInfo;
import nl.topicus.jdbc.shaded.io.grpc.EquivalentAddressGroup;
import nl.topicus.jdbc.shaded.io.grpc.ExperimentalApi;
import nl.topicus.jdbc.shaded.io.grpc.LoadBalancer;
import nl.topicus.jdbc.shaded.io.grpc.LoadBalancer.PickResult;
import nl.topicus.jdbc.shaded.io.grpc.LoadBalancer.PickSubchannelArgs;
import nl.topicus.jdbc.shaded.io.grpc.LoadBalancer.Subchannel;
import nl.topicus.jdbc.shaded.io.grpc.LoadBalancer.SubchannelPicker;
import nl.topicus.jdbc.shaded.io.grpc.Metadata;
import nl.topicus.jdbc.shaded.io.grpc.Metadata.Key;
import nl.topicus.jdbc.shaded.io.grpc.NameResolver;
import nl.topicus.jdbc.shaded.io.grpc.Status;
import nl.topicus.jdbc.shaded.io.grpc.internal.GrpcAttributes;
import nl.topicus.jdbc.shaded.io.grpc.internal.ServiceConfigUtil;
import java.util.ArrayList;
import java.util.Collection;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Set;
import java.util.concurrent.atomic.AtomicIntegerFieldUpdater;
import java.util.logging.Level;
import java.util.logging.Logger;
import nl.topicus.jdbc.shaded.javax.annotation.Nonnull;
import nl.topicus.jdbc.shaded.javax.annotation.Nullable;

/**
 * A {@link LoadBalancer} that provides round-robin load balancing mechanism over the
 * addresses from the {@link NameResolver}.  The sub-lists received from the name resolver
 * are considered to be an {@link EquivalentAddressGroup} and each of these sub-lists is
 * what is then balanced across.
 */
@ExperimentalApi("https://github.com/grpc/grpc-java/issues/1771")
public final class RoundRobinLoadBalancerFactory extends LoadBalancer.Factory {

  private static final RoundRobinLoadBalancerFactory INSTANCE =
      new RoundRobinLoadBalancerFactory();

  private RoundRobinLoadBalancerFactory() {}

  /**
   * A lighter weight Reference than AtomicReference.
   */
  @VisibleForTesting
  static final class Ref {
    T value;

    Ref(T value) {
      this.value = value;
    }
  }

  /**
   * Gets the singleton instance of this factory.
   */
  public static RoundRobinLoadBalancerFactory getInstance() {
    return INSTANCE;
  }

  @Override
  public LoadBalancer newLoadBalancer(LoadBalancer.Helper helper) {
    return new RoundRobinLoadBalancer(helper);
  }

  @VisibleForTesting
  static final class RoundRobinLoadBalancer extends LoadBalancer {
    @VisibleForTesting
    static final Attributes.Key> STATE_INFO =
        Attributes.Key.create("state-info");

    private static final Logger logger = Logger.getLogger(RoundRobinLoadBalancer.class.getName());

    private final Helper helper;
    private final Map subchannels =
        new HashMap();

    @Nullable
    private StickinessState stickinessState;

    RoundRobinLoadBalancer(Helper helper) {
      this.helper = checkNotNull(helper, "helper");
    }

    @Override
    public void handleResolvedAddressGroups(
        List servers, Attributes attributes) {
      Set currentAddrs = subchannels.keySet();
      Set latestAddrs = stripAttrs(servers);
      Set addedAddrs = setsDifference(latestAddrs, currentAddrs);
      Set removedAddrs = setsDifference(currentAddrs, latestAddrs);

      Map serviceConfig =
          attributes.get(GrpcAttributes.NAME_RESOLVER_SERVICE_CONFIG);
      if (serviceConfig != null) {
        String stickinessMetadataKey =
            ServiceConfigUtil.getStickinessMetadataKeyFromServiceConfig(serviceConfig);
        if (stickinessMetadataKey != null) {
          if (stickinessMetadataKey.endsWith(Metadata.BINARY_HEADER_SUFFIX)) {
            logger.log(
                Level.FINE,
                "Binary stickiness header is not supported. The header '{0}' will be ignored",
                stickinessMetadataKey);
          } else if (stickinessState == null
              || !stickinessState.key.name().equals(stickinessMetadataKey)) {
            stickinessState = new StickinessState(stickinessMetadataKey);
          }
        }
      }

      // Create new subchannels for new addresses.
      for (EquivalentAddressGroup addressGroup : addedAddrs) {
        // NB(lukaszx0): we don't merge `attributes` with `subchannelAttr` because subchannel
        // doesn't need them. They're describing the resolved server list but we're not taking
        // any action based on this information.
        Attributes subchannelAttrs = Attributes.newBuilder()
            // NB(lukaszx0): because attributes are immutable we can't set new value for the key
            // after creation but since we can mutate the values we leverge that and set
            // AtomicReference which will allow mutating state info for given channel.
            .set(
                STATE_INFO, new Ref(ConnectivityStateInfo.forNonError(IDLE)))
            .build();

        Subchannel subchannel =
            checkNotNull(helper.createSubchannel(addressGroup, subchannelAttrs), "subchannel");
        subchannels.put(addressGroup, subchannel);
        subchannel.requestConnection();
      }

      // Shutdown subchannels for removed addresses.
      for (EquivalentAddressGroup addressGroup : removedAddrs) {
        Subchannel subchannel = subchannels.remove(addressGroup);
        subchannel.shutdown();
      }

      updateBalancingState(getAggregatedState(), getAggregatedError());
    }

    @Override
    public void handleNameResolutionError(Status error) {
      updateBalancingState(TRANSIENT_FAILURE, error);
    }

    @Override
    public void handleSubchannelState(Subchannel subchannel, ConnectivityStateInfo stateInfo) {
      if (stateInfo.getState() == SHUTDOWN && stickinessState != null) {
        stickinessState.remove(subchannel);
      }
      if (subchannels.get(subchannel.getAddresses()) != subchannel) {
        return;
      }
      if (stateInfo.getState() == IDLE) {
        subchannel.requestConnection();
      }
      getSubchannelStateInfoRef(subchannel).value = stateInfo;
      updateBalancingState(getAggregatedState(), getAggregatedError());
    }

    @Override
    public void shutdown() {
      for (Subchannel subchannel : getSubchannels()) {
        subchannel.shutdown();
      }
    }

    /**
     * Updates picker with the list of active subchannels (state == READY).
     */
    private void updateBalancingState(ConnectivityState state, Status error) {
      List activeList = filterNonFailingSubchannels(getSubchannels());
      helper.updateBalancingState(state, new Picker(activeList, error, stickinessState));
    }

    /**
     * Filters out non-ready subchannels.
     */
    private static List filterNonFailingSubchannels(
        Collection subchannels) {
      List readySubchannels = new ArrayList(subchannels.size());
      for (Subchannel subchannel : subchannels) {
        if (getSubchannelStateInfoRef(subchannel).value.getState() == READY) {
          readySubchannels.add(subchannel);
        }
      }
      return readySubchannels;
    }

    /**
     * Converts list of {@link EquivalentAddressGroup} to {@link EquivalentAddressGroup} set and
     * remove all attributes.
     */
    private static Set stripAttrs(List groupList) {
      Set addrs = new HashSet(groupList.size());
      for (EquivalentAddressGroup group : groupList) {
        addrs.add(new EquivalentAddressGroup(group.getAddresses()));
      }
      return addrs;
    }

    /**
     * If all subchannels are TRANSIENT_FAILURE, return the Status associated with an arbitrary
     * subchannel otherwise, return null.
     */
    @Nullable
    private Status getAggregatedError() {
      Status status = null;
      for (Subchannel subchannel : getSubchannels()) {
        ConnectivityStateInfo stateInfo = getSubchannelStateInfoRef(subchannel).value;
        if (stateInfo.getState() != TRANSIENT_FAILURE) {
          return null;
        }
        status = stateInfo.getStatus();
      }
      return status;
    }

    private ConnectivityState getAggregatedState() {
      Set states = EnumSet.noneOf(ConnectivityState.class);
      for (Subchannel subchannel : getSubchannels()) {
        states.add(getSubchannelStateInfoRef(subchannel).value.getState());
      }
      if (states.contains(READY)) {
        return READY;
      }
      if (states.contains(CONNECTING)) {
        return CONNECTING;
      }
      if (states.contains(IDLE)) {
        // This subchannel IDLE is not because of channel IDLE_TIMEOUT, in which case LB is already
        // shutdown.
        // RRLB will request connection immediately on subchannel IDLE.
        return CONNECTING;
      }
      return TRANSIENT_FAILURE;
    }

    @VisibleForTesting
    Collection getSubchannels() {
      return subchannels.values();
    }

    private static Ref getSubchannelStateInfoRef(
        Subchannel subchannel) {
      return checkNotNull(subchannel.getAttributes().get(STATE_INFO), "STATE_INFO");
    }

    private static  Set setsDifference(Set a, Set b) {
      Set aCopy = new HashSet(a);
      aCopy.removeAll(b);
      return aCopy;
    }

    Map> getStickinessMapForTest() {
      if (stickinessState == null) {
        return null;
      }
      return stickinessState.stickinessMap;
    }

    /**
     * Holds stickiness related states: The stickiness key, a registry mapping stickiness values to
     * the associated Subchannel Ref, and a map from Subchannel to Subchannel Ref.
     */
    private static final class StickinessState {
      static final int MAX_ENTRIES = 1000;

      final Key key;
      final Map> stickinessMap =
          new LinkedHashMap>() {
            @Override
            protected boolean removeEldestEntry(Map.Entry> eldest) {
              return size() > MAX_ENTRIES;
            }
          };

      final Map> subchannelRefs =
          new HashMap>();

      StickinessState(@Nonnull String stickinessKey) {
        this.key = Key.of(stickinessKey, Metadata.ASCII_STRING_MARSHALLER);
      }

      /**
       * Returns the subchannel asscoicated to the stickiness value if available in both the
       * registry and the round robin list, otherwise associates the given subchannel with the
       * stickiness key in the registry and returns the given subchannel.
       */
      @Nonnull
      synchronized Subchannel maybeRegister(
          String stickinessValue, @Nonnull Subchannel subchannel, List rrList) {
        Subchannel existingSubchannel = getSubchannel(stickinessValue);
        if (existingSubchannel != null && rrList.contains(existingSubchannel)) {
          return existingSubchannel;
        }

        Ref subchannelRef = subchannelRefs.get(subchannel);
        if (subchannelRef == null) {
          subchannelRef = new Ref(subchannel);
          subchannelRefs.put(subchannel, subchannelRef);
        }
        stickinessMap.put(stickinessValue, subchannelRef);
        return subchannel;
      }

      /**
       * Unregister the subchannel from StickinessState.
       */
      synchronized void remove(Subchannel subchannel) {
        if (subchannelRefs.containsKey(subchannel)) {
          subchannelRefs.get(subchannel).value = null;
          subchannelRefs.remove(subchannel);
        }
      }

      /**
       * Gets the subchannel associated with the stickiness value if there is.
       */
      @Nullable
      synchronized Subchannel getSubchannel(String stickinessValue) {
        Ref subchannelRef = stickinessMap.get(stickinessValue);
        if (subchannelRef != null) {
          return subchannelRef.value;
        }
        return null;
      }
    }
  }

  @VisibleForTesting
  static final class Picker extends SubchannelPicker {
    private static final AtomicIntegerFieldUpdater indexUpdater =
        AtomicIntegerFieldUpdater.newUpdater(Picker.class, "index");

    @Nullable
    private final Status status;
    private final List list;
    @Nullable
    private final RoundRobinLoadBalancer.StickinessState stickinessState;
    @SuppressWarnings("unused")
    private volatile int index = -1; // start off at -1 so the address on first use is 0.

    Picker(
        List list, @Nullable Status status,
        @Nullable RoundRobinLoadBalancer.StickinessState stickinessState) {
      this.list = list;
      this.status = status;
      this.stickinessState = stickinessState;
    }

    @Override
    public PickResult pickSubchannel(PickSubchannelArgs args) {
      if (list.size() > 0) {
        if (stickinessState != null && args.getHeaders().containsKey(stickinessState.key)) {
          String stickinessValue = args.getHeaders().get(stickinessState.key);
          Subchannel subchannel = stickinessState.getSubchannel(stickinessValue);
          if (subchannel == null || !list.contains(subchannel)) {
            subchannel = stickinessState.maybeRegister(stickinessValue, nextSubchannel(), list);
          }
          return PickResult.withSubchannel(subchannel);
        }

        return PickResult.withSubchannel(nextSubchannel());
      }

      if (status != null) {
        return PickResult.withError(status);
      }

      return PickResult.withNoResult();
    }

    private Subchannel nextSubchannel() {
      if (list.isEmpty()) {
        throw new NoSuchElementException();
      }
      int size = list.size();

      int i = indexUpdater.incrementAndGet(this);
      if (i >= size) {
        int oldi = i;
        i %= size;
        indexUpdater.compareAndSet(this, oldi, i);
      }
      return list.get(i);
    }

    @VisibleForTesting
    List getList() {
      return list;
    }

    @VisibleForTesting
    Status getStatus() {
      return status;
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy