All Downloads are FREE. Search and download functionalities are using the official Maven repository.

nl.topicus.jdbc.shaded.io.grpc.util.RoundRobinLoadBalancerFactory Maven / Gradle / Ivy

/*
 * Copyright 2016, gRPC Authors All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package nl.topicus.jdbc.shaded.io.grpc.util;

import static nl.topicus.jdbc.shaded.com.google.common.base.Preconditions.checkNotNull;
import static nl.topicus.jdbc.shaded.io.grpc.ConnectivityState.CONNECTING;
import static nl.topicus.jdbc.shaded.io.grpc.ConnectivityState.IDLE;
import static nl.topicus.jdbc.shaded.io.grpc.ConnectivityState.READY;
import static nl.topicus.jdbc.shaded.io.grpc.ConnectivityState.TRANSIENT_FAILURE;

import nl.topicus.jdbc.shaded.com.google.common.annotations.VisibleForTesting;
import nl.topicus.jdbc.shaded.io.grpc.Attributes;
import nl.topicus.jdbc.shaded.io.grpc.ConnectivityState;
import nl.topicus.jdbc.shaded.io.grpc.ConnectivityStateInfo;
import nl.topicus.jdbc.shaded.io.grpc.EquivalentAddressGroup;
import nl.topicus.jdbc.shaded.io.grpc.ExperimentalApi;
import nl.topicus.jdbc.shaded.io.grpc.LoadBalancer;
import nl.topicus.jdbc.shaded.io.grpc.LoadBalancer.PickResult;
import nl.topicus.jdbc.shaded.io.grpc.LoadBalancer.PickSubchannelArgs;
import nl.topicus.jdbc.shaded.io.grpc.LoadBalancer.Subchannel;
import nl.topicus.jdbc.shaded.io.grpc.LoadBalancer.SubchannelPicker;
import nl.topicus.jdbc.shaded.io.grpc.NameResolver;
import nl.topicus.jdbc.shaded.io.grpc.Status;
import java.util.ArrayList;
import java.util.Collection;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Set;
import java.util.concurrent.atomic.AtomicIntegerFieldUpdater;
import nl.topicus.jdbc.shaded.javax.annotation.Nullable;

/**
 * A {@link LoadBalancer} that provides round-robin load balancing mechanism over the
 * addresses from the {@link NameResolver}.  The sub-lists received from the name resolver
 * are considered to be an {@link EquivalentAddressGroup} and each of these sub-lists is
 * what is then balanced across.
 */
@ExperimentalApi("https://github.com/grpc/grpc-java/issues/1771")
public final class RoundRobinLoadBalancerFactory extends LoadBalancer.Factory {

  private static final RoundRobinLoadBalancerFactory INSTANCE =
      new RoundRobinLoadBalancerFactory();

  private RoundRobinLoadBalancerFactory() {}

  /**
   * A lighter weight Reference than AtomicReference.
   */
  @VisibleForTesting
  static final class Ref {
    T value;

    Ref(T value) {
      this.value = value;
    }
  }

  /**
   * Gets the singleton instance of this factory.
   */
  public static RoundRobinLoadBalancerFactory getInstance() {
    return INSTANCE;
  }

  @Override
  public LoadBalancer newLoadBalancer(LoadBalancer.Helper helper) {
    return new RoundRobinLoadBalancer(helper);
  }

  @VisibleForTesting
  static final class RoundRobinLoadBalancer extends LoadBalancer {
    @VisibleForTesting
    static final Attributes.Key> STATE_INFO =
        Attributes.Key.of("state-info");

    private final Helper helper;
    private final Map subchannels =
        new HashMap();

    RoundRobinLoadBalancer(Helper helper) {
      this.helper = checkNotNull(helper, "helper");
    }

    @Override
    public void handleResolvedAddressGroups(
        List servers, Attributes attributes) {
      Set currentAddrs = subchannels.keySet();
      Set latestAddrs = stripAttrs(servers);
      Set addedAddrs = setsDifference(latestAddrs, currentAddrs);
      Set removedAddrs = setsDifference(currentAddrs, latestAddrs);

      // Create new subchannels for new addresses.
      for (EquivalentAddressGroup addressGroup : addedAddrs) {
        // NB(lukaszx0): we don't merge `attributes` with `subchannelAttr` because subchannel
        // doesn't need them. They're describing the resolved server list but we're not taking
        // any action based on this information.
        Attributes subchannelAttrs = Attributes.newBuilder()
            // NB(lukaszx0): because attributes are immutable we can't set new value for the key
            // after creation but since we can mutate the values we leverge that and set
            // AtomicReference which will allow mutating state info for given channel.
            .set(
                STATE_INFO, new Ref(ConnectivityStateInfo.forNonError(IDLE)))
            .build();

        Subchannel subchannel =
            checkNotNull(helper.createSubchannel(addressGroup, subchannelAttrs), "subchannel");
        subchannels.put(addressGroup, subchannel);
        subchannel.requestConnection();
      }

      // Shutdown subchannels for removed addresses.
      for (EquivalentAddressGroup addressGroup : removedAddrs) {
        Subchannel subchannel = subchannels.remove(addressGroup);
        subchannel.shutdown();
      }

      updateBalancingState(getAggregatedState(), getAggregatedError());
    }

    @Override
    public void handleNameResolutionError(Status error) {
      updateBalancingState(TRANSIENT_FAILURE, error);
    }

    @Override
    public void handleSubchannelState(Subchannel subchannel, ConnectivityStateInfo stateInfo) {
      if (subchannels.get(subchannel.getAddresses()) != subchannel) {
        return;
      }
      if (stateInfo.getState() == IDLE) {
        subchannel.requestConnection();
      }
      getSubchannelStateInfoRef(subchannel).value = stateInfo;
      updateBalancingState(getAggregatedState(), getAggregatedError());
    }

    @Override
    public void shutdown() {
      for (Subchannel subchannel : getSubchannels()) {
        subchannel.shutdown();
      }
    }

    /**
     * Updates picker with the list of active subchannels (state == READY).
     */
    private void updateBalancingState(ConnectivityState state, Status error) {
      List activeList = filterNonFailingSubchannels(getSubchannels());
      helper.updateBalancingState(state, new Picker(activeList, error));
    }

    /**
     * Filters out non-ready subchannels.
     */
    private static List filterNonFailingSubchannels(
        Collection subchannels) {
      List readySubchannels = new ArrayList(subchannels.size());
      for (Subchannel subchannel : subchannels) {
        if (getSubchannelStateInfoRef(subchannel).value.getState() == READY) {
          readySubchannels.add(subchannel);
        }
      }
      return readySubchannels;
    }

    /**
     * Converts list of {@link EquivalentAddressGroup} to {@link EquivalentAddressGroup} set and
     * remove all attributes.
     */
    private static Set stripAttrs(List groupList) {
      Set addrs = new HashSet(groupList.size());
      for (EquivalentAddressGroup group : groupList) {
        addrs.add(new EquivalentAddressGroup(group.getAddresses()));
      }
      return addrs;
    }

    /**
     * If all subchannels are TRANSIENT_FAILURE, return the Status associated with an arbitrary
     * subchannel otherwise, return null.
     */
    @Nullable
    private Status getAggregatedError() {
      Status status = null;
      for (Subchannel subchannel : getSubchannels()) {
        ConnectivityStateInfo stateInfo = getSubchannelStateInfoRef(subchannel).value;
        if (stateInfo.getState() != TRANSIENT_FAILURE) {
          return null;
        }
        status = stateInfo.getStatus();
      }
      return status;
    }

    private ConnectivityState getAggregatedState() {
      Set states = EnumSet.noneOf(ConnectivityState.class);
      for (Subchannel subchannel : getSubchannels()) {
        states.add(getSubchannelStateInfoRef(subchannel).value.getState());
      }
      if (states.contains(READY)) {
        return READY;
      }
      if (states.contains(CONNECTING)) {
        return CONNECTING;
      }
      if (states.contains(IDLE)) {
        // This subchannel IDLE is not because of channel IDLE_TIMEOUT, in which case LB is already
        // shutdown.
        // RRLB will request connection immediately on subchannel IDLE.
        return CONNECTING;
      }
      return TRANSIENT_FAILURE;
    }

    @VisibleForTesting
    Collection getSubchannels() {
      return subchannels.values();
    }

    private static Ref getSubchannelStateInfoRef(
        Subchannel subchannel) {
      return checkNotNull(subchannel.getAttributes().get(STATE_INFO), "STATE_INFO");
    }

    private static  Set setsDifference(Set a, Set b) {
      Set aCopy = new HashSet(a);
      aCopy.removeAll(b);
      return aCopy;
    }
  }

  @VisibleForTesting
  static final class Picker extends SubchannelPicker {
    private static final AtomicIntegerFieldUpdater indexUpdater =
        AtomicIntegerFieldUpdater.newUpdater(Picker.class, "index");

    @Nullable
    private final Status status;
    private final List list;
    @SuppressWarnings("unused")
    private volatile int index = -1; // start off at -1 so the address on first use is 0.

    Picker(List list, @Nullable Status status) {
      this.list = list;
      this.status = status;
    }

    @Override
    public PickResult pickSubchannel(PickSubchannelArgs args) {
      if (list.size() > 0) {
        return PickResult.withSubchannel(nextSubchannel());
      }

      if (status != null) {
        return PickResult.withError(status);
      }

      return PickResult.withNoResult();
    }

    private Subchannel nextSubchannel() {
      if (list.isEmpty()) {
        throw new NoSuchElementException();
      }
      int size = list.size();

      int i = indexUpdater.incrementAndGet(this);
      if (i >= size) {
        int oldi = i;
        i %= size;
        indexUpdater.compareAndSet(this, oldi, i);
      }
      return list.get(i);
    }

    @VisibleForTesting
    List getList() {
      return list;
    }

    @VisibleForTesting
    Status getStatus() {
      return status;
    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy