org.scalactic.anyvals.CompileTimeAssertions.scala Maven / Gradle / Ivy
/*
* Copyright 2001-2014 Artima, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.scalactic.anyvals
import org.scalactic.Resources
import reflect.macros.Context
/**
* Trait providing assertion methods that can be called at compile time from macros
* to validate literals in source code.
*
*
* The intent of CompileTimeAssertions
is to make it easier to create
* AnyVal
s that restrict the values of types for which Scala supports
* literals: Int
, Long
, Float
, Double
, Char
,
* and String
. For example, if you are using odd integers in many places
* in your code, you might have validity checks scattered throughout your code. Here's
* an example of a method that both requires an odd Int
is passed (as a
* precondition, and ensures an odd * Int
is returned (as
* a postcondition):
*
*
*
* def nextOdd(i: Int): Int = {
* def isOdd(x: Int): Boolean = x.abs % 2 == 1
* require(isOdd(i))
* (i + 2) ensuring (isOdd(_))
* }
*
*
*
* In either the precondition or postcondition check fails, an exception will
* be thrown at runtime. If you have many methods like this you may want to
* create a type to represent an odd Int
, so that the checking
* for validity errors is isolated in just one place. By using an AnyVal
* you can avoid boxing the Int
, which may be more efficient.
* This might look like:
*
*
*
* final class OddInt private (val value: Int) extends AnyVal {
* override def toString: String = s"OddInt($value)"
* }
*
* object OddInt {
* def apply(value: Int): OddInt = {
* require(value.abs % 2 == 1)
* new OddInt(value)
* }
* }
*
*
*
* An AnyVal
cannot have any constructor code, so to ensure that
* any Int
passed to the OddInt
constructor is actually
* odd, the constructor must be private. That way the only way to construct a
* new OddInt
is via the apply
factory method in the
* OddInt
companion object, which can require that the value be
* odd. This design eliminates the need for placing require
and
* ensuring
clauses anywhere else that odd Int
s are
* needed, because the type promises the constraint. The nextOdd
* method could, therefore, be rewritten as:
*
*
*
* def nextOdd(oi: OddInt): OddInt = OddInt(oi.value + 2)
*
*
*
* Using the compile-time assertions provided by this trait, you can construct
* a factory method implemented via a macro that causes a compile failure
* if OddInt.apply
is passed anything besides an odd
* Int
literal. Class OddInt
would look exactly the
* same as before:
*
*
*
* final class OddInt private (val value: Int) extends AnyVal {
* override def toString: String = s"OddInt($value)"
* }
*
*
*
* In the companion object, however, the apply
method would
* be implemented in terms of a macro. Because the apply
method
* will only work with literals, you'll need a second method that can work
* an any expression of type Int
. We recommend a from
method
* that returns an Option[OddInt]
that returns Some[OddInt}
if the passed Int
is odd,
* else returns None
, and an ensuringValid
method that returns an OddInt
* if the passed Int
is valid, else throws AssertionError
.
*
*
*
* object OddInt {
*
* // The from factory method validates at run time
* def from(value: Int): Option[OddInt] =
* if (OddIntMacro.isValid(value)) Some(new OddInt(value)) else None
*
* // The ensuringValid factory method validates at run time, but throws
* // an AssertionError if invalid
* def ensuringValid(value: Int): OddInt =
* if (OddIntMacro.isValid(value)) new OddInt(value) else {
* throw new AssertionError(s"$value was not a valid OddInt")
* }
*
* // The apply factory method validates at compile time
* import scala.language.experimental.macros
* def apply(value: Int): OddInt = macro OddIntMacro.apply
* }
*
*
*
* The apply
method refers to a macro implementation method in class
* PosIntMacro
. The macro implementation of any such method can look
* very similar to this one. The only changes you'd need to make is the
* isValid
method implementation and the text of the error messages.
*
*
*
* import org.scalactic.anyvals.CompileTimeAssertions
* import reflect.macros.Context
*
* object OddIntMacro extends CompileTimeAssertions {
*
* // Validation method used at both compile- and run-time
* def isValid(i: Int): Boolean = i.abs % 2 == 1
*
* // Apply macro that performs a compile-time assertion
* def apply(c: Context)(value: c.Expr[Int]): c.Expr[OddInt] = {
*
* // Prepare potential compiler error messages
* val notValidMsg = "OddInt.apply can only be invoked on odd Int literals, like OddInt(3)."
* val notLiteralMsg = "OddInt.apply can only be invoked on Int literals, like " +
* "OddInt(3). Please use OddInt.from instead."
*
* // Validate via a compile-time assertion
* ensureValidIntLiteral(c)(value, notValidMsg, notLiteralMsg)(isValid)
*
* // Validated, so rewrite the apply call to a from call
* c.universe.reify { OddInt.ensuringValid(value.splice) }
* }
* }
*
*
*
* The isValid
method just takes the underlying type and returns true
if it is valid,
* else false
. This method is placed here so the same valiation code can be used both in
* the from
method at runtime and the apply
macro at compile time. The apply
* actually does just two things. It calls a ensureValidIntLiteral
, performing a compile-time assertion
* that value passed to apply
is an Int
literal that is valid (in this case, odd).
* If the assertion fails, ensureValidIntLiteral
will complete abruptly with an exception that will
* contain an appropriate error message (one of the two you passed in) and cause a compiler error with that message.
* If the assertion succeeds, ensureValidIntLiteral
will just return normally. The next line of code
* will then execute. This line of code must construct an AST (abstract syntax tree) of code that will replace
* the OddInt.apply
invocation. We invoke the other factory method that either returns an OddInt
* or throws an AssertionError
, since we've proven at compile time that the call will succeed.
*
*
*
* You may wish to use quasi-quotes instead of reify. The reason we use reify is that this also works on 2.10 without
* any additional plugin (i.e., you don't need macro paradise), and Scalactic supports 2.10.
*
*/
trait CompileTimeAssertions {
/**
* Ensures a given expression of type Int
is a literal with a valid value according to a given validation function.
*
*
* If the given Int
expression is a literal whose value satisfies the given validation function, this method will
* return normally. Otherwise, if the given Int
expression is not a literal, this method will complete abruptly with
* an exception whose detail message includes the String
passed as notLiteralMsg
. Otherwise, the
* given Int
expression is a literal that does not satisfy the given validation function, so this method will
* complete abruptly with an exception whose detail message includes the String
passed as notValidMsg
.
*
*
*
* This method is intended to be invoked at compile time from macros. When called from a macro, exceptions thrown by this method
* will result in compiler errors. The detail message of the thrown exception will appear as the compiler error message.
*
*
* @param c the compiler context for this assertion
* @param value the Int
expression to validate
* @param notValidMsg a String
message to include in the exception thrown if the expression is a literal, but not valid
* @param notLiteralMsg a String
message to include in the exception thrown if the expression is not a literal
* @param isValid a function used to validate a literal value parsed from the given expression
*/
def ensureValidIntLiteral(c: Context)(value: c.Expr[Int], notValidMsg: String, notLiteralMsg: String)(isValid: Int => Boolean): Unit = {
import c.universe._
value.tree match {
case Literal(intConst) =>
val literalValue = intConst.value.toString.toInt
if (!isValid(literalValue))
c.abort(c.enclosingPosition, notValidMsg)
case _ =>
c.abort(c.enclosingPosition, notLiteralMsg)
}
}
/**
* Ensures a given expression of type Long
is a literal with a valid value according to a given validation function.
*
*
* If the given Long
expression is a literal whose value satisfies the given validation function, this method will
* return normally. Otherwise, if the given Long
expression is not a literal, this method will complete abruptly with
* an exception whose detail message includes the String
passed as notLiteralMsg
. Otherwise, the
* given Long
expression is a literal that does not satisfy the given validation function, so this method will
* complete abruptly with an exception whose detail message includes the String
passed as notValidMsg
.
*
*
*
* This method is intended to be invoked at compile time from macros. When called from a macro, exceptions thrown by this method
* will result in compiler errors. The detail message of the thrown exception will appear as the compiler error message.
*
*
* @param c the compiler context for this assertion
* @param value the Long
expression to validate
* @param notValidMsg a String
message to include in the exception thrown if the expression is a literal, but not valid
* @param notLiteralMsg a String
message to include in the exception thrown if the expression is not a literal
* @param isValid a function used to validate a literal value parsed from the given expression
*/
def ensureValidLongLiteral(c: Context)(value: c.Expr[Long], notValidMsg: String, notLiteralMsg: String)(isValid: Long => Boolean): Unit = {
import c.universe._
value.tree match {
case Literal(longConst) =>
val literalValue = longConst.value.toString.toLong
if (!isValid(literalValue))
c.abort(c.enclosingPosition, notValidMsg)
case _ =>
c.abort(c.enclosingPosition, notLiteralMsg)
}
}
/**
* Ensures a given expression of type Float
is a literal with a valid value according to a given validation function.
*
*
* If the given Float
expression is a literal whose value satisfies the given validation function, this method will
* return normally. Otherwise, if the given Float
expression is not a literal, this method will complete abruptly with
* an exception whose detail message includes the String
passed as notLiteralMsg
. Otherwise, the
* given Float
expression is a literal that does not satisfy the given validation function, so this method will
* complete abruptly with an exception whose detail message includes the String
passed as notValidMsg
.
*
*
*
* This method is intended to be invoked at compile time from macros. When called from a macro, exceptions thrown by this method
* will result in compiler errors. The detail message of the thrown exception will appear as the compiler error message.
*
*
* @param c the compiler context for this assertion
* @param value the Float
expression to validate
* @param notValidMsg a String
message to include in the exception thrown if the expression is a literal, but not valid
* @param notLiteralMsg a String
message to include in the exception thrown if the expression is not a literal
* @param isValid a function used to validate a literal value parsed from the given expression
*/
def ensureValidFloatLiteral(c: Context)(value: c.Expr[Float], notValidMsg: String, notLiteralMsg: String)(isValid: Float => Boolean): Unit = {
import c.universe._
value.tree match {
case Literal(floatConst) =>
val literalValue = floatConst.value.toString.toFloat
if (!isValid(literalValue))
c.abort(c.enclosingPosition, notValidMsg)
case _ =>
c.abort(c.enclosingPosition, notLiteralMsg)
}
}
/**
* Ensures a given expression of type Double
is a literal with a valid value according to a given validation function.
*
*
* If the given Double
expression is a literal whose value satisfies the given validation function, this method will
* return normally. Otherwise, if the given Double
expression is not a literal, this method will complete abruptly with
* an exception whose detail message includes the String
passed as notLiteralMsg
. Otherwise, the
* given Double
expression is a literal that does not satisfy the given validation function, so this method will
* complete abruptly with an exception whose detail message includes the String
passed as notValidMsg
.
*
*
*
* This method is intended to be invoked at compile time from macros. When called from a macro, exceptions thrown by this method
* will result in compiler errors. The detail message of the thrown exception will appear as the compiler error message.
*
*
* @param c the compiler context for this assertion
* @param value the Double
expression to validate
* @param notValidMsg a String
message to include in the exception thrown if the expression is a literal, but not valid
* @param notLiteralMsg a String
message to include in the exception thrown if the expression is not a literal
* @param isValid a function used to validate a literal value parsed from the given expression
*/
def ensureValidDoubleLiteral(c: Context)(value: c.Expr[Double], notValidMsg: String, notLiteralMsg: String)(isValid: Double => Boolean): Unit = {
import c.universe._
value.tree match {
case Literal(doubleConst) =>
val literalValue = doubleConst.value.toString.toDouble
if (!isValid(literalValue))
c.abort(c.enclosingPosition, notValidMsg)
case _ =>
c.abort(c.enclosingPosition, notLiteralMsg)
}
}
/**
* Ensures a given expression of type String
is a literal with a valid value according to a given validation function.
*
*
* If the given String
expression is a literal whose value satisfies the given validation function, this method will
* return normally. Otherwise, if the given String
expression is not a literal, this method will complete abruptly with
* an exception whose detail message includes the String
passed as notLiteralMsg
. Otherwise, the
* given String
expression is a literal that does not satisfy the given validation function, so this method will
* complete abruptly with an exception whose detail message includes the String
passed as notValidMsg
.
*
*
*
* This method is intended to be invoked at compile time from macros. When called from a macro, exceptions thrown by this method
* will result in compiler errors. The detail message of the thrown exception will appear as the compiler error message.
*
*
* @param c the compiler context for this assertion
* @param value the String
expression to validate
* @param notValidMsg a String
message to include in the exception thrown if the expression is a literal, but not valid
* @param notLiteralMsg a String
message to include in the exception thrown if the expression is not a literal
* @param isValid a function used to validate a literal value parsed from the given expression
*/
def ensureValidStringLiteral(c: Context)(value: c.Expr[String], notValidMsg: String, notLiteralMsg: String)(isValid: String => Boolean): Unit = {
import c.universe._
value.tree match {
case Literal(stringConst) =>
val literalValue = stringConst.value.toString
if (!isValid(literalValue))
c.abort(c.enclosingPosition, notValidMsg)
case _ =>
c.abort(c.enclosingPosition, notLiteralMsg)
}
}
/**
* Ensures a given expression of type Char
is a literal with a valid value according to a given validation function.
*
*
* If the given Char
expression is a literal whose value satisfies the given validation function, this method will
* return normally. Otherwise, if the given Char
expression is not a literal, this method will complete abruptly with
* an exception whose detail message includes the String
passed as notLiteralMsg
. Otherwise, the
* given Char
expression is a literal that does not satisfy the given validation function, so this method will
* complete abruptly with an exception whose detail message includes the String
passed as notValidMsg
.
*
*
*
* This method is intended to be invoked at compile time from macros. When called from a macro, exceptions thrown by this method
* will result in compiler errors. The detail message of the thrown exception will appear as the compiler error message.
*
*
* @param c the compiler context for this assertion
* @param value the Char
expression to validate
* @param notValidMsg a String
message to include in the exception thrown if the expression is a literal, but not valid
* @param notLiteralMsg a String
message to include in the exception thrown if the expression is not a literal
* @param isValid a function used to validate a literal value parsed from the given expression
*/
def ensureValidCharLiteral(c: Context)(value: c.Expr[Char], notValidMsg: String, notLiteralMsg: String)(isValid: Char => Boolean): Unit = {
import c.universe._
value.tree match {
case Literal(charConst) =>
val literalValue = charConst.value.toString.head
if (!isValid(literalValue))
c.abort(c.enclosingPosition, notValidMsg)
case _ =>
c.abort(c.enclosingPosition, notLiteralMsg)
}
}
}
/**
* Companion object that facilitates the importing of CompileTimeAssertions
members as
* an alternative to mixing in the trait.
*/
object CompileTimeAssertions extends CompileTimeAssertions