All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.outworkers.phantom.builder.query.QueryOptions.scala Maven / Gradle / Ivy

/*
 * Copyright 2013 - 2020 Outworkers Ltd.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.outworkers.phantom.builder.query

import java.nio.ByteBuffer
import java.util.{Collections, Map => JMap}

import scala.collection.JavaConverters._
import com.datastax.driver.core._
import com.datastax.driver.core.policies.TokenAwarePolicy
import com.outworkers.phantom.Manager
import com.outworkers.phantom.builder.ops.TokenizerKey
import com.outworkers.phantom.builder.primitives.Primitive
import shapeless.{::, Generic, HList, HNil, Lazy}

import scala.annotation.implicitNotFound

trait Modifier extends (Statement => Statement)

case class Payload(underlying: JMap[String, ByteBuffer]) {
  def isEmpty: Boolean = underlying.isEmpty

  def add(other: (String, ByteBuffer)): Payload = {
    val (key, value) = other
    underlying.put(key, value)
    Payload(underlying)
  }

  def add[T](other: (String, T))(
    implicit ev: Primitive[T],
    pv: ProtocolVersion
  ): Payload = {
    val (key, value) = other
    underlying.put(key, ev.serialize(value, pv))
    Payload(underlying)
  }
}

@implicitNotFound("Payloads are a sequence of (key, value) tuples where the key is always String, and the value is a primitive.")
trait PayloadSerializer[HL <: HList] {
  def apply(input: HL): Seq[(String, ByteBuffer)]
}

object PayloadSerializer {

  def apply[HL <: HList](implicit ev: PayloadSerializer[HL]): PayloadSerializer[HL] = ev

  implicit val hNilSerializer: PayloadSerializer[HNil] = new PayloadSerializer[HNil] {
    override def apply(input: HNil): Seq[(String, ByteBuffer)] = Seq.empty
  }

  implicit def hconsSerializer[H, HL <: HList, A, B](
    implicit tpEv: H <:< (String, B),
    ev: Primitive[B],
    pv: ProtocolVersion,
    ps: Lazy[PayloadSerializer[HL]]
  ): PayloadSerializer[H :: HL] = new PayloadSerializer[::[H, HL]] {
    override def apply(input: ::[H, HL]): Seq[(String, ByteBuffer)] = {
      val (key, value): (String, B) = input.head
      Seq(key -> ev.serialize(value, pv)) ++ ps.value(input.tail)
    }
  }
}

object Payload {
  def empty: Payload = new Payload(Collections.emptyMap())

  def apply(map: Map[String, ByteBuffer]): Payload = new Payload(map.asJava)

  def apply(tp: (String, ByteBuffer)): Payload = apply(Seq(tp).toMap)

  def seq(tp: (String, ByteBuffer)*): Payload = apply(tp.toMap)

  def apply[T](tp: (String, T))(implicit ev: Primitive[T], pv: ProtocolVersion): Payload = {
    val (key, value) = tp
    apply(Seq(key -> ev.serialize(value, pv)).toMap)
  }

  def apply[V1, HL <: HList](tp: V1)(
    implicit gen: Generic.Aux[V1, HL],
    pv: ProtocolVersion,
    ps: PayloadSerializer[HL]
  ): Payload = {
    seq(ps(gen to tp): _*)
  }
}

case class RoutingKeyModifier(
  tokens: List[TokenizerKey]
)(
  implicit session: Session
) extends (SimpleStatement => SimpleStatement) {
  override def apply(st: SimpleStatement): SimpleStatement = {

    val policy = session.getCluster.getConfiguration.getPolicies.getLoadBalancingPolicy

    if (policy.isInstanceOf[TokenAwarePolicy] && tokens.nonEmpty) {

      val routingKeys = tokens.map(_.apply(session))

      Manager.logger.debug(s"Routing key tokens found. Settings routing key to ${routingKeys.map(_.cql).mkString("(", ",", ")")}")

      st
        .setRoutingKey(routingKeys.map(_.bytes):_*)
        .setKeyspace(session.getLoggedKeyspace)
    } else {
      st
    }
  }
}

class ConsistencyLevelModifier(level: Option[ConsistencyLevel]) extends Modifier {
  override def apply(v1: Statement): Statement = {
    (level map v1.setConsistencyLevel).getOrElse(v1)
  }
}

class SerialConsistencyLevelModifier(level: Option[ConsistencyLevel]) extends Modifier {
  override def apply(v1: Statement): Statement = {
    (level map v1.setSerialConsistencyLevel).getOrElse(v1)
  }
}


class PayloadModifier(payload: Payload) extends Modifier {
  override def apply(v1: Statement): Statement = {
    if (payload.isEmpty) {
      v1
    } else {
      v1.setOutgoingPayload(payload.underlying)
    }
  }
}



class PagingStateModifier(level: Option[PagingState]) extends Modifier {
  override def apply(v1: Statement): Statement = {
    (level map v1.setPagingState).getOrElse(v1)
  }
}

class EnableTracingModifier(level: Option[Boolean]) extends Modifier {
  override def apply(v1: Statement): Statement = {
    level match {
      case Some(true) => v1.enableTracing()
      case Some(false) => v1.disableTracing()
      case None => v1
    }
  }
}

class FetchSizeModifier(level: Option[Int]) extends Modifier {
  override def apply(v1: Statement): Statement = {
    (level map v1.setFetchSize).getOrElse(v1)
  }
}

case class QueryOptions(
  consistencyLevel: Option[ConsistencyLevel],
  serialConsistencyLevel: Option[ConsistencyLevel],
  pagingState: Option[PagingState] = None,
  enableTracing: Option[Boolean] = None,
  fetchSize: Option[Int] = None,
  outgoingPayload: Payload = Payload.empty
) {

  def apply(st: Statement): Statement = {
    val applier = List[Statement => Statement](
      new ConsistencyLevelModifier(consistencyLevel),
      new SerialConsistencyLevelModifier(serialConsistencyLevel),
      new PagingStateModifier(pagingState),
      new EnableTracingModifier(enableTracing),
      new FetchSizeModifier(fetchSize),
      new PayloadModifier(outgoingPayload)
    ) reduce(_ andThen _)

    applier(st)
  }

  def options: com.datastax.driver.core.QueryOptions = {
    val opt = new com.datastax.driver.core.QueryOptions()

    consistencyLevel map opt.setConsistencyLevel
    serialConsistencyLevel map opt.setSerialConsistencyLevel
    fetchSize map opt.setFetchSize

    opt
  }

  def outgoingPayload_=(payload: Payload): QueryOptions = {
    this.copy(outgoingPayload = payload)
  }

  def consistencyLevel_=(level: ConsistencyLevel): QueryOptions = {
    this.copy(consistencyLevel = Some(level))
  }

  def serialConsistencyLevel_=(level: ConsistencyLevel): QueryOptions = {
    this.copy(serialConsistencyLevel = Some(level))
  }

  def enableTracing_=(flag: Boolean): QueryOptions = {
    this.copy(enableTracing = Some(flag))
  }

  def fetchSize_=(size: Int): QueryOptions = {
    this.copy(fetchSize = Some(size))
  }
}

object QueryOptions {
  def empty: QueryOptions = {
    new QueryOptions(
      consistencyLevel = None,
      serialConsistencyLevel = None,
      pagingState = None,
      enableTracing = None,
      fetchSize = None
    )
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy