com.arpnetworking.utility.ParallelLeastShardAllocationStrategy Maven / Gradle / Ivy
/*
* Copyright 2015 Groupon.com
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.arpnetworking.utility;
import com.arpnetworking.steno.Logger;
import com.arpnetworking.steno.LoggerFactory;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import org.apache.pekko.actor.ActorRef;
import org.apache.pekko.actor.ActorSelection;
import org.apache.pekko.cluster.sharding.ShardCoordinator;
import org.apache.pekko.dispatch.Futures;
import scala.collection.immutable.IndexedSeq;
import scala.concurrent.Future;
import scala.jdk.javaapi.CollectionConverters;
import java.io.Serializable;
import java.time.Instant;
import java.util.Collections;
import java.util.Comparator;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.TreeSet;
import java.util.stream.Collectors;
/**
* Implementation of the least shard allocation strategy that seeks to parallelize shard rebalancing.
*
* @author Brandon Arp (brandon dot arp at inscopemetrics dot com)
*/
public final class ParallelLeastShardAllocationStrategy extends ShardCoordinator.AbstractShardAllocationStrategy {
/**
* Public constructor.
*
* @param maxParallel number of allocations to start in parallel
* @param rebalanceThreshold difference in number of shards required to cause a rebalance
* @param notify the {@link org.apache.pekko.actor.ActorSelection} selection to notify of changes
*/
public ParallelLeastShardAllocationStrategy(
final int maxParallel,
final int rebalanceThreshold,
final Optional notify) {
_maxParallel = maxParallel;
_rebalanceThreshold = rebalanceThreshold;
_notify = notify;
}
@Override
public Future allocateShard(
final ActorRef requester,
final String shardId,
final Map> currentShardAllocations) {
// If we already decided where this goes, return the destination
if (_pendingRebalances.containsKey(shardId)) {
return Futures.successful(_pendingRebalances.get(shardId));
}
// Otherwise default to giving it to the shard with the least amount of shards
return Futures.successful(currentShardAllocations
.entrySet()
.stream()
.min(Comparator.comparingInt(e -> e.getValue().size()))
.get()
.getKey());
}
@Override
public Future> rebalance(
final Map> currentShardAllocations,
final Set rebalanceInProgress) {
// Only keep the rebalances that are in progress
_pendingRebalances.keySet().retainAll(rebalanceInProgress);
// Build a friendly set of current allocations
// Sort the set by "effective shards after rebalance"
final TreeSet allocations =
new TreeSet<>(Comparator.comparingInt(RegionShardAllocations::getEffectiveShardCount));
for (final Map.Entry> entry : currentShardAllocations.entrySet()) {
allocations.add(
new RegionShardAllocations(
entry.getKey(),
// Only count the shards that are not currently rebalancing
CollectionConverters.asJava(entry.getValue().toSet())
.stream()
.filter(e -> !rebalanceInProgress.contains(e))
.collect(Collectors.toSet())));
}
final Set toRebalance = Sets.newHashSet();
for (int x = 0; x < _maxParallel - rebalanceInProgress.size(); x++) {
// Note: the poll* functions remove the item from the set
final RegionShardAllocations leastShards = allocations.pollFirst();
final RegionShardAllocations mostShards = allocations.pollLast();
// Make sure that we have more than 1 region
if (leastShards == null || mostShards == null) {
LOGGER.trace()
.setMessage("Cannot rebalance shards, less than 2 shard regions found.")
.log();
break;
}
// Make sure that the difference is enough to warrant a rebalance
if (mostShards.getEffectiveShardCount() - leastShards.getEffectiveShardCount() < _rebalanceThreshold) {
LOGGER.trace()
.setMessage("Not rebalancing any (more) shards, shard region with most shards already balanced with least")
.addData("most", mostShards.getEffectiveShardCount())
.addData("least", leastShards.getEffectiveShardCount())
.addData("rebalanceThreshold", _rebalanceThreshold)
.log();
break;
}
final String rebalanceShard = Iterables.get(mostShards.getShards(), 0);
// Now we take a shard from mostShards and give it to leastShards
mostShards.removeShard(rebalanceShard);
leastShards.incrementIncoming();
toRebalance.add(rebalanceShard);
_pendingRebalances.put(rebalanceShard, leastShards.getRegion());
// Put them back in the list with their new counts
allocations.add(mostShards);
allocations.add(leastShards);
}
// Transform the currentShardAllocations to a Map> from the
// Scala representation
final Map> currentAllocations = Maps.transformValues(
currentShardAllocations,
e -> Sets.newHashSet(CollectionConverters.asJava(e)));
final RebalanceNotification notification = new RebalanceNotification(
currentAllocations,
rebalanceInProgress,
_pendingRebalances);
if (_notify.isPresent()) {
LOGGER.trace()
.setMessage("Broadcasting rebalance info")
.addData("target", _notify)
.addData("shardAllocations", notification)
.log();
_notify.get().tell(notification, ActorRef.noSender());
}
return Futures.successful(toRebalance);
}
private Map _pendingRebalances = Maps.newHashMap();
private final int _maxParallel;
private final int _rebalanceThreshold;
private final Optional _notify;
private static final Logger LOGGER = LoggerFactory.getLogger(ParallelLeastShardAllocationStrategy.class);
/**
* Notification message that contains rebalance status.
*
* @author Brandon Arp (brandon dot arp at inscopemetrics dot com)
*/
public static final class RebalanceNotification implements Serializable {
/**
* Public constructor.
*
* @param currentAllocations current allocations
* @param inflightRebalances shards that are currently in the process of rebalancing
* @param pendingRebalances current and pending rebalances and their destination
*/
public RebalanceNotification(
final Map> currentAllocations,
final Set inflightRebalances,
final Map pendingRebalances) {
_currentAllocations = ImmutableMap.copyOf(currentAllocations);
_inflightRebalances = ImmutableSet.copyOf(inflightRebalances);
_pendingRebalances = ImmutableMap.copyOf(pendingRebalances);
_timestamp = Instant.now();
}
public Map> getCurrentAllocations() {
return _currentAllocations;
}
public Set getInflightRebalances() {
return _inflightRebalances;
}
public Instant getTimestamp() {
return _timestamp;
}
public Map getPendingRebalances() {
return _pendingRebalances;
}
private final ImmutableMap> _currentAllocations;
private final ImmutableSet _inflightRebalances;
private final ImmutableMap _pendingRebalances;
private final Instant _timestamp;
private static final long serialVersionUID = 1L;
}
private static final class RegionShardAllocations {
private RegionShardAllocations(final ActorRef region, final Set shards) {
_region = region;
_shards = Sets.newHashSet(shards);
}
public ActorRef getRegion() {
return _region;
}
public Set getShards() {
return Collections.unmodifiableSet(_shards);
}
public int getEffectiveShardCount() {
return _shards.size() + _incomingShardsCount;
}
public void removeShard(final String shard) {
_shards.remove(shard);
}
public void incrementIncoming() {
_incomingShardsCount++;
}
private int _incomingShardsCount = 0;
private final ActorRef _region;
private final Set _shards;
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy