All Downloads are FREE. Search and download functionalities are using the official Maven repository.

in.ashwanthkumar.suuchi.router.ReplicationRouter.scala Maven / Gradle / Ivy

There is a newer version: 0.3.5
Show newest version
package in.ashwanthkumar.suuchi.router

import java.util.concurrent.{Executors, Executor}

import com.google.common.util.concurrent.{Futures, ListenableFuture}
import in.ashwanthkumar.suuchi.membership.MemberAddress
import io.grpc.ServerCall.Listener
import io.grpc._
import io.grpc.netty.NettyChannelBuilder
import io.grpc.stub.{ClientCalls, MetadataUtils}
import org.slf4j.LoggerFactory


/**
 * Replication Router picks up the set of nodes to which this request needs to be sent to (if not already set)
 * and forwards the request to the list of nodes in parallel and waits for all of them to complete
 */
abstract class ReplicationRouter(nrReplicas: Int, self: MemberAddress) extends ServerInterceptor { me =>

  protected val log = LoggerFactory.getLogger(me.getClass)

  override def interceptCall[ReqT, RespT](serverCall: ServerCall[ReqT, RespT], headers: Metadata, next: ServerCallHandler[ReqT, RespT]): Listener[ReqT] = {
    log.trace("Intercepting " + serverCall.getMethodDescriptor.getFullMethodName + " method in " + self)
    val replicator = this
    new Listener[ReqT] {
      var forwarded = false
      val delegate = next.startCall(serverCall, headers)

      override def onReady(): Unit = delegate.onReady()
      override def onMessage(incomingRequest: ReqT): Unit = {
        log.trace("onMessage in replicator")
        if (headers.containsKey(Headers.REPLICATION_REQUEST_KEY) && headers.get(Headers.REPLICATION_REQUEST_KEY).equals(self.toString)) {
          log.info("Received replication request for {}, processing it", incomingRequest)
          delegate.onMessage(incomingRequest)
        } else if (headers.containsKey(Headers.ELIGIBLE_NODES_KEY)) {
          // since this isn't a replication request - replicate the request to list of nodes as defined in ELIGIBLE_NODES header
          val nodes = headers.get(Headers.ELIGIBLE_NODES_KEY)
          // if the nodes to replicate contain self disable forwarded in all other cases forwarded is true
          // we need this since the default ServerHandler under which the actual delegate is wrapped under
          // invokes the method only in onHalfClose and not in onMessage (for non-streaming requests)
          forwarded = !nodes.contains(self)
          log.trace("Going to replicate the request to {}", nodes)
          replicator.replicate(nodes, serverCall, headers, incomingRequest, delegate)
          log.trace("Replication complete for {}", incomingRequest)
        } else {
          log.trace("Ignoring the request since I don't know what to do")
        }
      }

      override def onHalfClose(): Unit = {
        // apparently default ServerCall listener seems to hold some state from OnMessage which fails
        // here and client fails with an exception message -- Half-closed without a request
        if (forwarded) serverCall.close(Status.OK, headers) else delegate.onHalfClose()
      }
      override def onCancel(): Unit = delegate.onCancel()
      override def onComplete(): Unit = delegate.onComplete()
    }
  }

  def forward[RespT, ReqT](methodDescriptor: MethodDescriptor[ReqT, RespT], headers: Metadata, incomingRequest: ReqT, destination: MemberAddress): Any = {
    // Add HEADER to signify that this is a REPLICATION_REQUEST
    headers.put(Headers.REPLICATION_REQUEST_KEY, destination.toString)
    val nettyChannel = NettyChannelBuilder.forAddress(destination.host, destination.port).usePlaintext(true).build()

    val clientResponse = ClientCalls.blockingUnaryCall(
      ClientInterceptors.interceptForward(nettyChannel, MetadataUtils.newAttachHeadersInterceptor(headers)),
      methodDescriptor,
      CallOptions.DEFAULT,
      incomingRequest)

    nettyChannel.shutdown()
    clientResponse
  }

  def forwardAsync[RespT, ReqT](methodDescriptor: MethodDescriptor[ReqT, RespT], headers: Metadata,
                                incomingRequest: ReqT,
                                destination: MemberAddress)(implicit executor: Executor): ListenableFuture[RespT] = {
    // Add HEADER to signify that this is a REPLICATION_REQUEST
    headers.put(Headers.REPLICATION_REQUEST_KEY, destination.toString)
    val nettyChannel = NettyChannelBuilder.forAddress(destination.host, destination.port).usePlaintext(true).build()
    val clientCall = ClientInterceptors.interceptForward(nettyChannel, MetadataUtils.newAttachHeadersInterceptor(headers))
      .newCall(methodDescriptor, CallOptions.DEFAULT)
    val clientResponse = ClientCalls.futureUnaryCall(clientCall, incomingRequest)
    clientResponse.addListener(new Runnable {
      override def run(): Unit = nettyChannel.shutdown()
    }, executor)
    clientResponse
  }

  /**
   * Subclasses can choose to implement on how they want to replicate.
   *
   * See [[SequentialReplicator]] for usage.
   */
  def replicate[ReqT, RespT](eligibleNodes: List[MemberAddress], serverCall: ServerCall[ReqT, RespT], headers: Metadata, incomingRequest: ReqT, delegate: ServerCall.Listener[ReqT]): Unit = {
    eligibleNodes match {
      case nodes if nodes.size < nrReplicas =>
        log.warn("We don't have enough nodes to satisfy the replication factor. Not processing this request")
        serverCall.close(Status.FAILED_PRECONDITION, headers)
      case nodes if nodes.nonEmpty =>
        log.info("Replication nodes for {} are {}", incomingRequest, nodes)
        doReplication(eligibleNodes, serverCall, headers, incomingRequest, delegate)
      case Nil =>
        log.error("This should never happen. No nodes found to place replica")
        serverCall.close(Status.INTERNAL, headers)
    }
  }

  /**
   * Implement the actual replication logic assuming that you've the right set of nodes.
   * Just do it!
   *
   * Error handling and other scenarios are handled at [[ReplicationRouter.replicate]]
   * */
  def doReplication[ReqT, RespT](eligibleNodes: List[MemberAddress], serverCall: ServerCall[ReqT, RespT], headers: Metadata, incomingRequest: ReqT, delegate: Listener[ReqT]): Unit
}

class SequentialReplicator(nrReplicas: Int, self: MemberAddress) extends ReplicationRouter(nrReplicas, self) {
  override def doReplication[ReqT, RespT](eligibleNodes: List[MemberAddress], serverCall: ServerCall[ReqT, RespT], headers: Metadata, incomingRequest: ReqT, delegate: Listener[ReqT]) = {
    log.debug("Sequentially sending out replication requests to the above set of nodes")

    val hasLocalMember = eligibleNodes.exists(_.equals(self))

    eligibleNodes.filterNot(_.equals(self)).foreach { destination =>
      forward(serverCall.getMethodDescriptor, headers, incomingRequest, destination)
    }

    // we need to push this after the forwarding else we return to client immediately saying we're done
    if(hasLocalMember) {
      delegate.onMessage(incomingRequest)
    }
  }
}

object ParallelReplicator {
  implicit val PARALLEL_REPLICATION_EXECUTOR = Executors.newFixedThreadPool(3)
}
class ParallelReplicator(nrReplicas: Int, self: MemberAddress) extends ReplicationRouter(nrReplicas, self) {
  import ParallelReplicator._
  override def doReplication[ReqT, RespT](eligibleNodes: List[MemberAddress], serverCall: ServerCall[ReqT, RespT], headers: Metadata, incomingRequest: ReqT, delegate: Listener[ReqT]): Unit = {
    log.debug("Sending out replication requests to the above set of nodes in parallel")

    val hasLocalMember = eligibleNodes.exists(_.equals(self))

    val replicationResponses = eligibleNodes.filterNot(_.equals(self)).map { destination =>
        forwardAsync(serverCall.getMethodDescriptor, headers, incomingRequest, destination)
    }

    // Future.sequence equivalent + doing a get to ensure all operations complete
    log.debug("Waiting for replication response from replica nodes")
    Futures.allAsList(replicationResponses:_*).get()

    // we need to push this after the forwarding else we return to client immediately saying we're done
    if(hasLocalMember) {
      delegate.onMessage(incomingRequest)
    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy