All Downloads are FREE. Search and download functionalities are using the official Maven repository.

sttp.tapir.Validator.scala Maven / Gradle / Ivy

There is a newer version: 1.11.11
Show newest version
package sttp.tapir

import sttp.tapir.Schema.SName
import sttp.tapir.macros.ValidatorMacros

import scala.collection.immutable

sealed trait Validator[T] {
  def apply(t: T): List[ValidationError[_]]

  def contramap[TT](g: TT => T): Validator[TT] = Validator.Mapped(this, g)

  def and(other: Validator[T]): Validator[T] = Validator.all(this, other)
  def or(other: Validator[T]): Validator[T] = Validator.any(this, other)

  def show: Option[String] = Validator.show(this)
}

object Validator extends ValidatorMacros {

  // Used to capture encoding of a value to a raw format, which will then be directly rendered as a string in
  // documentation. This is needed as codecs for nested types aren't available.
  type EncodeToRaw[T] = T => Option[scala.Any]

  private val _pass: Validator[Nothing] = all()
  private val _reject: Validator[Nothing] = any()

  // -------------------------------------- SMART CONSTRUCTORS --------------------------------------
  def all[T](v: Validator[T]*): Validator[T] = if (v.size == 1) v.head else All[T](v.toList)

  def any[T](v: Validator[T]*): Validator[T] = if (v.size == 1) v.head else Any[T](v.toList)

  /** A validator instance that always pass. */
  def pass[T]: Validator[T] = _pass.asInstanceOf[Validator[T]]

  /** A validator instance that always reject. */
  def reject[T]: Validator[T] = _reject.asInstanceOf[Validator[T]]

  def min[T: Numeric](value: T, exclusive: Boolean = false): Validator.Primitive[T] = Min(value, exclusive)
  def max[T: Numeric](value: T, exclusive: Boolean = false): Validator.Primitive[T] = Max(value, exclusive)
  def positive[T: Numeric]: Validator.Primitive[T] = Min(implicitly[Numeric[T]].zero, exclusive = true)
  def positiveOrZero[T: Numeric]: Validator.Primitive[T] = Min(implicitly[Numeric[T]].zero, exclusive = false)
  def negative[T: Numeric]: Validator.Primitive[T] = Max(implicitly[Numeric[T]].zero, exclusive = true)
  def inRange[T: Numeric](min: T, max: T, minExclusive: Boolean = false, maxExclusive: Boolean = false): Validator[T] =
    Min(min, minExclusive).and(Max(max, maxExclusive))

  // string
  def pattern[T <: String](value: String): Validator.Primitive[T] = Pattern(value)

  /** Create a validator for minimum length constraints on strings.
    *
    * @param value
    *   The minimum allowed length.
    */
  def minLength[T <: String](value: Int): Validator.Primitive[T] = minLength(value, countCodePoints = false)

  /** Create a validator for minimum length constraints on strings.
    *
    * @param value
    *   The minimum allowed length.
    * @param countCodePoints
    *   A boolean parameter that determines whether the validation will consider code points or character count. When set to true, the
    *   validator will consider characters that are represented using two code units (a surrogate pair) in UTF-16 encoding allowing for a
    *   more accurate length validation for strings containing such characters (like emojis). Defaults to false, where length is validated
    *   based on the number of `Char` values in the string.
    */
  def minLength[T <: String](value: Int, countCodePoints: Boolean): Validator.Primitive[T] =
    MinLength(value, countCodePoints)

  /** Create a validator for maximum length constraints on strings.
    *
    * @param value
    *   The maximum allowed length.
    */
  def maxLength[T <: String](value: Int): Validator.Primitive[T] = maxLength(value, countCodePoints = false)

  /** Create a validator for maximum length constraints on strings.
    *
    * @param value
    *   The maximum allowed length.
    * @param countCodePoints
    *   A boolean parameter that determines whether the validation will consider code points or character count. When set to true, the
    *   validator will consider characters that are represented using two code units (a surrogate pair) in UTF-16 encoding allowing for a
    *   more accurate length validation for strings containing such characters (like emojis). Defaults to false, where length is validated
    *   based on the number of `Char` values in the string.
    */
  def maxLength[T <: String](value: Int, countCodePoints: Boolean): Validator.Primitive[T] =
    MaxLength(value, countCodePoints)

  /** Create a validator for fixed length constraints on strings.
    *
    * @param value
    *   The fixed allowed length.
    */
  def fixedLength[T <: String](value: Int): Validator[T] = fixedLength(value, countCodePoints = false)

  /** Create a validator for fixed length constraints on strings.
    *
    * @param value
    *   The fixed allowed length.
    * @param countCodePoints
    *   A boolean parameter that determines whether the validation will consider code points or character count. When set to true, the
    *   validator will consider characters that are represented using two code units (a surrogate pair) in UTF-16 encoding allowing for a
    *   more accurate length validation for strings containing such characters (like emojis). Defaults to false, where length is validated
    *   based on the number of `Char` values in the string.
    */
  def fixedLength[T <: String](value: Int, countCodePoints: Boolean): Validator[T] =
    MinLength(value, countCodePoints).and(MaxLength(value, countCodePoints))

  def nonEmptyString[T <: String]: Validator.Primitive[T] = MinLength(1)

  // iterable
  def minSize[T, C[_] <: Iterable[_]](value: Int): Validator.Primitive[C[T]] = MinSize(value)
  def maxSize[T, C[_] <: Iterable[_]](value: Int): Validator.Primitive[C[T]] = MaxSize(value)
  def nonEmpty[T, C[_] <: Iterable[_]]: Validator.Primitive[C[T]] = MinSize(1)
  def fixedSize[T, C[_] <: Iterable[_]](value: Int): Validator[C[T]] = MinSize(value).and(MaxSize(value))

  // enum
  /** Create an enumeration validator, with the given possible values.
    *
    * To represent the enumerated values in documentation, an encoding function needs to be provided. This can be done:
    *
    *   - by using the overloaded [[enumeration]] method with an `encode` parameter
    *   - by adding an encode function on an [[Validator.Enumeration]] instance using one of the `.encode` functions
    *   - by adding the validator directly to a codec (see [[Mapping.addEncodeToEnumValidator]])
    *   - when the values possible values are of a basic type (numbers, strings), the encode function is inferred if not present, when being
    *     added to the schema, see [[Schema.validate]]
    */
  def enumeration[T](possibleValues: List[T]): Validator.Enumeration[T] = Enumeration(possibleValues, None, None)

  /** Create an enumeration validator, with the given possible values, an optional encoding function (so that the enumerated values can be
    * represented in documentation), and an optional name (to create a reusable documentation component).
    *
    * @param encode
    *   Specify how values of this type can be encoded to a raw value, which will be used for documentation.
    */
  def enumeration[T](possibleValues: List[T], encode: EncodeToRaw[T], name: Option[SName] = None): Validator.Enumeration[T] =
    Enumeration(possibleValues, Some(encode), name)

  /** Create a custom validator.
    * @param validationLogic
    *   The logic of the validator
    * @param showMessage
    *   Description of the validator used when invoking [[Validator.show]].
    */
  def custom[T](validationLogic: T => ValidationResult, showMessage: Option[String] = None): Validator[T] =
    Custom(validationLogic, showMessage)

  // ---------- PRIMITIVE ----------
  sealed trait Primitive[T] extends Validator[T] {
    def doValidate(t: T): ValidationResult
    override def apply(t: T): List[ValidationError[T]] = doValidate(t) match {
      case ValidationResult.Valid => Nil
      case ValidationResult.Invalid(customMessages) =>
        customMessages match {
          case Nil => List(ValidationError(this, t, Nil, None))
          case l   => l.map(m => ValidationError(this, t, Nil, Some(m)))
        }
    }
  }
  case class Min[T](value: T, exclusive: Boolean)(implicit val valueIsNumeric: Numeric[T]) extends Primitive[T] {
    override def doValidate(t: T): ValidationResult =
      ValidationResult.validWhen(implicitly[Numeric[T]].gt(t, value) || (!exclusive && implicitly[Numeric[T]].equiv(t, value)))
  }
  case class Max[T](value: T, exclusive: Boolean)(implicit val valueIsNumeric: Numeric[T]) extends Primitive[T] {
    override def doValidate(t: T): ValidationResult =
      ValidationResult.validWhen(implicitly[Numeric[T]].lt(t, value) || (!exclusive && implicitly[Numeric[T]].equiv(t, value)))
  }
  case class Pattern[T <: String](value: String) extends Primitive[T] {
    override def doValidate(t: T): ValidationResult = ValidationResult.validWhen(t.matches(value))
  }

  case class MinLength[T <: String](value: Int, countCodePoints: Boolean) extends Primitive[T] {
    // The auxiliary constructor, custom copy methods, and the custom apply method in the companion object
    // have been introduced to maintain binary compatibility while adding a new field to the case class
    def this(value: Int) = this(value, false)
    def copy[TT <: String](value: Int): MinLength[TT] = MinLength[TT](value, false)
    def copy[TT <: String](value: Int = this.value, countCodePoints: Boolean = this.countCodePoints): MinLength[TT] =
      MinLength[TT](value, countCodePoints)

    override def doValidate(t: T): ValidationResult = {
      val size = if (countCodePoints) t.codePointCount(0, t.size) else t.size
      ValidationResult.validWhen(size >= value)
    }
  }
  object MinLength {
    def apply[T <: String](value: Int) = new MinLength[T](value, false)
  }

  case class MaxLength[T <: String](value: Int, countCodePoints: Boolean) extends Primitive[T] {
    // The auxiliary constructor, custom copy methods, and the custom apply method in the companion object
    // have been introduced to maintain binary compatibility while adding a new field to the case class
    def this(value: Int) = this(value, false)
    def copy[TT <: String](value: Int): MaxLength[TT] = MaxLength[TT](value, false)
    def copy[TT <: String](value: Int = this.value, countCodePoints: Boolean = this.countCodePoints): MaxLength[TT] =
      MaxLength[TT](value, countCodePoints)

    override def doValidate(t: T): ValidationResult = {
      val size = if (countCodePoints) t.codePointCount(0, t.size) else t.size
      ValidationResult.validWhen(size <= value)
    }
  }
  object MaxLength {
    def apply[T <: String](value: Int) = new MaxLength[T](value, false)
  }

  case class MinSize[T, C[_] <: Iterable[_]](value: Int) extends Primitive[C[T]] {
    override def doValidate(t: C[T]): ValidationResult = ValidationResult.validWhen(t.size >= value)
  }
  case class MaxSize[T, C[_] <: Iterable[_]](value: Int) extends Primitive[C[T]] {
    override def doValidate(t: C[T]): ValidationResult = ValidationResult.validWhen(t.size <= value)
  }
  case class Custom[T](validationLogic: T => ValidationResult, showMessage: Option[String] = None) extends Primitive[T] {
    override def doValidate(t: T): ValidationResult = validationLogic(t)
  }

  case class Enumeration[T](possibleValues: List[T], encode: Option[EncodeToRaw[T]], name: Option[SName]) extends Primitive[T] {
    override def doValidate(t: T): ValidationResult = ValidationResult.validWhen(possibleValues.contains(t))

    /** Specify how values of this type can be encoded to a raw value (typically a [[String]]). This encoding will be used when generating
      * documentation.
      */
    def encode(e: T => scala.Any): Enumeration[T] = copy(encode = Some(v => Some(e(v))))

    /** Specify that values of this type should be encoded to a raw value using an in-scope plain codec. This encoding will be used when
      * generating documentation.
      */
    def encodeWithPlainCodec(implicit c: Codec.PlainCodec[T]): Enumeration[T] = copy(encode = Some(v => Some(c.encode(v))))

    /** Specify that values of this type should be encoded to a raw value using an in-scope codec of the given format. This encoding will be
      * used when generating documentation.
      */
    def encodeWithCodec[CF <: CodecFormat](implicit c: Codec[String, T, CF]): Enumeration[T] = copy(encode = Some(v => Some(c.encode(v))))
  }

  //

  case class Mapped[TT, T](wrapped: Validator[T], g: TT => T) extends Validator[TT] {
    override def apply(t: TT): List[ValidationError[_]] = wrapped.apply(g(t))
  }

  case class All[T](validators: immutable.Seq[Validator[T]]) extends Validator[T] {
    override def apply(t: T): List[ValidationError[_]] = validators.flatMap(_.apply(t)).toList

    override def contramap[TT](g: TT => T): Validator[TT] = if (validators.isEmpty) this.asInstanceOf[Validator[TT]] else super.contramap(g)
    override def and(other: Validator[T]): Validator[T] = if (validators.isEmpty) other else All(validators :+ other)
  }

  case class Any[T](validators: immutable.Seq[Validator[T]]) extends Validator[T] {
    override def apply(t: T): List[ValidationError[_]] = {
      val results = validators.map(_.apply(t))
      if (results.exists(_.isEmpty)) {
        List.empty
      } else if (validators.nonEmpty) {
        results.flatten.toList
      } else {
        List(ValidationError[T](Any.PrimitiveRejectValidator, t))
      }
    }

    override def contramap[TT](g: TT => T): Validator[TT] = if (validators.isEmpty) Any(Nil) else super.contramap(g)
    override def or(other: Validator[T]): Validator[T] = if (validators.isEmpty) other else Any(validators :+ other)
  }
  object Any {
    private val _primitiveRejectValidator: Primitive[scala.Any] = Custom(_ => ValidationResult.Invalid("Validation rejected"))
    private[tapir] def PrimitiveRejectValidator[T]: Primitive[T] = _primitiveRejectValidator.asInstanceOf[Primitive[T]]
  }

  //

  def show[T](v: Validator[T]): Option[String] = {
    v match {
      case Min(value, exclusive)             => Some(s"${if (exclusive) ">" else ">="}$value")
      case Max(value, exclusive)             => Some(s"${if (exclusive) "<" else "<="}$value")
      case Pattern(value)                    => Some(s"~$value")
      case MinLength(value, _)               => Some(s"length>=$value")
      case MaxLength(value, _)               => Some(s"length<=$value")
      case MinSize(value)                    => Some(s"size>=$value")
      case MaxSize(value)                    => Some(s"size<=$value")
      case Custom(_, showMessage)            => showMessage.orElse(Some("custom"))
      case Enumeration(possibleValues, _, _) => Some(s"in(${possibleValues.mkString(",")}")
      case Mapped(wrapped, _)                => show(wrapped)
      case All(validators) =>
        validators.flatMap(show(_)) match {
          case immutable.Seq()  => None
          case immutable.Seq(s) => Some(s)
          case ss               => Some(s"all(${ss.mkString(",")})")
        }
      case Any(validators) =>
        validators.flatMap(show(_)) match {
          case immutable.Seq()  => Some("reject")
          case immutable.Seq(s) => Some(s)
          case ss               => Some(s"any(${ss.mkString(",")})")
        }
    }
  }
}

sealed trait ValidationResult
object ValidationResult {
  case object Valid extends ValidationResult
  case class Invalid(customMessage: List[String] = Nil) extends ValidationResult
  object Invalid {
    def apply(customMessage: String): Invalid = Invalid(List(customMessage))
    def apply(customMessage: String, customMessages: String*): Invalid = Invalid(customMessage :: customMessages.toList)
  }

  def validWhen(condition: Boolean): ValidationResult = if (condition) Valid else Invalid()
}

case class ValidationError[T](
    validator: Validator.Primitive[T],
    invalidValue: Any, // this isn't T, as we might want to report that a value can't be decoded as an enumeration member
    path: List[FieldName] = Nil,
    customMessage: Option[String] = None
) {
  def prependPath(f: FieldName): ValidationError[T] = copy(path = f :: path)
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy