com.devsisters.shardcake.ShardManager.scala Maven / Gradle / Ivy
The newest version!
package com.devsisters.shardcake
import com.devsisters.shardcake.ShardManager._
import com.devsisters.shardcake.interfaces._
import java.time.OffsetDateTime
import scala.util.Random
import zio._
import zio.stream.ZStream
import scala.annotation.tailrec
import scala.collection.compat._
/**
* A component in charge of assigning and unassigning shards to/from pods
*/
class ShardManager(
stateRef: Ref.Synchronized[ShardManagerState],
rebalanceSemaphore: Semaphore,
eventsHub: Hub[ShardingEvent],
healthApi: PodsHealth,
podApi: Pods,
stateRepository: Storage,
config: ManagerConfig
) {
def getAssignments: UIO[Map[ShardId, Option[PodAddress]]] =
stateRef.get.map(_.shards)
def getShardingEvents: ZStream[Any, Nothing, ShardingEvent] =
ZStream.fromHub(eventsHub)
def register(pod: Pod): Task[Unit] =
ZIO.ifZIO(healthApi.isAlive(pod.address))(
onTrue = for {
_ <- ZIO.logInfo(s"Registering $pod")
state <- stateRef.updateAndGetZIO(state =>
ZIO
.succeed(OffsetDateTime.now())
.map(cdt => state.copy(pods = state.pods.updated(pod.address, PodWithMetadata(pod, cdt))))
)
_ <- ManagerMetrics.pods.increment
_ <- eventsHub.publish(ShardingEvent.PodRegistered(pod.address))
_ <- ZIO.when(state.unassignedShards.nonEmpty)(rebalance(rebalanceImmediately = false))
_ <- persistPods.forkDaemon
} yield (),
onFalse = ZIO.logWarning(s"Pod $pod requested to register but is not alive, ignoring") *>
ZIO.fail(new RuntimeException(s"Pod $pod is not healthy, refusing to register"))
)
def notifyUnhealthyPod(podAddress: PodAddress, ignoreMetric: Boolean = false): UIO[Unit] =
ZIO
.whenZIODiscard(stateRef.get.map(_.pods.contains(podAddress))) {
ManagerMetrics.podHealthChecked.tagged("pod_address", podAddress.toString).increment.unless(ignoreMetric) *>
eventsHub.publish(ShardingEvent.PodHealthChecked(podAddress)) *>
ZIO.unlessZIO(healthApi.isAlive(podAddress))(
ZIO.logWarning(s"Pod $podAddress is not alive, unregistering") *> unregister(podAddress)
)
}
def checkAllPodsHealth: UIO[Unit] =
for {
pods <- stateRef.get.map(_.pods.keySet)
_ <- ZIO.foreachParDiscard(pods)(notifyUnhealthyPod(_, ignoreMetric = true)).withParallelism(4)
} yield ()
def unregister(podAddress: PodAddress): UIO[Unit] =
ZIO
.whenZIO(stateRef.get.map(_.pods.contains(podAddress))) {
for {
_ <- ZIO.logInfo(s"Unregistering $podAddress")
unassignments <- stateRef.modify { state =>
(
state.shards.collect { case (shard, Some(p)) if p == podAddress => shard }.toSet,
state.copy(
pods = state.pods - podAddress,
shards =
state.shards.map { case (k, v) => k -> (if (v.contains(podAddress)) None else v) }
)
)
}
_ <- ManagerMetrics.pods.decrement
_ <- ManagerMetrics.assignedShards.tagged("pod_address", podAddress.toString).decrementBy(unassignments.size)
_ <- ManagerMetrics.unassignedShards.incrementBy(unassignments.size)
_ <- eventsHub.publish(ShardingEvent.PodUnregistered(podAddress))
_ <- eventsHub
.publish(ShardingEvent.ShardsUnassigned(podAddress, unassignments))
.when(unassignments.nonEmpty)
_ <- persistPods.forkDaemon
_ <- rebalance(rebalanceImmediately = true).forkDaemon
} yield ()
}
.unit
private def rebalance(rebalanceImmediately: Boolean): UIO[Unit] =
rebalanceSemaphore.withPermit {
for {
state <- stateRef.get
// find which shards to assign and unassign
(assignments, unassignments) = if (rebalanceImmediately || state.unassignedShards.nonEmpty)
decideAssignmentsForUnassignedShards(state)
else decideAssignmentsForUnbalancedShards(state, config.rebalanceRate)
areChanges = assignments.nonEmpty || unassignments.nonEmpty
_ <- (ZIO.logDebug(s"Rebalancing (rebalanceImmediately=$rebalanceImmediately)") *>
ManagerMetrics.rebalances.increment).when(areChanges)
// ping pods first to make sure they are ready and remove those who aren't
failedPingedPods <- ZIO
.foreachPar(assignments.keySet ++ unassignments.keySet)(pod =>
podApi
.ping(pod)
.timeout(config.pingTimeout)
.someOrFailException
.fold(_ => Set(pod), _ => Set.empty[PodAddress])
)
.map(_.flatten)
shardsToRemove =
assignments.collect { case (pod, shards) if failedPingedPods.contains(pod) => shards }.toSet.flatten ++
unassignments.collect { case (pod, shards) if failedPingedPods.contains(pod) => shards }.toSet.flatten
readyAssignments = assignments.view.mapValues(_ diff shardsToRemove).filterNot(_._2.isEmpty).toMap
readyUnassignments = unassignments.view.mapValues(_ diff shardsToRemove).filterNot(_._2.isEmpty).toMap
// do the unassignments first
failed <- ZIO
.foreachPar(readyUnassignments.toList) { case (pod, shards) =>
(podApi.unassignShards(pod, shards) *> updateShardsState(shards, None)).foldZIO(
_ => ZIO.succeed((Set(pod), shards)),
_ =>
ManagerMetrics.assignedShards.tagged("pod_address", pod.toString).decrementBy(shards.size) *>
ManagerMetrics.unassignedShards.incrementBy(shards.size) *>
eventsHub
.publish(ShardingEvent.ShardsUnassigned(pod, shards))
.as((Set.empty, Set.empty))
)
}
.map(_.unzip)
.map { case (pods, shards) => (pods.flatten[PodAddress].toSet, shards.flatten[ShardId].toSet) }
(failedUnassignedPods, failedUnassignedShards) = failed
// remove assignments of shards that couldn't be unassigned, as well as faulty pods
filteredAssignments = (readyAssignments -- failedUnassignedPods).map { case (pod, shards) =>
pod -> (shards diff failedUnassignedShards)
}
// then do the assignments
failedAssignedPods <- ZIO
.foreachPar(filteredAssignments.toList) { case (pod, shards) =>
(podApi.assignShards(pod, shards) *> updateShardsState(shards, Some(pod))).foldZIO(
_ => ZIO.succeed(Set(pod)),
_ =>
ManagerMetrics.assignedShards
.tagged("pod_address", pod.toString)
.incrementBy(shards.size) *>
ManagerMetrics.unassignedShards.decrementBy(shards.size) *>
eventsHub.publish(ShardingEvent.ShardsAssigned(pod, shards)).as(Set.empty)
)
}
.map(_.flatten[PodAddress].toSet)
failedPods = failedPingedPods ++ failedUnassignedPods ++ failedAssignedPods
// check if failing pods are still up
_ <- ZIO.foreachDiscard(failedPods)(notifyUnhealthyPod(_)).forkDaemon
_ <- ZIO.logWarning(s"Failed to rebalance pods: $failedPods").when(failedPods.nonEmpty)
// retry rebalancing later if there was any failure
_ <- (Clock.sleep(config.rebalanceRetryInterval) *> rebalance(rebalanceImmediately)).forkDaemon
.when(failedPods.nonEmpty && rebalanceImmediately)
// persist state changes to Redis
_ <- persistAssignments.forkDaemon.when(areChanges)
} yield ()
}
private def withRetry[E, A](zio: IO[E, A]): UIO[Unit] =
zio
.retry[Any, Any](Schedule.spaced(config.persistRetryInterval) && Schedule.recurs(config.persistRetryCount))
.ignore
private def persistAssignments: UIO[Unit] =
withRetry(
stateRef.get.flatMap(state => stateRepository.saveAssignments(state.shards))
)
private def persistPods: UIO[Unit] =
withRetry(
stateRef.get.flatMap(state => stateRepository.savePods(state.pods.map { case (k, v) => (k, v.pod) }))
)
private def updateShardsState(shards: Set[ShardId], pod: Option[PodAddress]): Task[Unit] =
stateRef.updateZIO(state =>
ZIO
.whenCase(pod) {
case Some(pod) if !state.pods.contains(pod) => ZIO.fail(new Exception(s"Pod $pod is no longer registered"))
}
.as(
state.copy(shards = state.shards.map { case (shard, assignment) =>
shard -> (if (shards.contains(shard)) pod else assignment)
})
)
)
}
object ShardManager {
/**
* A layer that starts the Shard Manager process
*/
val live: ZLayer[PodsHealth with Pods with Storage with ManagerConfig, Throwable, ShardManager] =
ZLayer.scoped {
for {
config <- ZIO.service[ManagerConfig]
stateRepository <- ZIO.service[Storage]
healthApi <- ZIO.service[PodsHealth]
podApi <- ZIO.service[Pods]
pods <- stateRepository.getPods
assignments <- stateRepository.getAssignments
// remove unhealthy pods on startup
failedFilteredPods <-
ZIO.partitionPar(pods) { addrPod =>
ZIO.ifZIO(healthApi.isAlive(addrPod._1))(ZIO.succeed(addrPod), ZIO.fail(addrPod._2))
}
(failedPods, filtered) = failedFilteredPods
_ <- ZIO.when(failedPods.nonEmpty)(
ZIO.logInfo(s"Ignoring pods that are no longer alive ${failedPods.mkString("[", ", ", "]")}")
)
filteredPods = filtered.toMap
failedFilteredAssignments = partitionMap(assignments) {
case assignment @ (_, Some(address)) if filteredPods.contains(address) =>
Right(assignment)
case assignment => Left(assignment)
}
(failed, filteredAssignments) = failedFilteredAssignments
failedAssignments = failed.collect { case (shard, Some(addr)) => shard -> addr }
_ <- ZIO.when(failedAssignments.nonEmpty)(
ZIO.logWarning(
s"Ignoring assignments for pods that are no longer alive ${failedAssignments.mkString("[", ", ", "]")}"
)
)
cdt <- ZIO.succeed(OffsetDateTime.now())
initialState = ShardManagerState(
filteredPods.map { case (k, v) => k -> PodWithMetadata(v, cdt) },
(1 to config.numberOfShards).map(_ -> None).toMap ++ filteredAssignments
)
_ <- ZIO.logInfo(
s"Recovered pods ${filteredPods
.mkString("[", ", ", "]")} and assignments ${filteredAssignments.mkString("[", ", ", "]")}"
)
_ <- ManagerMetrics.pods.incrementBy(initialState.pods.size)
_ <- ZIO.foreachDiscard(initialState.shards) { case (_, podAddressOpt) =>
podAddressOpt match {
case Some(podAddress) =>
ManagerMetrics.assignedShards.tagged("pod_address", podAddress.toString).increment
case None =>
ManagerMetrics.unassignedShards.increment
}
}
state <- Ref.Synchronized.make(initialState)
rebalanceSemaphore <- Semaphore.make(1)
eventsHub <- Hub.unbounded[ShardingEvent]
shardManager =
new ShardManager(state, rebalanceSemaphore, eventsHub, healthApi, podApi, stateRepository, config)
_ <- ZIO.addFinalizer {
shardManager.persistAssignments.catchAllCause(cause =>
ZIO.logWarningCause("Failed to persist assignments on shutdown", cause)
) *>
shardManager.persistPods.catchAllCause(cause =>
ZIO.logWarningCause("Failed to persist pods on shutdown", cause)
)
}
_ <- shardManager.persistPods.forkDaemon
// rebalance immediately if there are unassigned shards
_ <- shardManager.rebalance(rebalanceImmediately = initialState.unassignedShards.nonEmpty).forkDaemon
// start a regular rebalance at the given interval
_ <- shardManager
.rebalance(rebalanceImmediately = false)
.repeat(Schedule.spaced(config.rebalanceInterval))
.forkDaemon
_ <- shardManager.getShardingEvents.mapZIO(event => ZIO.logInfo(event.toString)).runDrain.forkDaemon
_ <- shardManager.checkAllPodsHealth.repeat(Schedule.spaced(config.podHealthCheckInterval)).forkDaemon
_ <- ZIO.logInfo("Shard Manager loaded")
} yield shardManager
}
// reimplement Map.partitionMap because it does not exist in 2.12
private def partitionMap[K, V, VL <: V, VR <: V](map: Map[K, V])(partition: ((K, V)) => Either[(K, VL), (K, VR)]) = {
val left = Map.newBuilder[K, VL]
val right = Map.newBuilder[K, VR]
map.iterator.foreach { kv =>
partition(kv) match {
case Left(kvl) => left += kvl
case Right(kvr) => right += kvr
}
}
(left.result(), right.result())
}
implicit def listOrder[A](implicit ev: Ordering[A]): Ordering[List[A]] = (xs: List[A], ys: List[A]) => {
@tailrec def loop(xs: List[A], ys: List[A]): Int =
xs match {
case Nil =>
if (ys.isEmpty) 0 else -1
case x :: xs =>
ys match {
case Nil => 1
case y :: ys =>
val n = ev.compare(x, y)
if (n != 0) n else loop(xs, ys)
}
}
if (xs eq ys) 0 else loop(xs, ys)
}
case class ShardManagerState(pods: Map[PodAddress, PodWithMetadata], shards: Map[ShardId, Option[PodAddress]]) {
lazy val unassignedShards: Set[ShardId] = shards.collect { case (k, None) => k }.toSet
lazy val averageShardsPerPod: ShardId = if (pods.nonEmpty) shards.size / pods.size else 0
private lazy val podVersions = pods.values.toList.map(extractVersion)
lazy val maxVersion: Option[List[ShardId]] = podVersions.maxOption
lazy val allPodsHaveMaxVersion: Boolean = podVersions.forall(maxVersion.contains)
lazy val shardsPerPod: Map[PodAddress, Set[ShardId]] =
pods.map { case (k, _) => k -> Set.empty[ShardId] } ++
shards.groupBy(_._2).collect { case (Some(address), shards) => address -> shards.keySet }
}
case class PodWithMetadata(pod: Pod, registered: OffsetDateTime)
sealed trait ShardingEvent
object ShardingEvent {
case class ShardsAssigned(pod: PodAddress, shards: Set[ShardId]) extends ShardingEvent {
override def toString: String = s"ShardsAssigned(pod=$pod, shards=${renderShardIds(shards)})"
}
case class ShardsUnassigned(pod: PodAddress, shards: Set[ShardId]) extends ShardingEvent {
override def toString: String = s"ShardsUnassigned(pod=$pod, shards=${renderShardIds(shards)})"
}
case class PodRegistered(pod: PodAddress) extends ShardingEvent
case class PodUnregistered(pod: PodAddress) extends ShardingEvent
case class PodHealthChecked(pod: PodAddress) extends ShardingEvent
}
def decideAssignmentsForUnassignedShards(
state: ShardManagerState
): (Map[PodAddress, Set[ShardId]], Map[PodAddress, Set[ShardId]]) =
pickNewPods(state.unassignedShards.toList, state, rebalanceImmediately = true, 1.0)
def decideAssignmentsForUnbalancedShards(
state: ShardManagerState,
rebalanceRate: Double
): (Map[PodAddress, Set[ShardId]], Map[PodAddress, Set[ShardId]]) = {
val extraShardsToAllocate =
if (state.allPodsHaveMaxVersion) { // don't do regular rebalance in the middle of a rolling update
state.shardsPerPod.flatMap { case (_, shards) =>
// count how many extra shards compared to the average
val extraShards = (shards.size - state.averageShardsPerPod).max(0)
Random.shuffle(shards).take(extraShards)
}.toSet
} else Set.empty
val sortedShardsToRebalance = extraShardsToAllocate.toList.sortBy { shard =>
// handle unassigned shards first, then shards on the pods with most shards, then shards on old pods
state.shards.get(shard).flatten.fold((Int.MinValue, OffsetDateTime.MIN)) { pod =>
(
state.shardsPerPod.get(pod).fold(Int.MinValue)(-_.size),
state.pods.get(pod).fold(OffsetDateTime.MIN)(_.registered)
)
}
}
pickNewPods(sortedShardsToRebalance, state, rebalanceImmediately = false, rebalanceRate)
}
private def pickNewPods(
shardsToRebalance: List[ShardId],
state: ShardManagerState,
rebalanceImmediately: Boolean,
rebalanceRate: Double
): (Map[PodAddress, Set[ShardId]], Map[PodAddress, Set[ShardId]]) = {
val (_, assignments) = shardsToRebalance.foldLeft((state.shardsPerPod, List.empty[(ShardId, PodAddress)])) {
case ((shardsPerPod, assignments), shard) =>
val unassignedPods = assignments.flatMap { case (shard, _) =>
state.shards.get(shard).flatten[PodAddress]
}.toSet
// find pod with least amount of shards
shardsPerPod
// keep only pods with the max version
.filter { case (pod, _) =>
state.maxVersion.forall(max => state.pods.get(pod).map(extractVersion).forall(_ == max))
}
// don't assign too many shards to the same pods, unless we need rebalance immediately
.filter { case (pod, _) =>
rebalanceImmediately || assignments.count { case (_, p) => p == pod } < state.shards.size * rebalanceRate
}
// don't assign to a pod that was unassigned in the same rebalance
.filterNot { case (pod, _) => unassignedPods.contains(pod) }
.minByOption(_._2.size) match {
case Some((pod, shards)) =>
val oldPod = state.shards.get(shard).flatten
// if old pod is same as new pod, don't change anything
if (oldPod.contains(pod))
(shardsPerPod, assignments)
// if the new pod has more, as much, or only 1 less shard than the old pod, don't change anything
else if (
shardsPerPod.get(pod).fold(0)(_.size) + 1 >= oldPod.fold(Int.MaxValue)(
shardsPerPod.getOrElse(_, Nil).size
)
)
(shardsPerPod, assignments)
// otherwise, create a new assignment
else {
val unassigned = oldPod.fold(shardsPerPod)(oldPod => shardsPerPod.updatedWith(oldPod)(_.map(_ - shard)))
(unassigned.updated(pod, shards + shard), (shard, pod) :: assignments)
}
case None => (shardsPerPod, assignments)
}
}
val unassignments = assignments.flatMap { case (shard, _) => state.shards.get(shard).flatten.map(shard -> _) }
val assignmentsPerPod = assignments.groupBy(_._2).map { case (k, v) => k -> v.map(_._1).toSet }
val unassignmentsPerPod = unassignments.groupBy(_._2).map { case (k, v) => k -> v.map(_._1).toSet }
(assignmentsPerPod, unassignmentsPerPod)
}
private def extractVersion(pod: PodWithMetadata): List[Int] =
pod.pod.version.split("[.-]").toList.flatMap(_.toIntOption)
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy