All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.spotify.scio.testing.SCollectionMatchers.scala Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2019 Spotify AB.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package com.spotify.scio.testing

import java.lang.{Iterable => JIterable}
import java.util.{Map => JMap}
import com.spotify.scio.coders.Coder
import com.spotify.scio.values.SCollection
import com.twitter.chill.Externalizer
import org.apache.beam.sdk.testing.PAssert
import org.apache.beam.sdk.testing.PAssert.{IterableAssert, SingletonAssert}
import org.apache.beam.sdk.transforms.SerializableFunction
import org.apache.beam.sdk.transforms.windowing.BoundedWindow
import org.apache.beam.sdk.util.{CoderUtils, SerializableSupplier}
import org.scalatest.matchers.{MatchResult, Matcher}
import org.{hamcrest => h}
import org.hamcrest.Matchers
import org.hamcrest.MatcherAssert.assertThat

import scala.jdk.CollectionConverters._
import com.twitter.chill.ClosureCleaner
import cats.kernel.Eq
import org.apache.beam.sdk.testing.SerializableMatchers
import com.spotify.scio.coders.CoderMaterializer
import com.spotify.scio.ScioContext
import org.apache.beam.sdk.testing.SerializableMatcher

final private case class TestWrapper[T: Eq](get: T) {

  override def toString: String = Pretty.printer.apply(get).render

  override def equals(other: Any): Boolean =
    other match {
      case TestWrapper(o: T @unchecked) => Eq[T].eqv(get, o)
      case o                            => get == o
    }
}

private object TestWrapper {

  def wrap[T: Coder: Eq](coll: SCollection[T]): SCollection[TestWrapper[T]] =
    coll.map(t => TestWrapper(t))

  def wrap[T: Eq](coll: JIterable[T]): JIterable[TestWrapper[T]] =
    coll.asScala.map(t => TestWrapper(t)).asJava

  implicit def testWrapperCoder[T: Coder: Eq]: Coder[TestWrapper[T]] =
    Coder.xmap(Coder[T])(t => TestWrapper(t), w => w.get)
}

private object ScioMatchers {

  /** Create a hamcrest matcher that can be serialized using a Coder[T]. */
  private def supplierFromCoder[A: Coder, B](@transient a: A, @transient context: ScioContext)(
    builder: A => B
  ): SerializableSupplier[B] = {
    val coder = CoderMaterializer.beam(context, Coder[A])
    val encoded = CoderUtils.encodeToByteArray(coder, a)
    () => builder(CoderUtils.decodeFromByteArray(coder, encoded))
  }

  /**
   * This is equivalent to [[org.apache.beam.sdk.testing.PAssert#containsInAnyOrder()]] but will but
   * have a nicer message in case of failure.
   */
  def containsInAnyOrder[T: Coder](
    ts: Seq[T],
    context: ScioContext
  ): h.Matcher[JIterable[T]] =
    SerializableMatchers.fromSupplier {
      supplierFromCoder(ts, context) { ds =>
        val items = ds.mkString("\n\t\t", "\n\t\t", "\n")
        val message = s"Expected: iterable with items [$items]"
        val c = Matchers
          .containsInAnyOrder(ds: _*)
          .asInstanceOf[h.Matcher[JIterable[T]]]
        new h.BaseMatcher[JIterable[T]] {
          override def matches(o: AnyRef): Boolean = c.matches(o)
          override def describeTo(d: h.Description): Unit = d.appendText(message)
          override def describeMismatch(i: AnyRef, d: h.Description): Unit =
            c.describeMismatch(i, d)
        }
      }
    }

  def makeFn[T](
    f: JIterable[T] => Unit
  ): SerializableFunction[JIterable[T], Void] =
    new SerializableFunction[JIterable[T], Void] {
      // delegate serialization to Kryo to avoid serialization issues in tests
      // when a non-serializable object is captured by the closure
      private[this] val impl = Externalizer(f)

      override def apply(input: JIterable[T]): Void = {
        impl.get(input)
        null
      }
    }

  def makeFnSingle[T](f: T => Unit): SerializableFunction[T, Void] =
    new SerializableFunction[T, Void] {
      // delegate serialization to Kryo to avoid serialization issues in tests
      // when a non-serializable object is captured by the closure
      private[this] val impl = Externalizer(f)

      override def apply(input: T): Void = {
        impl.get(input)
        null
      }
    }

  def assertThatFn[T: Eq](
    mm: h.Matcher[JIterable[TestWrapper[T]]]
  ): SerializableFunction[JIterable[T], Void] =
    makeFn[T](in => assertThat(TestWrapper.wrap(in), mm))

  def assertThatNotFn[T: Eq](
    mm: h.Matcher[JIterable[TestWrapper[T]]]
  ): SerializableFunction[JIterable[T], Void] =
    makeFn[T](in => assertThat(TestWrapper.wrap(in), Matchers.not(mm)))

  def assert[T: Eq](
    p: Iterable[TestWrapper[T]] => Boolean
  ): SerializableFunction[JIterable[T], Void] =
    makeFn[T](in => Predef.assert(p(TestWrapper.wrap(in).asScala)))

  def assertSingle[T: Eq](p: TestWrapper[T] => Boolean): SerializableFunction[T, Void] =
    makeFnSingle[T](in => Predef.assert(p(TestWrapper(in))))

  def assertNot[T: Eq](
    p: Iterable[TestWrapper[T]] => Boolean
  ): SerializableFunction[JIterable[T], Void] =
    makeFn[T](in => Predef.assert(!p(TestWrapper.wrap(in).asScala)))

  def assertNotSingle[T: Eq](p: TestWrapper[T] => Boolean): SerializableFunction[T, Void] =
    makeFnSingle[T](in => Predef.assert(!p(TestWrapper(in))))

  def isEqualTo[T: Eq: Coder](context: ScioContext, t: T): SerializableFunction[T, Void] = {
    val mm: SerializableMatcher[Any] =
      SerializableMatchers.fromSupplier {
        supplierFromCoder(t, context)(t => Matchers.equalTo[Any](TestWrapper(t)))
      }
    makeFnSingle[T](in => assertThat(TestWrapper(in), mm))
  }

  def notEqualTo[T: Eq: Coder](context: ScioContext, t: T): SerializableFunction[T, Void] = {
    val mm: SerializableMatcher[Any] =
      SerializableMatchers.fromSupplier {
        supplierFromCoder(t, context)(t => Matchers.not(Matchers.equalTo[Any](TestWrapper(t))))
      }

    makeFnSingle[T](in => assertThat(TestWrapper(in), mm))
  }

  def hasItem[T: Coder](t: T, context: ScioContext): h.Matcher[JIterable[T]] =
    SerializableMatchers.fromSupplier {
      supplierFromCoder(t, context) { t =>
        Matchers.hasItem(t).asInstanceOf[h.Matcher[JIterable[T]]]
      }
    }
}

/**
 * Trait with ScalaTest [[org.scalatest.matchers.Matcher Matcher]] s for
 * [[com.spotify.scio.values.SCollection SCollection]] s.
 */
trait SCollectionMatchers extends EqInstances {

  import ScioMatchers.makeFn

  sealed trait MatcherBuilder[T] {
    _: Matcher[T] =>

    type From
    type To >: From
    type AssertBuilder = From => To

    def matcher(builder: AssertBuilder): Matcher[T]

    def matcher: Matcher[T] = matcher(identity)
  }

  sealed trait IterableMatcher[T, B] extends MatcherBuilder[T] with Matcher[T] {
    type From = IterableAssert[B]
    type To = From

    override def apply(left: T): MatchResult = matcher(left)
  }

  sealed trait SingleMatcher[T, B] extends MatcherBuilder[T] with Matcher[T] {
    type From = SingletonAssert[B]
    type To = From

    override def apply(left: T): MatchResult = matcher(left)
  }

  /*
  Wrapper for PAssert statements. PAssert does not perform assertions or throw exceptions until
  sc.run() is called. So MatchResult should always match true for "a should $Matcher" cases and
  false for "a shouldNot $Matcher" cases. We also need to run different assertions for positive
  (shouldFn) and negative (shouldNotFn) cases.
   */
  private def m(shouldFn: () => Any, shouldNotFn: () => Any): MatchResult = {
    val isShouldNot = Thread
      .currentThread()
      .getStackTrace
      .filter(_.getClassName.startsWith("org.scalatest."))
      .exists(e => e.getClassName.contains("NotWord") || e.getMethodName == "shouldNot")
    val r = if (isShouldNot) {
      shouldNotFn()
      false
    } else {
      shouldFn()
      true
    }
    MatchResult(r, "", "")
  }

  // Due to  https://github.com/GoogleCloudPlatform/DataflowJavaSDK/issues/434
  // SerDe cycle on each element to keep consistent with values on the expected side
  private def serDeCycle[T: Coder](scollection: SCollection[T]): SCollection[T] = {
    val coder = scollection.internal.getCoder
    scollection
      .map(e =>
        CoderUtils
          .decodeFromByteArray(coder, CoderUtils.encodeToByteArray(coder, e))
      )
  }

  /**
   * SCollection assertion only applied to the specified window, running the checker only on the
   * on-time pane for each key.
   */
  def inOnTimePane[T](window: BoundedWindow)(matcher: MatcherBuilder[T]): Matcher[T] =
    matcher match {
      case value: SingleMatcher[T, _] =>
        value.matcher(_.inOnTimePane(window))
      case value: IterableMatcher[T, _] =>
        value.matcher(_.inOnTimePane(window))
    }

  /** SCollection assertion only applied to the specified window. */
  def inWindow[T, B](
    window: BoundedWindow
  )(matcher: IterableMatcher[T, B]): Matcher[T] =
    matcher.matcher(_.inWindow(window))

  /**
   * SCollection assertion only applied to the specified window across all panes that were not
   * produced by the arrival of late data.
   */
  def inCombinedNonLatePanes[T, B](
    window: BoundedWindow
  )(matcher: IterableMatcher[T, B]): Matcher[T] =
    matcher.matcher(_.inCombinedNonLatePanes(window))

  /**
   * SCollection assertion only applied to the specified window, running the checker only on the
   * final pane for each key.
   */
  def inFinalPane[T, B](
    window: BoundedWindow
  )(matcher: MatcherBuilder[T]): Matcher[T] =
    matcher match {
      case value: SingleMatcher[T, _] =>
        value.matcher(_.inFinalPane(window))
      case value: IterableMatcher[T, _] =>
        value.matcher(_.inFinalPane(window))
    }

  /**
   * SCollection assertion only applied to the specified window, running the checker only on the
   * late pane for each key.
   */
  def inLatePane[T, B](
    window: BoundedWindow
  )(matcher: MatcherBuilder[T]): Matcher[T] =
    matcher match {
      case value: SingleMatcher[T, _] =>
        value.matcher(_.inLatePane(window))
      case value: IterableMatcher[T, _] =>
        value.matcher(_.inLatePane(window))
    }

  /**
   * SCollection assertion only applied to the specified window, running the checker only on the
   * early pane for each key.
   */
  def inEarlyPane[T](
    window: BoundedWindow
  )(matcher: MatcherBuilder[T]): Matcher[T] =
    matcher match {
      case value: SingleMatcher[T, _] =>
        value.matcher(_.inEarlyPane(window))
      case value: IterableMatcher[T, _] =>
        value.matcher(_.inEarlyPane(window))
    }

  /**
   * SCollection assertion only applied to the specified window. The assertion expect outputs to be
   * produced to the provided window exactly once.
   */
  def inOnlyPane[T, B](
    window: BoundedWindow
  )(matcher: SingleMatcher[T, B]): Matcher[T] =
    matcher.matcher(_.inOnlyPane(window))

  /** SCollection assertion only applied to early timing global window. */
  def inEarlyGlobalWindowPanes[T, B](
    matcher: IterableMatcher[T, B]
  ): Matcher[T] =
    matcher.matcher(_.inEarlyGlobalWindowPanes)

  /** Assert that the SCollection in question contains the provided elements. */
  def containInAnyOrder[T: Coder: Eq](
    value: Iterable[T]
  ): IterableMatcher[SCollection[T], T] =
    new IterableMatcher[SCollection[T], T] {
      override def matcher(builder: AssertBuilder): Matcher[SCollection[T]] =
        new Matcher[SCollection[T]] {
          override def apply(left: SCollection[T]): MatchResult = {
            val v = Externalizer(value.toSeq.map(x => TestWrapper(x))) // defeat closure
            val mm = ScioMatchers.containsInAnyOrder(v.get, left.context)
            val f = ScioMatchers.assertThatNotFn[T](mm)
            val g = ScioMatchers.assertThatFn[T](mm)
            val assertion = builder(PAssert.that(serDeCycle(left).internal))
            m(
              () => assertion.satisfies(g),
              () => assertion.satisfies(f)
            )
          }
        }
    }

  /** Assert that the SCollection in question contains a single provided element. */
  def containSingleValue[T: Coder: Eq](value: T): SingleMatcher[SCollection[T], T] =
    new SingleMatcher[SCollection[T], T] {
      override def matcher(builder: AssertBuilder): Matcher[SCollection[T]] =
        new Matcher[SCollection[T]] {
          override def apply(left: SCollection[T]): MatchResult = {
            ScioMatchers
            val assertion = builder(PAssert.thatSingleton(serDeCycle(left).internal))
            m(
              () => assertion.satisfies(ScioMatchers.isEqualTo(left.context, value)),
              () => assertion.satisfies(ScioMatchers.notEqualTo(left.context, value))
            )
          }
        }
    }

  /**
   * Assert that the SCollection in question contains the provided element without making
   * assumptions about other elements in the collection.
   */
  def containValue[T: Coder: Eq](value: T): IterableMatcher[SCollection[T], T] =
    new IterableMatcher[SCollection[T], T] {
      override def matcher(builder: AssertBuilder): Matcher[SCollection[T]] =
        new Matcher[SCollection[T]] {
          override def apply(left: SCollection[T]): MatchResult = {
            val v = Externalizer(TestWrapper[T](value)) // defeat closure
            val mm = ScioMatchers.hasItem(v.get, left.context)
            val should = ScioMatchers.assertThatFn[T](mm)
            val shouldNot = ScioMatchers.assertThatNotFn[T](mm)
            val assertion = builder(PAssert.that(serDeCycle(left).internal))
            m(
              () => assertion.satisfies(should),
              () => assertion.satisfies(shouldNot)
            )
          }
        }
    }

  /** Assert that the SCollection in question is empty. */
  val beEmpty: IterableMatcher[SCollection[_], Any] =
    new IterableMatcher[SCollection[_], Any] {
      override def matcher(builder: AssertBuilder): Matcher[SCollection[_]] =
        new Matcher[SCollection[_]] {
          override def apply(left: SCollection[_]): MatchResult = {
            val assertion = PAssert.that(left.asInstanceOf[SCollection[Any]].internal)
            m(
              () => builder(assertion).empty(),
              () =>
                builder(assertion)
                  .satisfies(makeFn(in => assert(in.iterator().hasNext, "SCollection is empty")))
            )
          }
        }
    }

  /** Assert that the SCollection in question has provided size. */
  def haveSize(size: Int): IterableMatcher[SCollection[_], Any] =
    new IterableMatcher[SCollection[_], Any] {
      override def matcher(builder: AssertBuilder): Matcher[SCollection[_]] =
        new Matcher[SCollection[_]] {
          override def apply(left: SCollection[_]): MatchResult = {
            val s = size
            val f = makeFn[Any] { in =>
              val inSize = in.asScala.size
              assert(inSize == s, s"SCollection expected size: $s, actual: $inSize")
            }
            val g = makeFn[Any] { in =>
              val inSize = in.asScala.size
              assert(inSize != s, s"SCollection expected size: not $s, actual: $inSize")
            }
            val assertion = PAssert.that(left.asInstanceOf[SCollection[Any]].internal)
            m(
              () => builder(assertion).satisfies(f),
              () => builder(assertion).satisfies(g)
            )
          }
        }
    }

  /** Assert that the SCollection in question is equivalent to the provided map. */
  def equalMapOf[K: Coder, V: Coder](
    value: Map[K, V]
  ): SingleMatcher[SCollection[(K, V)], JMap[K, V]] =
    new SingleMatcher[SCollection[(K, V)], JMap[K, V]] {
      override def matcher(builder: AssertBuilder): Matcher[SCollection[(K, V)]] =
        new Matcher[SCollection[(K, V)]] {
          override def apply(left: SCollection[(K, V)]): MatchResult = {
            val assertion = builder(PAssert.thatMap(serDeCycle(left).toKV.internal))
            m(
              () => assertion.isEqualTo(value.asJava),
              () => assertion.notEqualTo(value.asJava)
            )
          }
        }
    }

  // TODO: investigate why multi-map doesn't work

  /** Assert that the SCollection in question satisfies the provided function. */
  def satisfy[T: Coder: Eq](
    predicate: Iterable[T] => Boolean
  ): IterableMatcher[SCollection[T], T] =
    new IterableMatcher[SCollection[T], T] {
      override def matcher(builder: AssertBuilder): Matcher[SCollection[T]] =
        new Matcher[SCollection[T]] {
          override def apply(left: SCollection[T]): MatchResult = {
            val cleanedPredicate = ClosureCleaner.clean(predicate)
            val p: Iterable[TestWrapper[T]] => Boolean = it => cleanedPredicate(it.map(_.get))
            val f = ScioMatchers.assert(p)
            val g = ScioMatchers.assertNot(p)
            val assertion = builder(PAssert.that(serDeCycle(left).internal))
            m(
              () => assertion.satisfies(f),
              () => assertion.satisfies(g)
            )
          }
        }
    }

  /**
   * Assert that the SCollection in question contains a single element which satisfies the provided
   * function.
   */
  def satisfySingleValue[T: Coder: Eq](
    predicate: T => Boolean
  ): SingleMatcher[SCollection[T], T] =
    new SingleMatcher[SCollection[T], T] {
      override def matcher(builder: AssertBuilder): Matcher[SCollection[T]] =
        new Matcher[SCollection[T]] {
          override def apply(left: SCollection[T]): MatchResult = {
            val cleanedPredicate = ClosureCleaner.clean(predicate)
            val p: TestWrapper[T] => Boolean = t => cleanedPredicate(t.get)
            val f = ScioMatchers.assertSingle(p)
            val g = ScioMatchers.assertNotSingle(p)
            val assertion = builder(PAssert.thatSingleton(serDeCycle(left).internal))
            m(
              () => assertion.satisfies(f),
              () => assertion.satisfies(g)
            )
          }
        }
    }

  /** Assert that all elements of the SCollection in question satisfy the provided function. */
  def forAll[T: Coder: Eq](
    predicate: T => Boolean
  ): IterableMatcher[SCollection[T], T] = {
    val f = ClosureCleaner.clean(predicate)
    satisfy(_.forall(f))
  }

  /** Assert that some elements of the SCollection in question satisfy the provided function. */
  def exist[T: Coder: Eq](
    predicate: T => Boolean
  ): IterableMatcher[SCollection[T], T] = {
    val f = ClosureCleaner.clean(predicate)
    satisfy(_.exists(f))
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy