org.scalatest.enablers.Aggregating.scala Maven / Gradle / Ivy
Show all versions of scalatest_2.9.1 Show documentation
/*
* Copyright 2001-2013 Artima, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.scalatest.enablers
import org.scalautils.Equality
import org.scalatest.words.ArrayWrapper
import scala.collection.GenTraversable
import scala.collection.GenTraversableOnce
import org.scalatest.FailureMessages
import scala.annotation.tailrec
import scala.collection.JavaConverters._
/**
* Supertrait for typeclasses that enable contain matcher syntax for aggregations.
*
*
* An Aggregating[A] provides access to the "aggregating nature" of type A in such
* a way that relevant contain matcher syntax can be used with type A. An A
* can be any type of "aggregation," a type that in some way aggregates or brings together other types. ScalaTest provides
* implicit implementations for several types. You can enable the contain matcher syntax on your own
* type U by defining an Aggregating[U} for the type and making it available implicitly.
*
*
* ScalaTest provides implicit Aggregating instances for scala.collection.GenTraversable,
* java.util.Collection, java.util.Map, String, and Array in the
* Aggregating companion object.
*
*
*
* Note, for an explanation of the difference between Containing and Aggregating, both of which
* enable contain matcher syntax, see the Containing
* versus Aggregating section of the main documentation for trait Containing.
*
*/
trait Aggregating[A] {
// TODO: Write tests that a NotAllowedException is thrown when no elements are passed, maybe if only one element is passed, and
// likely if an object is repeated in the list.
/**
* Implements contain atLeastOneOf syntax for aggregations of type A.
*
* @param aggregation an aggregation about which an assertion is being made
* @param eles elements at least one of which should be contained in the passed aggregation
* @return true if the passed aggregation contains at least one of the passed elements
*/
def containsAtLeastOneOf(aggregation: A, eles: Seq[Any]): Boolean
/**
* Implements contain theSameElementsAs syntax for aggregations of type A.
*
* @param leftAggregation an aggregation about which an assertion is being made
* @param rightAggregation an aggregation that should contain the same elements as the passed leftAggregation
* @return true if the passed leftAggregation contains the same elements as the passed rightAggregation
*/
def containsTheSameElementsAs(leftAggregation: A, rightAggregation: GenTraversable[Any]): Boolean
/**
* Implements contain theSameElementsInOrderAs syntax for aggregations of type A.
*
* @param leftAggregation an aggregation about which an assertion is being made
* @param rightAggregation an aggregation that should contain the same elements, in (iterated) order as the passed leftAggregation
* @return true if the passed leftAggregation contains the same elements, in (iterated) order, as the passed rightAggregation
*/
def containsTheSameElementsInOrderAs(leftAggregation: A, rightAggregation: GenTraversable[Any]): Boolean
/**
* Implements contain only syntax for aggregations of type A.
*
* @param aggregation an aggregation about which an assertion is being made
* @param eles the only elements that should be contained in the passed aggregation
* @return true if the passed aggregation contains only the passed elements
*/
def containsOnly(aggregation: A, eles: Seq[Any]): Boolean
/**
* Implements contain inOrderOnly syntax for aggregations of type A.
*
* @param aggregation an aggregation about which an assertion is being made
* @param eles the only elements that should be contained, in order of appearence in eles, in the passed aggregation
* @return true if the passed aggregation contains only the passed elements in (iteration) order
*/
def containsInOrderOnly(aggregation: A, eles: Seq[Any]): Boolean
/**
* Implements contain allOf syntax for aggregations of type A.
*
* @param aggregation an aggregation about which an assertion is being made
* @param eles elements all of which should be contained in the passed aggregation
* @return true if the passed aggregation contains all of the passed elements
*/
def containsAllOf(aggregation: A, eles: Seq[Any]): Boolean
/**
* Implements contain inOrder syntax for aggregations of type A.
*
* @param aggregation an aggregation about which an assertion is being made
* @param eles elements all of which should be contained, in order of appearance in eles, in the passed aggregation
* @return true if the passed aggregation contains all of the passed elements in (iteration) order
*/
def containsInOrder(aggregation: A, eles: Seq[Any]): Boolean
/* def containsAtMostOneOf(aggregation: A, eles: Seq[Any]): Boolean
*/
}
object Aggregating {
// TODO: Throwing exceptions is slow. Just do a pattern match and test the type before trying to cast it.
private def tryEquality[T](left: Any, right: Any, equality: Equality[T]): Boolean =
try equality.areEqual(left.asInstanceOf[T], right)
catch {
case cce: ClassCastException => false
}
private def checkTheSameElementsAs[T](left: GenTraversable[T], right: GenTraversable[Any], equality: Equality[T]): Boolean = {
case class ElementCount(element: Any, leftCount: Int, rightCount: Int)
object ZipNoMatch
def leftNewCount(next: Any, count: IndexedSeq[ElementCount]): IndexedSeq[ElementCount] = {
val idx = count.indexWhere(ec => tryEquality(next, ec.element, equality))
if (idx >= 0) {
val currentElementCount = count(idx)
count.updated(idx, ElementCount(currentElementCount.element, currentElementCount.leftCount + 1, currentElementCount.rightCount))
}
else
count :+ ElementCount(next, 1, 0)
}
def rightNewCount(next: Any, count: IndexedSeq[ElementCount]): IndexedSeq[ElementCount] = {
val idx = count.indexWhere(ec => tryEquality(next, ec.element, equality))
if (idx >= 0) {
val currentElementCount = count(idx)
count.updated(idx, ElementCount(currentElementCount.element, currentElementCount.leftCount, currentElementCount.rightCount + 1))
}
else
count :+ ElementCount(next, 0, 1)
}
val counts = right.toIterable.zipAll(left.toIterable, ZipNoMatch, ZipNoMatch).aggregate(IndexedSeq.empty[ElementCount])(
{ case (count, (nextLeft, nextRight)) =>
if (nextLeft == ZipNoMatch || nextRight == ZipNoMatch)
return false // size not match, can fail early
rightNewCount(nextRight, leftNewCount(nextLeft, count))
},
{ case (count1, count2) =>
count2.foldLeft(count1) { case (count, next) =>
val idx = count.indexWhere(ec => tryEquality(next.element, ec.element, equality))
if (idx >= 0) {
val currentElementCount = count(idx)
count.updated(idx, ElementCount(currentElementCount.element, currentElementCount.leftCount + next.leftCount, currentElementCount.rightCount + next.rightCount))
}
else
count :+ next
}
}
)
!counts.exists(e => e.leftCount != e.rightCount)
}
private def checkTheSameElementsInOrderAs[T](left: GenTraversable[T], right: GenTraversable[Any], equality: Equality[T]): Boolean = {
@tailrec
def checkEqual(left: Iterator[T], right: Iterator[Any]): Boolean = {
if (left.hasNext && right.hasNext) {
val nextLeft = left.next
val nextRight = right.next
if (!equality.areEqual(nextLeft, nextRight))
false
else
checkEqual(left, right)
}
else
left.isEmpty && right.isEmpty
}
checkEqual(left.toIterator, right.toIterator)
}
private def checkOnly[T](left: GenTraversable[T], right: GenTraversable[Any], equality: Equality[T]): Boolean = {
@tailrec
def findNext(value: T, rightItr: Iterator[Any], processedSet: Set[Any]): Set[Any] =
if (rightItr.hasNext) {
val nextRight = rightItr.next
if (processedSet.find(tryEquality(_, nextRight, equality)).isDefined)
throw new IllegalArgumentException(FailureMessages("onlyDuplicate", nextRight))
if (tryEquality(nextRight, value, equality))
processedSet + nextRight
else
findNext(value, rightItr, processedSet + nextRight)
}
else
processedSet
@tailrec
def checkEqual(leftItr: Iterator[T], rightItr: Iterator[Any], processedSet: Set[Any]): Boolean = {
if (leftItr.hasNext) {
val nextLeft = leftItr.next
if (processedSet.find(tryEquality(_, nextLeft, equality)).isDefined) // The nextLeft is contained in right, let's continue next
checkEqual(leftItr, rightItr, processedSet)
else {
val newProcessedSet = findNext(nextLeft, rightItr, processedSet)
if (newProcessedSet.find(tryEquality(_, nextLeft, equality)).isDefined) // The nextLeft is contained in right, let's continue next
checkEqual(leftItr, rightItr, newProcessedSet)
else // The nextLeft is not in right, let's fail early
false
}
}
else // No more element in left, left contains only elements of right.
true
}
checkEqual(left.toIterator, right.toIterator, Set.empty)
}
private def checkInOrderOnly[T](left: GenTraversable[T], right: GenTraversable[Any], equality: Equality[T]): Boolean = {
@tailrec
def checkEqual(left: T, right: Any, leftItr: Iterator[T], rightItr: Iterator[Any]): Boolean = {
if (equality.areEqual(left, right)) { // The first time in, left must equal right
// Now need to iterate through the left while it is equal to the right
@tailrec
def checkNextLeftAgainstCurrentRight(): Option[T] = { // Returns first left that doesn't match the current right, or None, if all remaining lefts matched current right
if (leftItr.hasNext) {
val nextLeft = leftItr.next
if (equality.areEqual(nextLeft, right))
checkNextLeftAgainstCurrentRight()
else
Some(nextLeft)
}
else None // No more lefts
}
val nextLeftOption = checkNextLeftAgainstCurrentRight()
nextLeftOption match {
case Some(nextLeft) =>
if (rightItr.hasNext) {
checkEqual(nextLeft, rightItr.next, leftItr, rightItr)
}
else false
case None => !rightItr.hasNext // No more lefts remaining, so we're good so long as no more rights remaining either.
}
}
else false
}
val leftItr: Iterator[T] = left.toIterator
val rightItr: Iterator[Any] = right.toIterator
if (leftItr.hasNext && rightItr.hasNext)
checkEqual(leftItr.next, rightItr.next, leftItr, rightItr)
else left.isEmpty && right.isEmpty
}
private def checkAllOf[T](left: GenTraversable[T], right: GenTraversable[Any], equality: Equality[T]): Boolean = {
@tailrec
def checkEqual(left: GenTraversable[T], rightItr: Iterator[Any], processedSet: Set[Any]): Boolean = {
if (rightItr.hasNext) {
val nextRight = rightItr.next
if (processedSet.contains(nextRight))
throw new IllegalArgumentException(FailureMessages("allOfDuplicate", nextRight))
if (left.exists(t => equality.areEqual(t, nextRight)))
checkEqual(left, rightItr, processedSet + nextRight)
else
false // Element not found, let's fail early
}
else // No more element in right, left contains all of right.
true
}
checkEqual(left, right.toIterator, Set.empty)
}
private def checkInOrder[T](left: GenTraversable[T], right: GenTraversable[Any], equality: Equality[T]): Boolean = {
@tailrec
def lastIndexOf(itr: Iterator[T], element: Any, idx: Option[Int], i: Int): Option[Int] = {
if (itr.hasNext) {
val next = itr.next
if (equality.areEqual(next, element))
lastIndexOf(itr, element, Some(i), i + 1)
else
lastIndexOf(itr, element, idx, i + 1)
}
else
idx
}
@tailrec
def checkEqual(left: GenTraversable[T], rightItr: Iterator[Any], processedSet: Set[Any]): Boolean = {
if (rightItr.hasNext) {
val nextRight = rightItr.next
if (processedSet.find(tryEquality(_, nextRight, equality)).isDefined)
throw new IllegalArgumentException(FailureMessages("inOrderDuplicate", nextRight))
lastIndexOf(left.toIterator, nextRight, None, 0) match {
case Some(idx) =>
checkEqual(left.drop(idx).tail, rightItr, processedSet + nextRight)
case None =>
false // Element not found, let's fail early
}
}
else // No more element in right, left contains all of right.
true
}
checkEqual(left, right.toIterator, Set.empty)
}
implicit def withGenTraversableElementEquality[E, TRAV[_] <: scala.collection.GenTraversable[_]](implicit equality: Equality[E]): Aggregating[TRAV[E]] =
new Aggregating[TRAV[E]] {
def containsAtLeastOneOf(trav: TRAV[E], elements: scala.collection.Seq[Any]): Boolean = {
trav.exists((e: Any) => elements.exists((ele: Any) => equality.areEqual(e.asInstanceOf[E], ele)))
}
def containsTheSameElementsAs(trav: TRAV[E], elements: GenTraversable[Any]): Boolean = {
checkTheSameElementsAs[E](trav.asInstanceOf[GenTraversable[E]], elements, equality)
}
def containsTheSameElementsInOrderAs(trav: TRAV[E], elements: GenTraversable[Any]): Boolean = {
checkTheSameElementsInOrderAs[E](trav.asInstanceOf[GenTraversable[E]], elements, equality)
}
def containsOnly(trav: TRAV[E], elements: scala.collection.Seq[Any]): Boolean = {
checkOnly[E](trav.asInstanceOf[GenTraversable[E]], elements, equality)
}
def containsInOrderOnly(trav: TRAV[E], elements: scala.collection.Seq[Any]): Boolean = {
checkInOrderOnly[E](trav.asInstanceOf[GenTraversable[E]], elements, equality)
}
def containsAllOf(trav: TRAV[E], elements: scala.collection.Seq[Any]): Boolean = {
checkAllOf(trav.asInstanceOf[GenTraversable[E]], elements, equality)
}
def containsInOrder(trav: TRAV[E], elements: scala.collection.Seq[Any]): Boolean = {
checkInOrder(trav.asInstanceOf[GenTraversable[E]], elements, equality)
}
}
// Enables (xs should contain ("HI")) (after being lowerCased)
implicit def convertEqualityToGenTraversableAggregating[E, TRAV[_] <: scala.collection.GenTraversable[_]](equality: Equality[E]): Aggregating[TRAV[E]] =
withGenTraversableElementEquality(equality)
implicit def withArrayElementEquality[E](implicit equality: Equality[E]): Aggregating[Array[E]] =
new Aggregating[Array[E]] {
def containsAtLeastOneOf(array: Array[E], elements: scala.collection.Seq[Any]): Boolean = {
new ArrayWrapper(array).exists((e: Any) => elements.exists((ele: Any) => equality.areEqual(e.asInstanceOf[E], ele)))
}
def containsTheSameElementsAs(array: Array[E], elements: GenTraversable[Any]): Boolean = {
checkTheSameElementsAs[E](new ArrayWrapper(array), elements, equality)
}
def containsTheSameElementsInOrderAs(array: Array[E], elements: GenTraversable[Any]): Boolean = {
checkTheSameElementsInOrderAs[E](new ArrayWrapper(array), elements, equality)
}
def containsOnly(array: Array[E], elements: scala.collection.Seq[Any]): Boolean = {
checkOnly(new ArrayWrapper(array), elements, equality)
}
def containsInOrderOnly(array: Array[E], elements: scala.collection.Seq[Any]): Boolean = {
checkInOrderOnly(new ArrayWrapper(array), elements, equality)
}
def containsAllOf(array: Array[E], elements: scala.collection.Seq[Any]): Boolean = {
checkAllOf(new ArrayWrapper(array), elements, equality)
}
def containsInOrder(array: Array[E], elements: scala.collection.Seq[Any]): Boolean = {
checkInOrder(new ArrayWrapper(array), elements, equality)
}
}
// Enables (xs should contain ("HI")) (after being lowerCased)
implicit def convertEqualityToArrayAggregating[E](equality: Equality[E]): Aggregating[Array[E]] =
withArrayElementEquality(equality)
implicit def withStringCharacterEquality(implicit equality: Equality[Char]): Aggregating[String] =
new Aggregating[String] {
def containsAtLeastOneOf(s: String, elements: scala.collection.Seq[Any]): Boolean = {
s.exists((e: Any) => elements.exists((ele: Any) => equality.areEqual(e.asInstanceOf[Char], ele)))
}
def containsTheSameElementsAs(s: String, elements: GenTraversable[Any]): Boolean = {
checkTheSameElementsAs(s, elements, equality)
}
def containsTheSameElementsInOrderAs(s: String, elements: GenTraversable[Any]): Boolean = {
checkTheSameElementsInOrderAs(s, elements, equality)
}
def containsOnly(s: String, elements: scala.collection.Seq[Any]): Boolean = {
checkOnly(s, elements, equality)
}
def containsInOrderOnly(s: String, elements: scala.collection.Seq[Any]): Boolean = {
checkInOrderOnly(s, elements, equality)
}
def containsAllOf(s: String, elements: scala.collection.Seq[Any]): Boolean = {
checkAllOf(s, elements, equality)
}
def containsInOrder(s: String, elements: scala.collection.Seq[Any]): Boolean = {
checkInOrder(s, elements, equality)
}
}
implicit def convertEqualityToStringAggregating(equality: Equality[Char]): Aggregating[String] =
withStringCharacterEquality(equality)
implicit def withGenMapElementEquality[K, V, MAP[_, _] <: scala.collection.GenMap[_, _]](implicit equality: Equality[(K, V)]): Aggregating[MAP[K, V]] =
new Aggregating[MAP[K, V]] {
def containsAtLeastOneOf(map: MAP[K, V], elements: scala.collection.Seq[Any]): Boolean = {
map.exists((e: Any) => elements.exists((ele: Any) => equality.areEqual(e.asInstanceOf[(K, V)], ele)))
}
def containsTheSameElementsAs(map: MAP[K, V], elements: GenTraversable[Any]): Boolean = {
checkTheSameElementsAs(map.asInstanceOf[scala.collection.GenMap[K, V]], elements, equality)
}
def containsTheSameElementsInOrderAs(map: MAP[K, V], elements: GenTraversable[Any]): Boolean = {
checkTheSameElementsInOrderAs(map.asInstanceOf[scala.collection.GenMap[K, V]], elements, equality)
}
def containsOnly(map: MAP[K, V], elements: scala.collection.Seq[Any]): Boolean = {
checkOnly(map.asInstanceOf[scala.collection.GenMap[K, V]], elements, equality)
}
def containsInOrderOnly(map: MAP[K, V], elements: scala.collection.Seq[Any]): Boolean = {
checkInOrderOnly(map.asInstanceOf[scala.collection.GenMap[K, V]], elements, equality)
}
def containsAllOf(map: MAP[K, V], elements: scala.collection.Seq[Any]): Boolean = {
checkAllOf(map.asInstanceOf[scala.collection.GenMap[K, V]], elements, equality)
}
def containsInOrder(map: MAP[K, V], elements: scala.collection.Seq[Any]): Boolean = {
checkInOrder(map.asInstanceOf[scala.collection.GenMap[K, V]], elements, equality)
}
}
implicit def convertEqualityToGenMapAggregating[K, V, MAP[_, _] <: scala.collection.GenMap[_, _]](equality: Equality[(K, V)]): Aggregating[MAP[K, V]] =
withGenMapElementEquality(equality)
implicit def withJavaCollectionElementEquality[E, JCOL[_] <: java.util.Collection[_]](implicit equality: Equality[E]): Aggregating[JCOL[E]] =
new Aggregating[JCOL[E]] {
def containsAtLeastOneOf(col: JCOL[E], elements: scala.collection.Seq[Any]): Boolean = {
col.asInstanceOf[java.util.Collection[E]].asScala.exists((e: Any) => elements.exists((ele: Any) => equality.areEqual(e.asInstanceOf[E], ele)))
}
def containsTheSameElementsAs(col: JCOL[E], elements: GenTraversable[Any]): Boolean = {
checkTheSameElementsAs(col.asInstanceOf[java.util.Collection[E]].asScala, elements, equality)
}
def containsTheSameElementsInOrderAs(col: JCOL[E], elements: GenTraversable[Any]): Boolean = {
checkTheSameElementsInOrderAs(col.asInstanceOf[java.util.Collection[E]].asScala, elements, equality)
}
def containsOnly(col: JCOL[E], elements: scala.collection.Seq[Any]): Boolean = {
checkOnly(col.asInstanceOf[java.util.Collection[E]].asScala, elements, equality)
}
def containsInOrderOnly(col: JCOL[E], elements: scala.collection.Seq[Any]): Boolean = {
checkInOrderOnly(col.asInstanceOf[java.util.Collection[E]].asScala, elements, equality)
}
def containsAllOf(col: JCOL[E], elements: scala.collection.Seq[Any]): Boolean = {
checkAllOf(col.asInstanceOf[java.util.Collection[E]].asScala, elements, equality)
}
def containsInOrder(col: JCOL[E], elements: scala.collection.Seq[Any]): Boolean = {
checkInOrder(col.asInstanceOf[java.util.Collection[E]].asScala, elements, equality)
}
}
implicit def convertEqualityToJavaCollectionAggregating[E, JCOL[_] <: java.util.Collection[_]](equality: Equality[E]): Aggregating[JCOL[E]] =
withJavaCollectionElementEquality(equality)
implicit def withJavaMapElementEquality[K, V, JMAP[_, _] <: java.util.Map[_, _]](implicit equality: Equality[(K, V)]): Aggregating[JMAP[K, V]] =
new Aggregating[JMAP[K, V]] {
// This is needed as asScala does not preserve the original iterated order
private def getScalaMapInOrder(javaMap: JMAP[K, V]): scala.collection.GenMap[K, V] = {
val map = new collection.mutable.LinkedHashMap[K, V]
val itr = javaMap.entrySet.iterator
while (itr.hasNext) {
val entry = itr.next
map += ((entry.getKey.asInstanceOf[K], entry.getValue.asInstanceOf[V]))
}
map
}
def containsAtLeastOneOf(map: JMAP[K, V], elements: scala.collection.Seq[Any]): Boolean = {
getScalaMapInOrder(map).exists((e: Any) => elements.exists((ele: Any) => equality.areEqual(e.asInstanceOf[(K, V)], ele)))
}
def containsTheSameElementsAs(map: JMAP[K, V], elements: GenTraversable[Any]): Boolean = {
checkTheSameElementsAs(getScalaMapInOrder(map), elements, equality)
}
def containsTheSameElementsInOrderAs(map: JMAP[K, V], elements: GenTraversable[Any]): Boolean = {
checkTheSameElementsInOrderAs(getScalaMapInOrder(map), elements, equality)
}
def containsOnly(map: JMAP[K, V], elements: scala.collection.Seq[Any]): Boolean = {
checkOnly(getScalaMapInOrder(map), elements, equality)
}
def containsInOrderOnly(map: JMAP[K, V], elements: scala.collection.Seq[Any]): Boolean = {
checkInOrderOnly(getScalaMapInOrder(map), elements, equality)
}
def containsAllOf(map: JMAP[K, V], elements: scala.collection.Seq[Any]): Boolean = {
checkAllOf(getScalaMapInOrder(map), elements, equality)
}
def containsInOrder(map: JMAP[K, V], elements: scala.collection.Seq[Any]): Boolean = {
checkInOrder(getScalaMapInOrder(map), elements, equality)
}
}
implicit def convertEqualityToJavaMapAggregating[K, V, JMAP[_, _] <: java.util.Map[_, _]](equality: Equality[(K, V)]): Aggregating[JMAP[K, V]] =
withJavaMapElementEquality(equality)
}