
in.ashwanthkumar.suuchi.router.ReplicationRouter.scala Maven / Gradle / Ivy
The newest version!
package in.ashwanthkumar.suuchi.router
import java.util.concurrent.{Executor, Executors, TimeUnit}
import com.google.common.util.concurrent.{Futures, ListenableFuture}
import in.ashwanthkumar.suuchi.cluster.MemberAddress
import in.ashwanthkumar.suuchi.rpc.CachedChannelPool
import io.grpc.ServerCall.Listener
import io.grpc._
import io.grpc.stub.{ClientCalls, MetadataUtils}
import org.slf4j.LoggerFactory
/**
* Replication Router picks up the set of nodes to which this request needs to be sent to (if not already set)
* and forwards the request to the list of nodes in parallel and waits for all of them to complete
*/
abstract class ReplicationRouter(nrReplicas: Int, self: MemberAddress) extends ServerInterceptor {
me =>
protected val log = LoggerFactory.getLogger(me.getClass)
val channelPool = CachedChannelPool()
override def interceptCall[ReqT, RespT](serverCall: ServerCall[ReqT, RespT],
headers: Metadata,
next: ServerCallHandler[ReqT, RespT]): Listener[ReqT] = {
log.trace(
"Intercepting " + serverCall.getMethodDescriptor.getFullMethodName + " method in " + self)
val replicator = this
new Listener[ReqT] {
var forwarded = false
val delegate = next.startCall(serverCall, headers)
override def onReady(): Unit =
wrapWithContextValue(Headers.PRIMARY_NODE_REQUEST_CTX, isPrimaryNode(headers)) {
delegate.onReady()
}
override def onMessage(incomingRequest: ReqT): Unit = {
log.trace("onMessage in replicator")
wrapWithContextValue(Headers.PRIMARY_NODE_REQUEST_CTX, isPrimaryNode(headers)) {
if (isReplicationRequest(headers)) {
log.debug("Received replication request for {}, processing it", incomingRequest)
delegate.onMessage(incomingRequest)
} else if (headers.containsKey(Headers.ELIGIBLE_NODES_KEY)) {
// since this isn't a replication request - replicate the request to list of nodes as defined in ELIGIBLE_NODES header
val nodes = headers.get(Headers.ELIGIBLE_NODES_KEY)
// if the nodes to replicate contain self disable forwarded in all other cases forwarded is true
// we need this since the default ServerHandler under which the actual delegate is wrapped under
// invokes the method only in onHalfClose and not in onMessage (for non-streaming requests)
forwarded = !nodes.contains(self)
log.trace("Going to replicate the request to {}", nodes)
replicator.replicate(nodes, serverCall, headers, incomingRequest, delegate)
log.trace("Replication complete for {}", incomingRequest)
} else {
log.warn("Ignoring the request since I don't know what to do")
}
}
}
override def onHalfClose(): Unit = {
wrapWithContextValue(Headers.PRIMARY_NODE_REQUEST_CTX, isPrimaryNode(headers)) {
// apparently default ServerCall listener seems to hold some state from OnMessage which fails
// here and client fails with an exception message -- Half-closed without a request
if (forwarded) serverCall.close(Status.OK, headers) else delegate.onHalfClose()
}
}
override def onCancel(): Unit =
wrapWithContextValue(Headers.PRIMARY_NODE_REQUEST_CTX, isPrimaryNode(headers)) {
delegate.onCancel()
}
override def onComplete(): Unit =
wrapWithContextValue(Headers.PRIMARY_NODE_REQUEST_CTX, isPrimaryNode(headers)) {
delegate.onComplete()
}
}
}
private def isPrimaryNode[RespT, ReqT](headers: Metadata) = {
headers.containsKey(Headers.PRIMARY_NODE_KEY) && headers
.get(Headers.PRIMARY_NODE_KEY)
.equals(self)
}
private def wrapWithContextValue[T](ctxKey: Context.Key[T], value: T)(block: => Unit) = {
val previous = Context.current().withValue(ctxKey, value).attach()
try {
block
} finally {
Context.current().detach(previous)
}
}
private def isReplicationRequest[RespT, ReqT](headers: Metadata) = {
headers.containsKey(Headers.REPLICATION_REQUEST_KEY) && headers
.get(Headers.REPLICATION_REQUEST_KEY)
.equals(self.toString)
}
def forward[RespT, ReqT](methodDescriptor: MethodDescriptor[ReqT, RespT],
headers: Metadata,
incomingRequest: ReqT,
destination: MemberAddress): Any = {
// Add HEADER to signify that this is a REPLICATION_REQUEST
headers.put(Headers.REPLICATION_REQUEST_KEY, destination.toString)
val channel = channelPool.get(destination, insecure = true)
ClientCalls.blockingUnaryCall(
ClientInterceptors.interceptForward(channel,
MetadataUtils.newAttachHeadersInterceptor(headers)),
methodDescriptor,
CallOptions.DEFAULT
.withDeadlineAfter(10, TimeUnit.MINUTES), // TODO (ashwanthkumar): Make this deadline configurable
incomingRequest
)
}
def forwardAsync[RespT, ReqT](methodDescriptor: MethodDescriptor[ReqT, RespT],
headers: Metadata,
incomingRequest: ReqT,
destination: MemberAddress): ListenableFuture[RespT] = {
// Add HEADER to signify that this is a REPLICATION_REQUEST
headers.put(Headers.REPLICATION_REQUEST_KEY, destination.toString)
val channel = channelPool.get(destination, insecure = true)
val clientCall = ClientInterceptors
.interceptForward(channel, MetadataUtils.newAttachHeadersInterceptor(headers))
.newCall(methodDescriptor, CallOptions.DEFAULT.withDeadlineAfter(10, TimeUnit.MINUTES)) // TODO (ashwanthkumar): Make this deadline configurable
ClientCalls.futureUnaryCall(clientCall, incomingRequest)
}
/**
* Subclasses can choose to implement on how they want to replicate.
*
* See [[SequentialReplicator]] for usage.
*/
def replicate[ReqT, RespT](eligibleNodes: List[MemberAddress],
serverCall: ServerCall[ReqT, RespT],
headers: Metadata,
incomingRequest: ReqT,
delegate: ServerCall.Listener[ReqT]): Unit = {
eligibleNodes match {
case nodes if nodes.size < nrReplicas =>
log.warn(
"We don't have enough nodes to satisfy the replication factor. Not processing this request")
serverCall.close(
Status.FAILED_PRECONDITION.withDescription(
"We don't have enough nodes to satisfy the replication factor. Not processing this request"),
headers)
case nodes if nodes.nonEmpty =>
log.debug("Replication nodes for {} are {}", incomingRequest, nodes)
doReplication(eligibleNodes, serverCall, headers, incomingRequest, delegate)
case Nil =>
log.error("This should never happen. No nodes found to place replica")
serverCall.close(Status.INTERNAL.withDescription(
"This should never happen. No nodes found to place replica"),
headers)
}
}
/**
* Implement the actual replication logic assuming that you've the right set of nodes.
* Just do it!
*
* Error handling and other scenarios are handled at [[ReplicationRouter.replicate]]
**/
def doReplication[ReqT, RespT](eligibleNodes: List[MemberAddress],
serverCall: ServerCall[ReqT, RespT],
headers: Metadata,
incomingRequest: ReqT,
delegate: Listener[ReqT]): Unit
}
/**
* Sequential Synchronous replication implementation. While replicating we'll issue a forward request to each of the candidate nodes one by one.
*
* @param nrReplicas Number of replicas to keep for the requests
* @param self Reference to [[MemberAddress)]] of the current node
*/
class SequentialReplicator(nrReplicas: Int, self: MemberAddress)
extends ReplicationRouter(nrReplicas, self) {
override def doReplication[ReqT, RespT](eligibleNodes: List[MemberAddress],
serverCall: ServerCall[ReqT, RespT],
headers: Metadata,
incomingRequest: ReqT,
delegate: Listener[ReqT]) = {
log.debug("Sequentially sending out replication requests to the above set of nodes")
val hasLocalMember = eligibleNodes.exists(_.equals(self))
eligibleNodes.filterNot(_.equals(self)).foreach { destination =>
forward(serverCall.getMethodDescriptor, headers, incomingRequest, destination)
}
// we need to push this after the forwarding else we return to client immediately saying we're done
if (hasLocalMember) {
delegate.onMessage(incomingRequest)
}
}
}
/**
* Parallel Synchronous replication implementation. While replicating we'll issue a forward request to all the nodes in
* parallel. Even if one of the node's request fails the entire operation is assumed to have failed.
*
* @param nrReplicas Number of replicas to make
* @param self Reference to [[MemberAddress]] of the current node
*/
class ParallelReplicator(nrReplicas: Int, self: MemberAddress)
extends ReplicationRouter(nrReplicas, self) {
override def doReplication[ReqT, RespT](eligibleNodes: List[MemberAddress],
serverCall: ServerCall[ReqT, RespT],
headers: Metadata,
incomingRequest: ReqT,
delegate: Listener[ReqT]): Unit = {
log.debug("Sending out replication requests to the above set of nodes in parallel")
val hasLocalMember = eligibleNodes.exists(_.equals(self))
val replicationResponses = eligibleNodes.filterNot(_.equals(self)).map { destination =>
forwardAsync(serverCall.getMethodDescriptor, headers, incomingRequest, destination)
}
// Future.sequence equivalent + doing a get to ensure all operations complete
log.debug("Waiting for replication response from replica nodes")
Futures.allAsList(replicationResponses: _*).get()
// we need to push this after the forwarding else we return to client immediately saying we're done
if (hasLocalMember) {
delegate.onMessage(incomingRequest)
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy