All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.twitter.finagle.stress.LoadBalancerTest.scala Maven / Gradle / Ivy

The newest version!
package com.twitter.finagle.stress

import com.google.caliper.{Param, SimpleBenchmark}
import com.twitter.app.App
import com.twitter.concurrent.{BridgedThreadPoolScheduler, Scheduler, ThreadPoolScheduler}
import com.twitter.conversions.time._
import com.twitter.finagle.Service
import com.twitter.finagle.builder.ClientBuilder
import com.twitter.finagle.httpx.{Http, Request, Response}
import com.twitter.finagle.stats.OstrichStatsReceiver
import com.twitter.ostrich.stats.{Stats => OstrichStats, StatsCollection}
import com.twitter.util.{CountDownLatch, Duration, Return, Stopwatch, Throw}
import java.util.concurrent.atomic.AtomicInteger
import scala.collection.mutable.ArrayBuffer

object LoadBalancerTest extends App {
  val nreqsFlag = flag("n", 100000, "Number of reqs sent from each client")
  val latencyFlag = flag("l", 0.seconds, "req latency forced at the server")
  val schedFlag = flag("sched", "local", "Use the specified scheduler")

  val totalRequests = new AtomicInteger(0)
  val clientBuilder = ClientBuilder()
    .requestTimeout(100.milliseconds)
    .retries(10)

  def main() {
    schedFlag() match {
      case "local" =>
      case "threadpool" =>
        println("Using threadpool scheduler")
        Scheduler.setUnsafe(new ThreadPoolScheduler("FINAGLE"))
      case "bridged" =>
        println("Using the bridged threadpool scheduler")
        Scheduler.setUnsafe(new BridgedThreadPoolScheduler("FINAGLE"))
      case unknown =>
        println("Unknown scheduler "+unknown)
        System.exit(1)
    }

    runSuite()
  }

  def doTest(latency: Duration, nreqs: Int,
             behavior: PartialFunction[(Int, Seq[EmbeddedServer]), Unit]) {
      new LoadBalancerTest(clientBuilder, latency, nreqs)(behavior).run()
  }

  def runSuite() {
    val latency = latencyFlag()
    val n = nreqsFlag()
    val N = 10*n

    println("testing " + clientBuilder)
    println("\n== baseline (warmup) ==\n")
    doTest(latency, N, { case _ => })

    println("\n== baseline ==\n")
    doTest(latency, N, { case _ => })

    println("\n== 1 server goes offline ==\n")
    doTest(latency, N, { case (`n`, servers) => servers(1).stop() })

    println("\n== 1 application becomes nonresponsive ==\n")
    doTest(latency, N,
      { case (`n`, servers) => servers(1).becomeApplicationNonresponsive() })

    println("\n== 1 connection becomes nonresponsive ==\n")
    doTest(latency, N,
      { case (`n`, servers) => servers(1).becomeConnectionNonresponsive() })

    println("\n== 1 server has a protocol error ==\n")
    doTest(latency, N,
      { case (`n`, servers) => servers(1).becomeBelligerent() })
  }
}


class LoadBalancerBenchmark extends SimpleBenchmark {
  @Param(Array("0")) val latencyInMilliSec: Long = 0
  @Param(Array("10000")) val nreqs: Int = 10000

  def timeBaseline(reps: Int) {
    var i = 0
    while (i < reps) {
      LoadBalancerTest.doTest(
        Duration.fromMilliseconds(latencyInMilliSec),
        nreqs,
        { case _ => })
      i += 1
    }
  }

  def timeOneOffline(reps: Int) {
    var i = 0
    while (i < reps) {
      LoadBalancerTest.doTest(
        Duration.fromMilliseconds(latencyInMilliSec),
        nreqs,
        { case (n, servers) => servers(1).stop() })
      i += 1
    }
  }

  def timeOneAppNonResponsive(reps: Int) {
    var i = 0
    while (i < reps) {
      LoadBalancerTest.doTest(
        Duration.fromMilliseconds(latencyInMilliSec),
        nreqs,
        { case (n, servers) if n == nreqs/10 =>  servers(1).becomeApplicationNonresponsive() })
      i += 1
    }
  }

  def timeOneConnNonResponsive(reps: Int) {
    var i = 0
    while (i < reps) {
      LoadBalancerTest.doTest(
        Duration.fromMilliseconds(latencyInMilliSec),
        nreqs,
        { case (n, servers) if n == nreqs/10  =>  servers(1).becomeConnectionNonresponsive() })
      i += 1
    }
  }

  def timeOneProtocolError(reps: Int) {
    var i = 0
    while (i < reps) {
      LoadBalancerTest.doTest(
      Duration.fromMilliseconds(latencyInMilliSec),
      nreqs,
      { case (n, servers) if n == nreqs/10  =>  servers(1).becomeBelligerent() })
      i += 1
    }
  }
}

class LoadBalancerTest(
  clientBuilder: ClientBuilder[_, _, _, _, _],
  serverLatency: Duration = 0.seconds,
  numRequests: Int = 100000,
  concurrency: Int = 20)(behavior: PartialFunction[(Int, Seq[EmbeddedServer]), Unit])
{
  private[this] val requestNumber = new AtomicInteger(0)
  private[this] val requestCount  = new AtomicInteger(numRequests)
  private[this] val latch         = new CountDownLatch(concurrency)
  private[this] val stats         = new StatsCollection
  private[this] val gaugeValues   = new ArrayBuffer[(Int, Map[String, Float])]

  private[this] def dispatch(
      client: Service[Request, Response],
      servers: Seq[EmbeddedServer],
      f: PartialFunction[(Int, Seq[EmbeddedServer]), Unit]) {
    val num = requestNumber.incrementAndGet()
    LoadBalancerTest.totalRequests.incrementAndGet()
    if (f.isDefinedAt((num, servers)))
      f((num, servers))

    val elapsed = Stopwatch.start()

    client(Request("/")).respond { result =>
      result match {
        case Return(_) =>
          val duration = elapsed()
          stats.addMetric("request_msec", duration.inMilliseconds.toInt)
          stats.addMetric("request_usec", duration.inMicroseconds.toInt)
          stats.incr("success")
        case Throw(exc) =>
          stats.incr("fail")
          stats.incr("fail_%s".format(exc.getClass.getName.split('.').last))
      }

      if (requestCount.decrementAndGet() > 0)
        dispatch(client, servers, f)
      else
        latch.countDown()
    }
  }

  def run() {
    OstrichStats.clearAll()

    val servers = (0 until 3) map { _ =>
      val server = EmbeddedServer()
      server.setLatency(serverLatency)
      server
    }

    val client = clientBuilder
      .codec(Http())
      .hosts(servers map(_.boundAddress))
      .hostConnectionLimit(Int.MaxValue)
      .reportTo(new OstrichStatsReceiver)
      .build()

    val elapsed = Stopwatch.start()
    0 until concurrency foreach { _ => dispatch(client, servers, behavior) }
    latch.await()
    val duration = elapsed()
    val rps = numRequests.toDouble / duration.inMilliseconds.toDouble * 1000

    // Produce a "report" here instead, so we have some sort of
    // semantic information here.

    println("> STATS")
    println("> rps: %.2f".format(rps))
    val succ = stats.getCounter("success")().toDouble
    val fail = stats.getCounter("fail")().toDouble
    println("> success rate: %.2f".format(100.0 * succ / (succ + fail)))
    println("> request rate: %.2f".format(rps))
    Stats.prettyPrint(stats)

    val allGaugeNames = {
      val unique = Set() ++ gaugeValues flatMap { case (_, gvs) => gvs map (_._1) }
      unique.toList.sorted
    }

    println("> %5s %s".format("time", allGaugeNames map("%-8s".format(_)) mkString(" ")))

    gaugeValues foreach { case (requestNum, values) =>
      val columns = allGaugeNames map { name =>
        val value = values.get(name)
        val formatted = value.map("%.2e".format(_)).getOrElse("n/a")
        formatted
      }
      println("> %05d %s".format(requestNum, columns.map("%8s".format(_)).mkString(" ")))
    }

    servers.zipWithIndex foreach { case (server, which) =>
      server.stop()
      println("> SERVER[%d] (%s)".format(which, server.boundAddress))
      Stats.prettyPrint(server.stats)
    }

    println("> OSTRICH counters")
    Stats.prettyPrint(OstrichStats)

    client.close()
    servers foreach { _.stop() }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy