All Downloads are FREE. Search and download functionalities are using the official Maven repository.

smithy4s.compliancetests.internals.eq.EqSchemaVisitor.scala Maven / Gradle / Ivy

/*
 *  Copyright 2021-2023 Disney Streaming
 *
 *  Licensed under the Tomorrow Open Source Technology License, Version 1.0 (the "License");
 *  you may not use this file except in compliance with the License.
 *  You may obtain a copy of the License at
 *
 *     https://disneystreaming.github.io/TOST-1.0.txt
 *
 *  Unless required by applicable law or agreed to in writing, software
 *  distributed under the License is distributed on an "AS IS" BASIS,
 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  See the License for the specific language governing permissions and
 *  limitations under the License.
 */

package smithy4s.compliancetests.internals.eq

import smithy4s.compliancetests.internals.eq.Smithy4sEqInstances._
import cats.kernel.Eq
import cats.syntax.all._
import smithy4s._
import smithy4s.schema.{Schema, _}

import java.util.UUID
import smithy4s.capability.EncoderK
import cats.kernel.Monoid

object EqSchemaVisitor extends SchemaVisitor[Eq] { self =>
  override def primitive[P](
      shapeId: ShapeId,
      hints: Hints,
      tag: Primitive[P]
  ): Eq[P] = primitiveEq(tag)

  override def collection[C[_], A](
      shapeId: ShapeId,
      hints: Hints,
      tag: CollectionTag[C],
      member: Schema[A]
  ): Eq[C[A]] = {
    implicit val memberEq: Eq[A] = self(member)
    tag match {
      case CollectionTag.ListTag       => Eq[List[A]]
      case CollectionTag.SetTag        => Eq[Set[A]]
      case CollectionTag.VectorTag     => Eq[Vector[A]]
      case CollectionTag.IndexedSeqTag => Eq[IndexedSeq[A]]
    }
  }

  override def map[K, V](
      shapeId: ShapeId,
      hints: Hints,
      key: Schema[K],
      value: Schema[V]
  ): Eq[Map[K, V]] = {
    implicit val valueEq: Eq[V] = self(value)
    Eq[Map[K, V]]
  }

  override def enumeration[E](
      shapeId: ShapeId,
      hints: Hints,
      tag: EnumTag[E],
      values: List[EnumValue[E]],
      total: E => EnumValue[E]
  ): Eq[E] =
    Eq.by { x =>
      val enumX = total(x)
      (enumX.intValue, enumX.stringValue)
    }

  override def struct[S](
      shapeId: ShapeId,
      hints: Hints,
      fields: Vector[Field[S, _]],
      make: IndexedSeq[Any] => S
  ): Eq[S] = {
    def forField[A2](field: Field[S, A2]): Eq[S] =
      field.schema.compile(self).contramap(field.get)
    implicit val monoidEqS: Monoid[Eq[S]] = Eq.allEqualBoundedSemilattice
    fields.foldMap(forField(_))
  }

  override def union[U](
      shapeId: ShapeId,
      hints: Hints,
      alternatives: Vector[Alt[U, _]],
      dispatch: Alt.Dispatcher[U]
  ): Eq[U] = {
    // A version of `Eq` that assumes that the RHS is "up-casted" to U.
    trait AltEq[A] {
      def eqv(a: A, u: U): Boolean
    }

    // The encoded form that Eq works against is a partially-applied curried function.
    implicit val encoderKInstance = new EncoderK[AltEq, U => Boolean] {
      def apply[A](fa: AltEq[A], a: A): U => Boolean = { (u: U) =>
        fa.eqv(a, u)
      }
      def absorb[A](f: A => (U => Boolean)): AltEq[A] = new AltEq[A] {
        def eqv(a: A, u: U): Boolean = f(a)(u)
      }
    }

    val precompiler = new Alt.Precompiler[AltEq] {
      def apply[A](label: String, instance: Schema[A]): AltEq[A] = {
        // Here we "cheat" to recover the `Alt` corresponding to `A`, as this information
        // is lost in the precompiler.
        val altA =
          alternatives.find(_.label == label).get.asInstanceOf[Alt[U, A]]
        // We're using it to get a function that lets us project the `U` against `A`.
        // `U` is not necessarily an `A, so this function returns an `Option`
        val eqA = instance.compile(self)
        new AltEq[A] {
          def eqv(a: A, u: U): Boolean = altA.project.lift(u) match {
            case None => false // U is not an A.
            case Some(a2) =>
              eqA.eqv(a, a2) // U is an A, we delegate the comparison
          }
        }
      }
    }

    val altEqU = dispatch.compile(precompiler)
    new Eq[U] {
      def eqv(x: U, y: U): Boolean = altEqU.eqv(x, y)
    }
  }

  override def biject[A, B](
      schema: Schema[A],
      bijection: Bijection[A, B]
  ): Eq[B] = {
    val aEq: Eq[A] = self(schema)
    (x: B, y: B) => aEq.eqv(bijection.from(x), bijection.from(y))
  }

  override def refine[A, B](
      schema: Schema[A],
      refinement: Refinement[A, B]
  ): Eq[B] = {
    val eqA: Eq[A] = self(schema)
    (x: B, y: B) => eqA.eqv(refinement.from(x), refinement.from(y))
  }

  override def lazily[A](suspend: Lazy[Schema[A]]): Eq[A] = {
    val eq = suspend.map(self(_))
    (x: A, y: A) => eq.value.eqv(x, y)
  }

  override def option[A](schema: Schema[A]): Eq[Option[A]] = {
    LenientOptionalCollectionEquality(schema) match {
      case Some(eq) => eq
      case None     => Eq.catsKernelEqForOption(self(schema))
    }
  }

  def primitiveEq[P](primitive: Primitive[P]): Eq[P] = {
    primitive match {
      case Primitive.PShort      => Eq[Short]
      case Primitive.PInt        => Eq[Int]
      case Primitive.PFloat      => floatEq
      case Primitive.PLong       => Eq[Long]
      case Primitive.PDouble     => doubleEq
      case Primitive.PBigInt     => Eq[BigInt]
      case Primitive.PBigDecimal => Eq[BigDecimal]
      case Primitive.PBoolean    => Eq[Boolean]
      case Primitive.PString     => Eq[String]
      case Primitive.PUUID       => Eq[UUID]
      case Primitive.PByte       => Eq[Byte]
      case Primitive.PBlob       => Eq[Blob]
      case Primitive.PDocument   => Eq[Document]
      case Primitive.PTimestamp  => Eq[Timestamp]
    }
  }

  type EqOpt[A] = Eq[Option[A]]
  // A sub-visitor that provides lenient equality for Option-wrapped collections,
  // where None and `Some(Empty)` are considered equivalent. s
  object LenientOptionalCollectionEquality
      extends SchemaVisitor.Optional[EqOpt] {
    override def collection[C[_], A](
        shapeId: ShapeId,
        hints: Hints,
        tag: CollectionTag[C],
        member: Schema[A]
    ): Option[Eq[Option[C[A]]]] = Some {
      val collectionEq = EqSchemaVisitor.collection(shapeId, hints, tag, member)
      new Eq[Option[C[A]]] {
        def eqv(x: Option[C[A]], y: Option[C[A]]): Boolean = (x, y) match {
          case (Some(left), Some(right)) => collectionEq.eqv(left, right)
          case (None, Some(right))       => tag.isEmpty(right)
          case (Some(left), None)        => tag.isEmpty(left)
          case (None, None)              => true
        }
      }
    }

    override def map[K, V](
        shapeId: ShapeId,
        hints: Hints,
        key: Schema[K],
        value: Schema[V]
    ): Option[Eq[Option[Map[K, V]]]] = Some {
      val mapEq = EqSchemaVisitor.map(shapeId, hints, key, value)
      new Eq[Option[Map[K, V]]] {
        def eqv(x: Option[Map[K, V]], y: Option[Map[K, V]]): Boolean =
          (x, y) match {
            case (Some(left), Some(right)) => mapEq.eqv(left, right)
            case (None, Some(right))       => right.isEmpty
            case (Some(left), None)        => left.isEmpty
            case (None, None)              => true
          }
      }
    }
  }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy