All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.arpnetworking.utility.ParallelLeastShardAllocationStrategy Maven / Gradle / Ivy

/*
 * Copyright 2015 Groupon.com
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.arpnetworking.utility;

import com.arpnetworking.steno.Logger;
import com.arpnetworking.steno.LoggerFactory;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import org.apache.pekko.actor.ActorRef;
import org.apache.pekko.actor.ActorSelection;
import org.apache.pekko.cluster.sharding.ShardCoordinator;
import org.apache.pekko.dispatch.Futures;
import scala.collection.immutable.IndexedSeq;
import scala.concurrent.Future;
import scala.jdk.javaapi.CollectionConverters;

import java.io.Serializable;
import java.time.Instant;
import java.util.Collections;
import java.util.Comparator;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.TreeSet;
import java.util.stream.Collectors;

/**
 * Implementation of the least shard allocation strategy that seeks to parallelize shard rebalancing.
 *
 * @author Brandon Arp (brandon dot arp at inscopemetrics dot com)
 */
public final class ParallelLeastShardAllocationStrategy extends ShardCoordinator.AbstractShardAllocationStrategy {

    /**
     * Public constructor.
     *
     * @param maxParallel number of allocations to start in parallel
     * @param rebalanceThreshold difference in number of shards required to cause a rebalance
     * @param notify the {@link org.apache.pekko.actor.ActorSelection} selection to notify of changes
     */
    public ParallelLeastShardAllocationStrategy(
            final int maxParallel,
            final int rebalanceThreshold,
            final Optional notify) {
        _maxParallel = maxParallel;
        _rebalanceThreshold = rebalanceThreshold;
        _notify = notify;
    }

    @Override
    public Future allocateShard(
            final ActorRef requester,
            final String shardId,
            final Map> currentShardAllocations) {
        // If we already decided where this goes, return the destination
        if (_pendingRebalances.containsKey(shardId)) {
            return Futures.successful(_pendingRebalances.get(shardId));
        }

        // Otherwise default to giving it to the shard with the least amount of shards
        return Futures.successful(currentShardAllocations
                .entrySet()
                .stream()
                .min(Comparator.comparingInt(e -> e.getValue().size()))
                .get()
                .getKey());
    }

    @Override
    public Future> rebalance(
            final Map> currentShardAllocations,
            final Set rebalanceInProgress) {
        // Only keep the rebalances that are in progress
        _pendingRebalances.keySet().retainAll(rebalanceInProgress);

        // Build a friendly set of current allocations
        // Sort the set by "effective shards after rebalance"
        final TreeSet allocations =
                new TreeSet<>(Comparator.comparingInt(RegionShardAllocations::getEffectiveShardCount));

        for (final Map.Entry> entry : currentShardAllocations.entrySet()) {
            allocations.add(
                    new RegionShardAllocations(
                            entry.getKey(),
                            // Only count the shards that are not currently rebalancing
                            CollectionConverters.asJava(entry.getValue().toSet())
                                    .stream()
                                    .filter(e -> !rebalanceInProgress.contains(e))
                                    .collect(Collectors.toSet())));
        }

        final Set toRebalance = Sets.newHashSet();

        for (int x = 0; x < _maxParallel - rebalanceInProgress.size(); x++) {
            // Note: the poll* functions remove the item from the set
            final RegionShardAllocations leastShards = allocations.pollFirst();
            final RegionShardAllocations mostShards = allocations.pollLast();


            // Make sure that we have more than 1 region
            if (leastShards == null || mostShards == null) {
                LOGGER.trace()
                        .setMessage("Cannot rebalance shards, less than 2 shard regions found.")
                        .log();
                break;
            }

            // Make sure that the difference is enough to warrant a rebalance
            if (mostShards.getEffectiveShardCount() - leastShards.getEffectiveShardCount() < _rebalanceThreshold) {
                LOGGER.trace()
                        .setMessage("Not rebalancing any (more) shards, shard region with most shards already balanced with least")
                        .addData("most", mostShards.getEffectiveShardCount())
                        .addData("least", leastShards.getEffectiveShardCount())
                        .addData("rebalanceThreshold", _rebalanceThreshold)
                        .log();
                break;
            }

            final String rebalanceShard = Iterables.get(mostShards.getShards(), 0);

            // Now we take a shard from mostShards and give it to leastShards
            mostShards.removeShard(rebalanceShard);
            leastShards.incrementIncoming();
            toRebalance.add(rebalanceShard);
            _pendingRebalances.put(rebalanceShard, leastShards.getRegion());

            // Put them back in the list with their new counts
            allocations.add(mostShards);
            allocations.add(leastShards);
        }

        // Transform the currentShardAllocations to a Map> from the
        // Scala representation
        final Map> currentAllocations = Maps.transformValues(
                currentShardAllocations,
                e -> Sets.newHashSet(CollectionConverters.asJava(e)));

        final RebalanceNotification notification = new RebalanceNotification(
                currentAllocations,
                rebalanceInProgress,
                _pendingRebalances);
        if (_notify.isPresent()) {
            LOGGER.trace()
                    .setMessage("Broadcasting rebalance info")
                    .addData("target", _notify)
                    .addData("shardAllocations", notification)
                    .log();
            _notify.get().tell(notification, ActorRef.noSender());
        }
        return Futures.successful(toRebalance);
    }

    private Map _pendingRebalances = Maps.newHashMap();

    private final int _maxParallel;
    private final int _rebalanceThreshold;
    private final Optional _notify;

    private static final Logger LOGGER = LoggerFactory.getLogger(ParallelLeastShardAllocationStrategy.class);

    /**
     * Notification message that contains rebalance status.
     *
     * @author Brandon Arp (brandon dot arp at inscopemetrics dot com)
     */
    public static final class RebalanceNotification implements Serializable {
        /**
         * Public constructor.
         *
         * @param currentAllocations current allocations
         * @param inflightRebalances shards that are currently in the process of rebalancing
         * @param pendingRebalances current and pending rebalances and their destination
         */
        public RebalanceNotification(
                final Map> currentAllocations,
                final Set inflightRebalances,
                final Map pendingRebalances) {
            _currentAllocations = ImmutableMap.copyOf(currentAllocations);
            _inflightRebalances = ImmutableSet.copyOf(inflightRebalances);
            _pendingRebalances = ImmutableMap.copyOf(pendingRebalances);
            _timestamp = Instant.now();
        }

        public Map> getCurrentAllocations() {
            return _currentAllocations;
        }

        public Set getInflightRebalances() {
            return _inflightRebalances;
        }

        public Instant getTimestamp() {
            return _timestamp;
        }

        public Map getPendingRebalances() {
            return _pendingRebalances;
        }

        private final ImmutableMap> _currentAllocations;
        private final ImmutableSet _inflightRebalances;
        private final ImmutableMap _pendingRebalances;
        private final Instant _timestamp;

        private static final long serialVersionUID = 1L;
    }

    private static final class RegionShardAllocations {
        private RegionShardAllocations(final ActorRef region, final Set shards) {
            _region = region;
            _shards = Sets.newHashSet(shards);
        }

        public ActorRef getRegion() {
            return _region;
        }

        public Set getShards() {
            return Collections.unmodifiableSet(_shards);
        }

        public int getEffectiveShardCount() {
            return _shards.size() + _incomingShardsCount;
        }

        public void removeShard(final String shard) {
            _shards.remove(shard);
        }

        public void incrementIncoming() {
            _incomingShardsCount++;
        }

        private int _incomingShardsCount = 0;

        private final ActorRef _region;
        private final Set _shards;
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy