All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.redislabs.provider.redis.RedisConfig.scala Maven / Gradle / Ivy

package com.redislabs.provider.redis

import java.net.URI

import org.apache.spark.SparkConf
import redis.clients.jedis.util.{JedisClusterCRC16, JedisURIHelper, SafeEncoder}
import redis.clients.jedis.{Jedis, Protocol}

import scala.collection.JavaConversions._


/**
  * RedisEndpoint represents a redis connection endpoint info: host, port, auth password
  * db number, and timeout
  *
  * @param host  the redis host or ip
  * @param port  the redis port
  * @param auth  the authentication password
  * @param dbNum database number (should be avoided in general)
  */
case class RedisEndpoint(host: String = Protocol.DEFAULT_HOST,
                         port: Int = Protocol.DEFAULT_PORT,
                         auth: String = null,
                         dbNum: Int = Protocol.DEFAULT_DATABASE,
                         timeout: Int = Protocol.DEFAULT_TIMEOUT)
  extends Serializable {

  /**
    * Constructor from spark config. set params with spark.redis.host, spark.redis.port, spark.redis.auth and spark.redis.db
    *
    * @param conf spark context config
    */
  def this(conf: SparkConf) {
    this(
      conf.get("spark.redis.host", Protocol.DEFAULT_HOST),
      conf.getInt("spark.redis.port", Protocol.DEFAULT_PORT),
      conf.get("spark.redis.auth", null),
      conf.getInt("spark.redis.db", Protocol.DEFAULT_DATABASE),
      conf.getInt("spark.redis.timeout", Protocol.DEFAULT_TIMEOUT)
    )
  }

  /**
    * Constructor with Jedis URI
    *
    * @param uri connection URI in the form of redis://:$password@$host:$port/[dbnum]
    */
  def this(uri: URI) {
    this(uri.getHost, uri.getPort, JedisURIHelper.getPassword(uri), JedisURIHelper.getDBIndex(uri))
  }

  /**
    * Constructor with Jedis URI from String
    *
    * @param uri connection URI in the form of redis://:$password@$host:$port/[dbnum]
    */
  def this(uri: String) {
    this(URI.create(uri))
  }

  /**
    * Connect tries to open a connection to the redis endpoint,
    * optionally authenticating and selecting a db
    *
    * @return a new Jedis instance
    */
  def connect(): Jedis = {
    ConnectionPool.connect(this)
  }

  /**
    * @return config with masked password. Used for logging.
    */
  def maskPassword(): RedisEndpoint = {
    this.copy(auth = "")
  }
}

case class RedisNode(endpoint: RedisEndpoint,
                     startSlot: Int,
                     endSlot: Int,
                     idx: Int,
                     total: Int) {
  def connect(): Jedis = {
    endpoint.connect()
  }
}

/**
  * Tuning options for read and write operations.
  */
case class ReadWriteConfig(scanCount: Int, maxPipelineSize: Int)

object ReadWriteConfig {
  /** maximum number of commands per pipeline **/
  val MaxPipelineSizeConfKey = "spark.redis.max.pipeline.size"
  val MaxPipelineSizeDefault = 100

  /** count option of SCAN command **/
  val ScanCountConfKey = "spark.redis.scan.count"
  val ScanCountDefault = 100

  val Default: ReadWriteConfig = ReadWriteConfig(ScanCountDefault, MaxPipelineSizeDefault)

  def fromSparkConf(conf: SparkConf): ReadWriteConfig = {
    ReadWriteConfig(
      conf.getInt(ScanCountConfKey, ScanCountDefault),
      conf.getInt(MaxPipelineSizeConfKey, MaxPipelineSizeDefault)
    )
  }
}

object RedisConfig {

  /**
    * create redis config from spark config
    */
  def fromSparkConf(conf: SparkConf): RedisConfig = {
    new RedisConfig(new RedisEndpoint(conf))
  }
}

/**
  * RedisConfig holds the state of the cluster nodes, and uses consistent hashing to map
  * keys to nodes
  */
class RedisConfig(val initialHost: RedisEndpoint) extends Serializable {

  val initialAddr: String = initialHost.host

  val hosts: Array[RedisNode] = getHosts(initialHost)
  val nodes: Array[RedisNode] = getNodes(initialHost)

  /**
    * @return initialHost's auth
    */
  def getAuth: String = {
    initialHost.auth
  }

  /**
    * @return selected db number
    */
  def getDB: Int = {
    initialHost.dbNum
  }

  def getRandomNode(): RedisNode = {
    val rnd = scala.util.Random.nextInt().abs % hosts.length
    hosts(rnd)
  }

  /**
    * @param sPos start slot number
    * @param ePos end slot number
    * @return a list of RedisNode whose slots union [sPos, ePos] is not null
    */
  def getNodesBySlots(sPos: Int, ePos: Int): Array[RedisNode] = {
    /* This function judges if [sPos1, ePos1] union [sPos2, ePos2] is not null */
    def inter(sPos1: Int, ePos1: Int, sPos2: Int, ePos2: Int) =
      if (sPos1 <= sPos2) ePos1 >= sPos2 else ePos2 >= sPos1

    nodes.filter(node => inter(sPos, ePos, node.startSlot, node.endSlot)).
      filter(_.idx == 0) //master only now
  }

  /**
    * @param key
    * *IMPORTANT* Please remember to close after using
    * @return jedis who is a connection for a given key
    */
  def connectionForKey(key: String): Jedis = {
    getHost(key).connect()
  }

  /**
    * @param initialHost any redis endpoint of a cluster or a single server
    * @return true if the target server is in cluster mode
    */
  private def clusterEnabled(initialHost: RedisEndpoint): Boolean = {
    val conn = initialHost.connect()
    val info = conn.info.split("\n")
    val version = info.filter(_.contains("redis_version:"))(0)
    val clusterEnable = info.filter(_.contains("cluster_enabled:"))
    val mainVersion = version.substring(14, version.indexOf(".")).toInt
    val res = mainVersion > 2 && clusterEnable.length > 0 && clusterEnable(0).contains("1")
    conn.close()
    res
  }

  /**
    * @param key
    * @return host whose slots should involve key
    */
  def getHost(key: String): RedisNode = {
    val slot = JedisClusterCRC16.getSlot(key)
    hosts.filter(host => {
      host.startSlot <= slot && host.endSlot >= slot
    })(0)
  }


  /**
    * @param initialHost any redis endpoint of a cluster or a single server
    * @return list of host nodes
    */
  private def getHosts(initialHost: RedisEndpoint): Array[RedisNode] = {
    getNodes(initialHost).filter(_.idx == 0)
  }

  /**
    * @param initialHost any redis endpoint of a single server
    * @return list of nodes
    */
  private def getNonClusterNodes(initialHost: RedisEndpoint): Array[RedisNode] = {
    val master = (initialHost.host, initialHost.port)
    val conn = initialHost.connect()

    val replinfo = conn.info("Replication").split("\n")
    conn.close()

    // If  this node is a slave, we need to extract the slaves from its master
    if (replinfo.exists(_.contains("role:slave"))) {
      val host = replinfo.filter(_.contains("master_host:"))(0).trim.substring(12)
      val port = replinfo.filter(_.contains("master_port:"))(0).trim.substring(12).toInt

      //simply re-enter this function witht he master host/port
      getNonClusterNodes(initialHost = new RedisEndpoint(host, port,
        initialHost.auth, initialHost.dbNum))

    } else {
      //this is a master - take its slaves

      val slaves = replinfo.filter(x => x.contains("slave") && x.contains("online")).map(rl => {
        val content = rl.substring(rl.indexOf(':') + 1).split(",")
        val ip = content(0)
        val port = content(1)
        (ip.substring(ip.indexOf('=') + 1), port.substring(port.indexOf('=') + 1).toInt)
      })

      val nodes = master +: slaves
      val range = nodes.length
      (0 until range).map(i =>
        RedisNode(RedisEndpoint(nodes(i)._1, nodes(i)._2, initialHost.auth, initialHost.dbNum,
          initialHost.timeout),
          0, 16383, i, range)).toArray
    }
  }

  /**
    * @param initialHost any redis endpoint of a cluster server
    * @return list of nodes
    */
  private def getClusterNodes(initialHost: RedisEndpoint): Array[RedisNode] = {
    val conn = initialHost.connect()
    val res = conn.clusterSlots().flatMap {
      slotInfoObj => {
        val slotInfo = slotInfoObj.asInstanceOf[java.util.List[java.lang.Object]]
        val sPos = slotInfo.get(0).toString.toInt
        val ePos = slotInfo.get(1).toString.toInt
        /*
         * We will get all the nodes with the slots range [sPos, ePos],
         * and create RedisNode for each nodes, the total field of all
         * RedisNode are the number of the nodes whose slots range is
         * as above, and the idx field is just an index for each node
         * which will be used for adding support for slaves and so on.
         * And the idx of a master is always 0, we rely on this fact to
         * filter master.
         */
        (0 until (slotInfo.size - 2)).map(i => {
          val node = slotInfo(i + 2).asInstanceOf[java.util.List[java.lang.Object]]
          val host = SafeEncoder.encode(node.get(0).asInstanceOf[Array[scala.Byte]])
          val port = node.get(1).toString.toInt
          RedisNode(RedisEndpoint(host, port, initialHost.auth, initialHost.dbNum,
            initialHost.timeout),
            sPos,
            ePos,
            i,
            slotInfo.size - 2)
        })
      }
    }.toArray
    conn.close()
    res
  }

  /**
    * @param initialHost any redis endpoint of a cluster or a single server
    * @return list of nodes
    */
  def getNodes(initialHost: RedisEndpoint): Array[RedisNode] = {
    if (clusterEnabled(initialHost)) {
      getClusterNodes(initialHost)
    } else {
      getNonClusterNodes(initialHost)
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy