All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.spotify.scio.testing.CoderAssertions.scala Maven / Gradle / Ivy

/*
 * Copyright 2019 Spotify AB.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package com.spotify.scio.testing

import com.spotify.scio.coders._
import org.apache.beam.sdk.coders.Coder.NonDeterministicException
import org.apache.beam.sdk.coders.{Coder => BCoder}
import org.apache.beam.sdk.options.{PipelineOptions, PipelineOptionsFactory}
import org.apache.beam.sdk.testing.CoderProperties
import org.apache.beam.sdk.util.{CoderUtils, SerializableUtils}
import org.scalactic.Equality
import org.scalatest.Assertion
import org.scalatest.matchers.should.Matchers._
import org.typelevel.scalaccompat.annotation.nowarn

import scala.reflect.ClassTag

object CoderAssertions {
  private lazy val DefaultPipelineOptions = PipelineOptionsFactory.create()

  type CoderAssertion[T] = AssertionContext[T] => Assertion
  type CoderAssertionBase = AssertionContextBase => Assertion

  case class WithOptions(opts: PipelineOptions)

  trait CustomOptionsSyntax[T] {
    def should(coderAssertion: CoderAssertion[T]): AssertionContext[T]
  }

  implicit class ValueShouldSyntax[T](value: T) {
    def coderShould(
      coderAssertion: CoderAssertion[T]
    )(implicit c: Coder[T]): AssertionContext[T] = {
      val ctx = AssertionContext(Some(value), c)
      ctx.copy(lastAssertion = Some(coderAssertion(ctx)))
    }

    def kryoCoderShould(
      coderAssertion: CoderAssertion[T]
    )(implicit ct: ClassTag[T]): AssertionContext[T] = {
      val ctx = AssertionContext(Some(value), Coder.kryo[T])
      ctx.copy(lastAssertion = Some(coderAssertion(ctx)))
    }

    def coder(
      optionsTerm: WithOptions
    )(implicit c: Coder[T]): CustomOptionsSyntax[T] = new CustomOptionsSyntax[T] {
      override def should(coderAssertion: CoderAssertion[T]): AssertionContext[T] = {
        val ctx = AssertionContext(Some(value), c, opts = optionsTerm.opts)
        ctx.copy(lastAssertion = Some(coderAssertion(ctx)))
      }
    }
  }

  implicit class CoderShouldSyntax[T](c: Coder[T]) {
    def coderShould(
      coderAssertion: CoderAssertion[T]
    ): AssertionContext[T] = {
      val ctx = AssertionContext(None, c)
      ctx.copy(lastAssertion = Some(coderAssertion(ctx)))
    }
  }

  case class AssertionContext[T](
    actualValue: Option[T],
    coder: Coder[T],
    lastAssertion: Option[Assertion] = None,
    opts: PipelineOptions = DefaultPipelineOptions
  ) extends AssertionContextBase {
    override type ValType = T

    def and(
      coderAssertion: CoderAssertion[T]
    ): AssertionContext[T] = copy(lastAssertion = Some(coderAssertion(this)))
  }

  trait AssertionContextBase {
    type ValType
    val actualValue: Option[ValType]
    val coder: Coder[ValType]
    val lastAssertion: Option[Assertion]
    val opts: PipelineOptions
    lazy val beamCoder: BCoder[ValType] = CoderMaterializer.beamWithDefault(coder, opts)
  }

  def roundtrip[T: Equality](): CoderAssertion[T] = ctx =>
    checkRoundtripWithCoder[T](ctx.beamCoder, ctx.actualValue.get)

  def roundtripToBytes[T: Equality](expectedBytes: Array[Byte]): CoderAssertion[T] = ctx =>
    checkRoundtripWithCoder[T](ctx.beamCoder, ctx.actualValue.get, expectedBytes)

  def haveCoderInstance(expectedCoder: Coder[_]): CoderAssertionBase = ctx =>
    ctx.coder should ===(expectedCoder)

  def notFallback[T: ClassTag: Equality](): CoderAssertion[T] = ctx => {
    ctx.coder should !==(Coder.kryo[T])
    checkRoundtripWithCoder(ctx.beamCoder, ctx.actualValue.get)
  }

  def fallback[T: ClassTag: Equality](): CoderAssertion[T] = ctx => {
    ctx.coder should ===(Coder.kryo[T])
    checkRoundtripWithCoder(ctx.beamCoder, ctx.actualValue.get)
  }

  def beConsistentWithEquals(): CoderAssertionBase = ctx =>
    ctx.beamCoder.consistentWithEquals() shouldBe true

  def beNotConsistentWithEquals(): CoderAssertionBase = ctx =>
    ctx.beamCoder.consistentWithEquals() shouldBe false

  def beDeterministic(): CoderAssertionBase = ctx =>
    noException should be thrownBy ctx.beamCoder.verifyDeterministic()

  def beNonDeterministic(): CoderAssertionBase = ctx =>
    a[NonDeterministicException] should be thrownBy ctx.beamCoder.verifyDeterministic()

  def beSerializable(): CoderAssertionBase = ctx =>
    noException should be thrownBy SerializableUtils.ensureSerializable(ctx.beamCoder)

  def coderIsSerializable[A](implicit c: Coder[A]): Assertion =
    c.coderShould(beSerializable()).lastAssertion.get

  def beOfType[ExpectedCoder: ClassTag]: CoderAssertionBase = ctx =>
    ctx.coder shouldBe a[ExpectedCoder]

  def materializeTo[ExpectedBeamCoder: ClassTag]: CoderAssertionBase =
    ctx => {
      ctx.beamCoder shouldBe a[MaterializedCoder[_]]
      ctx.beamCoder.asInstanceOf[MaterializedCoder[_]].bcoder shouldBe a[ExpectedBeamCoder]
    }

  def materializeToTransformOf[ExpectedBeamCoder: ClassTag]: CoderAssertionBase =
    ctx => {
      ctx.beamCoder shouldBe a[MaterializedCoder[_]]
      ctx.beamCoder.asInstanceOf[MaterializedCoder[_]].bcoder shouldBe a[TransformCoder[_, _]]
      val innerCoder =
        ctx.beamCoder.asInstanceOf[MaterializedCoder[_]].bcoder.asInstanceOf[TransformCoder[_, _]]
      innerCoder.bcoder shouldBe a[ExpectedBeamCoder]
    }

  /*
   * Checks that Beam's registerByteSizeObserver() and encode() are consistent
   * */
  def bytesCountTested[T <: Object: ClassTag](): CoderAssertion[T] =
    ctx => {
      val arr = Array(ctx.actualValue.get)
      noException should be thrownBy CoderProperties.testByteCount(
        ctx.beamCoder,
        BCoder.Context.OUTER: @nowarn("cat=deprecation"),
        arr
      )
    }

  /**
   * Verifies that for the given coder and values, the structural values are equal if and only if
   * the encoded bytes are equal. Verifies for Outer and Nested contexts
   */
  def structuralValueConsistentWithEquals(): CoderAssertionBase = ctx => {
    noException should be thrownBy CoderProperties.structuralValueConsistentWithEquals(
      ctx.beamCoder,
      ctx.actualValue.get,
      ctx.actualValue.get
    )
  }

  /** Passes all checks on Beam coder */
  def beFullyCompliant[T <: Object: ClassTag](): CoderAssertion[T] = ctx => {
    structuralValueConsistentWithEquals()(ctx)
    beSerializable()(ctx)
    beConsistentWithEquals()(ctx)
    bytesCountTested[T]().apply(ctx)
    beDeterministic()(ctx)
  }

  def beFullyCompliantNonDeterministic[T <: Object: ClassTag](): CoderAssertion[T] =
    ctx => {
      structuralValueConsistentWithEquals()(ctx)
      beSerializable()(ctx)
      beConsistentWithEquals()(ctx)
      bytesCountTested[T]().apply(ctx)
      beNonDeterministic()(ctx)
    }

  def beFullyCompliantNotConsistentWithEquals[T <: Object: ClassTag](): CoderAssertion[T] =
    ctx => {
      structuralValueConsistentWithEquals()(ctx)
      beSerializable()(ctx)
      beNotConsistentWithEquals()(ctx)
      bytesCountTested[T]().apply(ctx)
      beDeterministic()(ctx)
    }

  private def checkRoundtripWithCoder[T: Equality](
    beamCoder: BCoder[T],
    actualValue: T,
    expectedBytes: Array[Byte] = null
  ): Assertion = {
    val bytes = CoderUtils.encodeToByteArray(beamCoder, actualValue)
    if (expectedBytes != null) {
      bytes should ===(expectedBytes)
    }
    val result = CoderUtils.decodeFromByteArray(beamCoder, bytes)
    result should ===(actualValue)
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy