Please wait. This can take some minutes ...
Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance.
Project price only 1 $
You can buy this project and download/modify it how often you want.
org.alephium.ralph.Ast.scala Maven / Gradle / Ivy
// Copyright 2018 The Alephium Authors
// This file is part of the alephium project.
//
// The library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the library. If not, see .
package org.alephium.ralph
import java.nio.charset.StandardCharsets
import scala.annotation.{nowarn, tailrec}
import scala.collection.mutable
import akka.util.ByteString
import org.alephium.protocol.vm
import org.alephium.protocol.vm.{ALPHTokenId => ALPHTokenIdInstr, Contract => VmContract, _}
import org.alephium.ralph.LogicalOperator.Not
import org.alephium.ralph.Parser.FunctionUsingAnnotation
import org.alephium.util.{AVector, DjbHash, Hex, I256, U256}
// scalastyle:off number.of.methods number.of.types file.size.limit
object Ast {
type StdInterfaceId = Val.ByteVec
val StdInterfaceIdPrefix: ByteString = ByteString("ALPH", StandardCharsets.UTF_8)
private val stdArg: Argument =
Argument(Ident("__stdInterfaceId"), Type.ByteVec, isMutable = false, isUnused = true)
trait Positioned {
var sourceIndex: Option[SourceIndex] = None
def atSourceIndex(fromIndex: Int, endIndex: Int, fileURI: Option[java.net.URI]): this.type = {
require(this.sourceIndex.isEmpty)
this.sourceIndex = Some(SourceIndex(fromIndex, endIndex - fromIndex, fileURI))
this
}
def atSourceIndex(sourceIndex: Option[SourceIndex]): this.type = {
require(this.sourceIndex.isEmpty)
this.sourceIndex = sourceIndex
this
}
def overwriteSourceIndex(
fromIndex: Int,
endIndex: Int,
fileURI: Option[java.net.URI]
): this.type = {
require(this.sourceIndex.isDefined)
this.sourceIndex = Some(SourceIndex(fromIndex, endIndex - fromIndex, fileURI))
this
}
/*
* This function update a `CompilerError` when the source index was not
* available at the time of the error.
* For example for `Operator` or `BuiltIn`, we could add the `SourceIndex`
* to the `getReturnType` function, but it implies a lot of changes in the
* all `ralph` module, while the position is not useful along the way.
*/
def positionedError[T](f: => T): T = {
try {
f
} catch {
case e: error.CompilerError.Default =>
if (sourceIndex.isDefined && e.sourceIndex.isEmpty) {
throw e.copy(sourceIndex = sourceIndex)
} else {
throw e
}
}
}
}
final case class Ident(name: String) extends Positioned
final case class TypeId(name: String) extends Positioned
final case class FuncId(name: String, isBuiltIn: Boolean) extends Positioned
final case class Argument(ident: Ident, tpe: Type, isMutable: Boolean, isUnused: Boolean)
extends Positioned {
def signature: String = {
val prefix = if (isMutable) "mut " else ""
s"${prefix}${ident.name}:${tpe.signature}"
}
}
final case class EventField(ident: Ident, tpe: Type) extends Positioned {
def signature: String = s"${ident.name}:${tpe.signature}"
}
final case class AnnotationField[Ctx <: StatelessContext](ident: Ident, value: Const[Ctx])
extends Positioned
final case class Annotation[Ctx <: StatelessContext](id: Ident, fields: Seq[AnnotationField[Ctx]])
extends Positioned
object FuncId {
lazy val empty: FuncId = FuncId("", isBuiltIn = false)
}
def funcName(typeId: TypeId, funcId: FuncId): String = quote(s"${typeId.name}.${funcId.name}")
final case class ApproveAsset[Ctx <: StatelessContext](
address: Expr[Ctx],
tokenAmounts: Seq[(Expr[Ctx], Expr[Ctx])]
) extends Positioned {
def check(state: Compiler.State[Ctx]): Unit = {
if (address.getType(state) != Seq(Type.Address)) {
throw Compiler.Error(s"Invalid address type: ${address}", address.sourceIndex)
}
tokenAmounts
.find(p =>
(p._1.getType(state), p._2.getType(state)) != (Seq(Type.ByteVec), Seq(Type.U256))
) match {
case None => ()
case Some((exp, _)) =>
throw Compiler.Error(s"Invalid token amount type: ${tokenAmounts}", exp.sourceIndex)
}
}
@SuppressWarnings(Array("org.wartremover.warts.AsInstanceOf"))
def genCode(state: Compiler.State[Ctx]): Seq[Instr[Ctx]] = {
val approveCount = tokenAmounts.length
assume(approveCount >= 1)
val approveTokens: Seq[Instr[Ctx]] = tokenAmounts.flatMap {
case (ALPHTokenId(), amount) =>
amount.genCode(state) :+ ApproveAlph.asInstanceOf[Instr[Ctx]]
case (tokenId, amount) =>
tokenId.genCode(state) ++ amount.genCode(state) :+ ApproveToken.asInstanceOf[Instr[Ctx]]
}
address.genCode(state) ++ Seq.fill(approveCount - 1)(Dup) ++ approveTokens
}
def reset(): Unit = {
address.reset()
tokenAmounts.foreach { case (tokenExpr, amountExpr) =>
tokenExpr.reset()
amountExpr.reset()
}
}
}
sealed trait ContractAssetsAnnotation {
def assetsEnabled: Boolean
}
case object NotUseContractAssets extends ContractAssetsAnnotation {
val assetsEnabled = false
}
case object UseContractAssets extends ContractAssetsAnnotation {
val assetsEnabled = true
}
case object EnforcedUseContractAssets extends ContractAssetsAnnotation {
val assetsEnabled = true
}
trait ApproveAssets[Ctx <: StatelessContext] extends Positioned {
def approveAssets: Seq[ApproveAsset[Ctx]]
def checkApproveAssets(state: Compiler.State[Ctx]): Unit = {
approveAssets.foreach(_.check(state))
}
def genApproveCode(
state: Compiler.State[Ctx],
func: Compiler.FuncInfo[Ctx]
): Seq[Instr[Ctx]] = {
(approveAssets.nonEmpty, func.usePreapprovedAssets) match {
case (true, false) =>
throw Compiler.Error(
s"Function `${func.name}` does not use preapproved assets",
sourceIndex
)
case (false, true) =>
throw Compiler.Error(
s"Function `${func.name}` needs preapproved assets, please use braces syntax",
sourceIndex
)
case _ => ()
}
approveAssets.flatMap(_.genCode(state))
}
}
trait Typed[Ctx <: StatelessContext, T] extends Positioned {
private var tpe: Option[T] = None
protected def _getType(state: Compiler.State[Ctx]): T
def getCachedType(): Option[T] = tpe
def getType(state: Compiler.State[Ctx]): T =
tpe match {
case Some(ts) => ts
case None =>
val t = _getType(state)
tpe = Some(t)
t
}
def reset(): Unit = tpe = None
}
sealed trait Expr[Ctx <: StatelessContext]
extends Typed[Ctx, Seq[Type]]
with Product
with Serializable {
def genCode(state: Compiler.State[Ctx]): Seq[Instr[Ctx]]
}
final case class ALPHTokenId[Ctx <: StatelessContext]() extends Expr[Ctx] {
def _getType(state: Compiler.State[Ctx]): Seq[Type] = Seq(Type.ByteVec)
@SuppressWarnings(Array("org.wartremover.warts.AsInstanceOf"))
def genCode(state: Compiler.State[Ctx]): Seq[Instr[Ctx]] = Seq(
ALPHTokenIdInstr.asInstanceOf[Instr[Ctx]]
)
}
final case class Const[Ctx <: StatelessContext](v: Val) extends Expr[Ctx] {
def toConstInstr: Instr[StatelessContext] = v.toConstInstr
override def _getType(state: Compiler.State[Ctx]): Seq[Type] = Seq(Type.fromVal(v.tpe))
override def genCode(state: Compiler.State[Ctx]): Seq[Instr[Ctx]] = {
Seq(v.toConstInstr)
}
}
sealed trait CreateArrayExpr[Ctx <: StatelessContext] extends Expr[Ctx] {
def elementExpr: Expr[Ctx]
protected def getElementType(state: Compiler.State[Ctx]): Type = {
val baseType = elementExpr.getType(state)
if (baseType.length != 1) {
throw Compiler.Error(
s"Expected single type for array element, got ${quote(baseType)}",
sourceIndex
)
}
baseType(0)
}
}
final case class CreateArrayExpr1[Ctx <: StatelessContext](elements: Seq[Expr[Ctx]])
extends CreateArrayExpr[Ctx] {
def elementExpr: Expr[Ctx] = {
assume(elements.nonEmpty)
elements(0)
}
override def _getType(state: Compiler.State[Ctx]): Seq[Type.FixedSizeArray] = {
val elementType = getElementType(state)
if (elements.drop(0).exists(_.getType(state) != Seq(elementType))) {
throw Compiler.Error(s"Array elements should have same type", sourceIndex)
}
Seq(Type.FixedSizeArray(elementType, Left(elements.size)))
}
override def genCode(state: Compiler.State[Ctx]): Seq[Instr[Ctx]] = {
elements.flatMap(_.genCode(state))
}
override def reset(): Unit = {
elements.foreach(_.reset())
super.reset()
}
}
final case class CreateArrayExpr2[Ctx <: StatelessContext](
elementExpr: Expr[Ctx],
sizeExpr: Expr[Ctx]
) extends CreateArrayExpr[Ctx] {
private var size: Option[Int] = None
override def _getType(state: Compiler.State[Ctx]): Seq[Type.FixedSizeArray] = {
val elementType = getElementType(state)
val arraySize = state.calcArraySize(sizeExpr)
size = Some(arraySize)
Seq(Type.FixedSizeArray(elementType, Left(arraySize)))
}
override def genCode(state: Compiler.State[Ctx]): Seq[Instr[Ctx]] = {
val arraySize = size.getOrElse(state.calcArraySize(sizeExpr))
Seq.fill(arraySize)(elementExpr).flatMap(_.genCode(state))
}
override def reset(): Unit = {
size = None
elementExpr.reset()
sizeExpr.reset()
super.reset()
}
}
sealed trait AccessDataT[Ctx <: StatelessContext] { self: Positioned =>
def selectors: Seq[DataSelector]
@SuppressWarnings(Array("org.wartremover.warts.AsInstanceOf"))
protected def mapKeyIndex: Expr[Ctx] = {
selectors(0).asInstanceOf[IndexSelector[Ctx]].index
}
protected def _getType(
state: Compiler.State[Ctx],
rootType: Type,
sourceIndex: Option[SourceIndex]
): Type = {
selectors
.foldLeft((rootType, sourceIndex)) { case ((tpe, sourceIndex), selector) =>
(tpe, selector) match {
case (array: Type.FixedSizeArray, selector: IndexSelector[Ctx @unchecked]) =>
state.checkArrayIndexType(selector.index)
(array.baseType, selector.sourceIndex)
case (struct: Type.Struct, IdentSelector(ident)) =>
val field = state.getStruct(struct.id).getField(ident)
(state.resolveType(field.tpe), selector.sourceIndex)
case (map: Type.Map, selector: IndexSelector[Ctx @unchecked]) =>
state.checkMapKeyType(map, selector.index)
(map.value, selector.sourceIndex)
case (tpe, _: IndexSelector[Ctx @unchecked]) =>
throw Compiler.Error(
s"Expected array or map type, got ${quote(tpe)}",
SourceIndex(this.sourceIndex, sourceIndex)
)
case (tpe, _: IdentSelector) =>
throw Compiler.Error(
s"Expected struct type, got ${quote(tpe)}",
SourceIndex(this.sourceIndex, sourceIndex)
)
}
}
._1
}
}
object MapOps {
private def genMapKey[Ctx <: StatelessContext](
state: Compiler.State[Ctx],
expr: Ast.Expr[Ctx]
): Seq[Instr[Ctx]] = {
val codes = expr.genCode(state)
expr.getType(state)(0) match {
case Type.Bool => codes :+ BoolToByteVec
case Type.U256 => codes :+ U256ToByteVec
case Type.I256 => codes :+ I256ToByteVec
case Type.Address => codes :+ AddressToByteVec
case Type.ByteVec => codes
case tpe => // dead branch
throw Compiler.Error(s"Invalid key type $tpe", expr.sourceIndex)
}
}
@inline def genSubContractPath[Ctx <: StatelessContext](
state: Compiler.State[Ctx],
ident: Ast.Ident,
index: Ast.Expr[Ctx]
): Seq[Instr[Ctx]] = {
(state.genLoadCode(ident) ++ genMapKey(state, index)) :+ ByteVecConcat
}
@inline def genSubContractPath[Ctx <: StatelessContext](
state: Compiler.State[Ctx],
map: Ast.Expr[Ctx],
index: Ast.Expr[Ctx]
): Seq[Instr[Ctx]] = {
(map.genCode(state) ++ genMapKey(state, index)) :+ ByteVecConcat
}
@tailrec
private def calcDataOffset(
state: Compiler.State[StatefulContext],
tpe: Type,
selectors: Seq[DataSelector],
isMutable: Boolean,
dataOffset: DataRefOffset[StatefulContext]
): (DataRefOffset[StatefulContext], Boolean) = {
(state.resolveType(tpe), selectors.headOption) match {
case (_, None) => (dataOffset, isMutable)
case (tpe: Type.FixedSizeArray, Some(s: IndexSelector[StatefulContext @unchecked])) =>
val newOffset = dataOffset.calcArrayElementOffset(state, tpe, s.index, isMutable)
calcDataOffset(state, tpe.baseType, selectors.drop(1), isMutable, newOffset)
case (tpe: Type.Struct, Some(IdentSelector(ident))) =>
val ast = state.getStruct(tpe.id)
val newOffset = dataOffset.calcStructFieldOffset(state, ast, ident, isMutable)
val field = ast.getField(ident)
val isFieldMutable = isMutable && field.isMutable
calcDataOffset(state, field.tpe, selectors.drop(1), isFieldMutable, newOffset)
case _ => // dead branch
throw Compiler.Error(
s"Invalid type $tpe and selectors $selectors",
selectors.headOption.flatMap(_.sourceIndex)
)
}
}
private def calcDataOffset(
state: Compiler.State[StatefulContext],
rootType: Type,
selectors: Seq[DataSelector]
): (VarOffset[StatefulContext], VarOffset[StatefulContext], Boolean) = {
val initOffset = DataRefOffset[StatefulContext](ConstantVarOffset(0), ConstantVarOffset(0))
val (offset, isMutable) = calcDataOffset(
state,
rootType,
selectors,
isMutable = true,
initOffset
)
(offset.immDataOffset, offset.mutDataOffset, isMutable)
}
private def genSubContractId(
state: Compiler.State[StatefulContext],
objCodes: Seq[Instr[StatefulContext]],
size: Int
): (Seq[Instr[StatefulContext]], Seq[Instr[StatefulContext]]) = {
if (size == 1) {
(Seq.empty, objCodes)
} else {
val ident = state.getSubContractIdVar()
(objCodes ++ state.genStoreCode(ident).flatten, state.genLoadCode(ident))
}
}
@SuppressWarnings(Array("org.wartremover.warts.AsInstanceOf"))
def genLoad[Ctx <: StatelessContext](
state: Compiler.State[Ctx],
rootType: Type,
selectedDataType: Type,
pathCodes: Seq[Instr[Ctx]],
selectors: Seq[DataSelector]
): Seq[Instr[Ctx]] = {
val statefulState = state.asInstanceOf[Compiler.State[StatefulContext]]
val (immOffset, mutOffset, isMutable) = calcDataOffset(statefulState, rootType, selectors)
val mutability = state.flattenTypeMutability(selectedDataType, isMutable)
val (initCodes, subContractIdCodes) = genSubContractId(
statefulState,
pathCodes.asInstanceOf[Seq[Instr[StatefulContext]]] :+ SubContractId,
mutability.length
)
val funcArgLenAndRetLen =
Seq(ConstInstr.u256(Val.U256(U256.One)), ConstInstr.u256(Val.U256(U256.One)))
val instrs = mutability.indices
.foldLeft((Seq.empty[Instr[StatefulContext]], immOffset, mutOffset)) {
case ((instrs, immOffset, mutOffset), index) =>
val objCodes = if (index == 0) initCodes ++ subContractIdCodes else subContractIdCodes
if (mutability(index)) {
val loadCodes = mutOffset.genCode() ++ funcArgLenAndRetLen ++
objCodes :+ CallExternal(CreateMapEntry.LoadMutFieldMethodIndex)
(instrs ++ loadCodes, immOffset, mutOffset.add(1))
} else {
val loadCodes = immOffset.genCode() ++ funcArgLenAndRetLen ++
objCodes :+ CallExternal(CreateMapEntry.LoadImmFieldMethodIndex)
(instrs ++ loadCodes, immOffset.add(1), mutOffset)
}
}
._1
instrs.asInstanceOf[Seq[Instr[Ctx]]]
}
@SuppressWarnings(Array("org.wartremover.warts.AsInstanceOf"))
def genStore[Ctx <: StatelessContext](
state: Compiler.State[Ctx],
rootType: Type,
selectedDataType: Type,
pathCodes: Seq[Instr[Ctx]],
selectors: Seq[DataSelector]
): Seq[Seq[Instr[Ctx]]] = {
val statefulState = state.asInstanceOf[Compiler.State[StatefulContext]]
val (_, mutOffset, _) = calcDataOffset(statefulState, rootType, selectors)
val length = state.flattenTypeLength(Seq(selectedDataType))
val (initCodes, subContractIdCodes) = genSubContractId(
statefulState,
pathCodes.asInstanceOf[Seq[Instr[StatefulContext]]] :+ SubContractId,
length
)
val instrs = (0 until length).map { index =>
val indexCodes = if (index == 0) mutOffset.genCode() else mutOffset.add(index).genCode()
val objCodes =
if (index == length - 1) initCodes ++ subContractIdCodes else subContractIdCodes
indexCodes ++ Seq(
ConstInstr.u256(Val.U256(U256.Two)),
ConstInstr.u256(Val.U256(U256.Zero))
) ++ objCodes :+ CallExternal(CreateMapEntry.StoreMutFieldMethodIndex)
}
instrs.asInstanceOf[Seq[Seq[Instr[Ctx]]]]
}
}
final case class LoadDataBySelectors[Ctx <: StatelessContext](
base: Expr[Ctx],
selectors: Seq[DataSelector]
) extends Expr[Ctx]
with AccessDataT[Ctx] {
def _getType(state: Compiler.State[Ctx]): Seq[Type] = {
assume(selectors.nonEmpty)
base.getType(state) match {
case Seq(t: Type.FixedSizeArray) =>
Seq(_getType(state, t, base.sourceIndex))
case Seq(t: Type.Struct) => Seq(_getType(state, t, base.sourceIndex))
case Seq(t: Type.Map) => Seq(_getType(state, t, base.sourceIndex))
case tpe =>
val tpeStr = quoteTypes(tpe)
selectors.headOption match {
case Some(IndexSelector(_)) =>
throw Compiler.Error(s"Expected array or map type, got $tpeStr", base.sourceIndex)
case _ =>
throw Compiler.Error(s"Expected struct type, got $tpeStr", base.sourceIndex)
}
}
}
@SuppressWarnings(Array("org.wartremover.warts.IterableOps"))
def genCode(state: Compiler.State[Ctx]): Seq[Instr[Ctx]] = {
base.getType(state) match {
case Seq(map: Type.Map) =>
val pathCodes = MapOps.genSubContractPath(state, base, mapKeyIndex)
MapOps.genLoad(state, map.value, getType(state).head, pathCodes, selectors.tail)
case _ =>
val (ref, codes) = state.getOrCreateVariablesRef(base)
val subRef = ref.subRef(state, selectors.init)
codes ++ subRef.genLoadCode(state, selectors.last)
}
}
override def reset(): Unit = {
base.reset()
selectors.foreach(_.reset())
super.reset()
}
}
final case class MapContains(ident: Ident, index: Expr[StatefulContext])
extends Expr[StatefulContext] {
def _getType(state: Compiler.State[StatefulContext]): Seq[Type] = {
val mapType = state.getVariable(ident).tpe match {
case t: Type.Map => t
case t => throw Compiler.Error(s"Expected map type, got $t", ident.sourceIndex)
}
val expected = Seq(mapType.key)
val argTypes = index.getType(state)
if (argTypes != expected) {
throw Compiler.Error(s"Invalid args type $argTypes, expected $expected", sourceIndex)
}
Seq(Type.Bool)
}
def genCode(state: Compiler.State[StatefulContext]): Seq[Instr[StatefulContext]] = {
val pathCodes = MapOps.genSubContractPath(state, ident, index)
pathCodes ++ Seq(SubContractId, ContractExists)
}
override def reset(): Unit = {
index.reset()
super.reset()
}
}
final case class Variable[Ctx <: StatelessContext](id: Ident) extends Expr[Ctx] {
override def _getType(state: Compiler.State[Ctx]): Seq[Type] = Seq(state.resolveType(id))
override def genCode(state: Compiler.State[Ctx]): Seq[Instr[Ctx]] = {
state.genLoadCode(id)
}
}
final case class EnumFieldSelector[Ctx <: StatelessContext](enumId: TypeId, field: Ident)
extends Expr[Ctx] {
lazy val fieldIdent = EnumDef.fieldIdent(enumId, field)
override def _getType(state: Compiler.State[Ctx]): Seq[Type] =
Seq(state.getVariable(fieldIdent).tpe)
override def genCode(state: Compiler.State[Ctx]): Seq[Instr[Ctx]] = {
state.genLoadCode(fieldIdent)
}
}
final case class UnaryOp[Ctx <: StatelessContext](op: Operator, expr: Expr[Ctx])
extends Expr[Ctx] {
override def _getType(state: Compiler.State[Ctx]): Seq[Type] = {
positionedError(op.getReturnType(expr.getType(state)))
}
override def genCode(state: Compiler.State[Ctx]): Seq[Instr[Ctx]] = {
expr.genCode(state) ++ op.genCode(expr.getType(state))
}
override def reset(): Unit = {
expr.reset()
super.reset()
}
}
final case class Binop[Ctx <: StatelessContext](op: Operator, left: Expr[Ctx], right: Expr[Ctx])
extends Expr[Ctx] {
override def _getType(state: Compiler.State[Ctx]): Seq[Type] = {
positionedError(op.getReturnType(left.getType(state) ++ right.getType(state)))
}
override def genCode(state: Compiler.State[Ctx]): Seq[Instr[Ctx]] = {
positionedError(
left.genCode(state) ++ right.genCode(state) ++ op.genCode(
left.getType(state) ++ right.getType(state)
)
)
}
override def reset(): Unit = {
left.reset()
right.reset()
super.reset()
}
}
final case class ContractConv[Ctx <: StatelessContext](contractType: TypeId, address: Expr[Ctx])
extends Expr[Ctx] {
override protected def _getType(state: Compiler.State[Ctx]): Seq[Type] = {
state.checkContractType(contractType)
if (address.getType(state) != Seq(Type.ByteVec)) {
throw Compiler.Error(s"Invalid expr $address for contract address", address.sourceIndex)
}
val contractInfo = state.getContractInfo(contractType)
if (!contractInfo.kind.instantiable) {
throw Compiler.Error(s"${contractType.name} is not instantiable", sourceIndex)
}
Seq(Type.Contract(contractType))
}
override def genCode(state: Compiler.State[Ctx]): Seq[Instr[Ctx]] =
address.genCode(state)
override def reset(): Unit = {
address.reset()
super.reset()
}
}
sealed trait CallAst[Ctx <: StatelessContext] extends ApproveAssets[Ctx] {
def id: FuncId
def args: Seq[Expr[Ctx]]
def ignoreReturn: Boolean
def getFunc(state: Compiler.State[Ctx]): Compiler.FuncInfo[Ctx]
@SuppressWarnings(Array("org.wartremover.warts.AsInstanceOf"))
def _genCode(state: Compiler.State[Ctx]): Seq[Instr[Ctx]] = {
(id, args) match {
case (BuiltIn.approveToken.funcId, Seq(from, ALPHTokenId(), amount)) =>
Seq(from, amount).flatMap(_.genCode(state)) :+ ApproveAlph.asInstanceOf[Instr[Ctx]]
case (BuiltIn.tokenRemaining.funcId, Seq(from, ALPHTokenId())) =>
val instrs = from.genCode(state) :+ AlphRemaining.asInstanceOf[Instr[Ctx]]
if (ignoreReturn) instrs :+ Pop.asInstanceOf[Instr[Ctx]] else instrs
case (BuiltIn.transferToken.funcId, Seq(from, to, ALPHTokenId(), amount)) =>
Seq(from, to, amount).flatMap(_.genCode(state)) :+ TransferAlph.asInstanceOf[Instr[Ctx]]
case (BuiltIn.transferTokenFromSelf.funcId, Seq(to, ALPHTokenId(), amount)) =>
Seq(to, amount).flatMap(_.genCode(state)) :+ TransferAlphFromSelf.asInstanceOf[Instr[Ctx]]
case (BuiltIn.transferTokenToSelf.funcId, Seq(from, ALPHTokenId(), amount)) =>
Seq(from, amount).flatMap(_.genCode(state)) :+ TransferAlphToSelf.asInstanceOf[Instr[Ctx]]
case _ =>
val func = getFunc(state)
val argsType = args.flatMap(_.getType(state))
val variadicInstrs = if (func.isVariadic) {
Seq(U256Const(Val.U256.unsafe(args.length)))
} else {
Seq.empty
}
val instrs = genApproveCode(state, func) ++
func.genCodeForArgs(args, state) ++
variadicInstrs ++
func.genCode(argsType)
if (ignoreReturn) {
val returnType = positionedError(func.getReturnType(argsType, state))
instrs ++ Seq.fill(state.flattenTypeLength(returnType))(Pop)
} else {
instrs
}
}
}
@inline final def checkStaticContractFunction(
typeId: TypeId,
funcId: FuncId,
func: Compiler.ContractFunc[Ctx]
): Unit = {
if (!func.isStatic) {
throw Compiler.Error(
s"Expected static function, got ${funcName(typeId, funcId)}",
funcId.sourceIndex
)
}
}
}
final case class CallExpr[Ctx <: StatelessContext](
id: FuncId,
approveAssets: Seq[ApproveAsset[Ctx]],
args: Seq[Expr[Ctx]]
) extends Expr[Ctx]
with CallAst[Ctx] {
def ignoreReturn: Boolean = false
def getFunc(state: Compiler.State[Ctx]): Compiler.FuncInfo[Ctx] = state.getFunc(id)
override def _getType(state: Compiler.State[Ctx]): Seq[Type] = {
checkApproveAssets(state)
val funcInfo = state.getFunc(id)
positionedError(funcInfo.getReturnType(args.flatMap(_.getType(state)), state))
}
override def genCode(state: Compiler.State[Ctx]): Seq[Instr[Ctx]] = {
state.addInternalCall(
id
) // don't put this in _getType, otherwise the statement might get skipped
_genCode(state)
}
override def reset(): Unit = {
approveAssets.foreach(_.reset())
args.foreach(_.reset())
super.reset()
}
}
final case class ContractStaticCallExpr[Ctx <: StatelessContext](
contractId: TypeId,
id: FuncId,
approveAssets: Seq[ApproveAsset[Ctx]],
args: Seq[Expr[Ctx]]
) extends Expr[Ctx]
with CallAst[Ctx] {
def ignoreReturn: Boolean = false
def getFunc(state: Compiler.State[Ctx]): Compiler.ContractFunc[Ctx] =
state.getFunc(contractId, id)
override def _getType(state: Compiler.State[Ctx]): Seq[Type] = {
checkApproveAssets(state)
val funcInfo = getFunc(state)
checkStaticContractFunction(contractId, id, funcInfo)
positionedError(funcInfo.getReturnType(args.flatMap(_.getType(state)), state))
}
override def genCode(state: Compiler.State[Ctx]): Seq[Instr[Ctx]] = {
_genCode(state)
}
override def reset(): Unit = {
approveAssets.foreach(_.reset())
args.foreach(_.reset())
super.reset()
}
}
trait ContractCallBase extends ApproveAssets[StatefulContext] {
def obj: Expr[StatefulContext]
def callId: FuncId
def args: Seq[Expr[StatefulContext]]
@inline def getContractType(state: Compiler.State[StatefulContext]): Type.Contract = {
val objType = obj.getType(state)
if (objType.length != 1) {
throw Compiler.Error(
s"Expected a single parameter for contract object, got ${quote(obj)}",
obj.sourceIndex
)
} else {
state.resolveType(objType(0)) match {
case contract: Type.Contract => contract
case _ =>
throw Compiler.Error(
s"Expected a contract for ${quote(callId)}, got ${quote(obj)}",
obj.sourceIndex
)
}
}
}
def _getTypeBase(state: Compiler.State[StatefulContext]): (TypeId, Seq[Type]) = {
val contractType = getContractType(state)
val contractInfo = state.getContractInfo(contractType.id)
if (contractInfo.kind == Compiler.ContractKind.Interface) {
state.addInterfaceFuncCall(state.currentScope)
}
val funcInfo = state.getFunc(contractType.id, callId)
checkNonStaticContractFunction(contractType.id, callId, funcInfo)
state.addExternalCall(contractType.id, callId)
val retTypes = positionedError(funcInfo.getReturnType(args.flatMap(_.getType(state)), state))
(contractType.id, retTypes)
}
@SuppressWarnings(Array("org.wartremover.warts.AsInstanceOf"))
def genContractCall(
state: Compiler.State[StatefulContext],
popReturnValues: Boolean
): Seq[Instr[StatefulContext]] = {
val contract = obj.getType(state)(0).asInstanceOf[Type.Contract]
val func = state.getFunc(contract.id, callId)
val argTypes = args.flatMap(_.getType(state))
val retTypes = func.getReturnType(argTypes, state)
val retLength = state.flattenTypeLength(retTypes)
genApproveCode(state, func) ++
args.flatMap(_.genCode(state)) ++
func.genExternalCallCode(state, obj.genCode(state), contract.id) ++
(if (popReturnValues) Seq.fill[Instr[StatefulContext]](retLength)(Pop) else Seq.empty)
}
@inline final def checkNonStaticContractFunction(
typeId: TypeId,
funcId: FuncId,
func: Compiler.ContractFunc[StatefulContext]
): Unit = {
if (func.isStatic) {
// TODO: use `obj.funcId` instead of `typeId.funcId`
throw Compiler.Error(
s"Expected non-static function, got ${funcName(typeId, funcId)}",
funcId.sourceIndex
)
}
}
}
final case class ContractCallExpr(
obj: Expr[StatefulContext],
callId: FuncId,
approveAssets: Seq[ApproveAsset[StatefulContext]],
args: Seq[Expr[StatefulContext]]
) extends Expr[StatefulContext]
with ContractCallBase {
override def _getType(state: Compiler.State[StatefulContext]): Seq[Type] = {
checkApproveAssets(state)
_getTypeBase(state)._2
}
override def genCode(state: Compiler.State[StatefulContext]): Seq[Instr[StatefulContext]] = {
genContractCall(state, false)
}
override def reset(): Unit = {
obj.reset()
approveAssets.foreach(_.reset())
args.foreach(_.reset())
super.reset()
}
}
final case class ParenExpr[Ctx <: StatelessContext](expr: Expr[Ctx]) extends Expr[Ctx] {
override def _getType(state: Compiler.State[Ctx]): Seq[Type] =
expr.getType(state: Compiler.State[Ctx])
override def genCode(state: Compiler.State[Ctx]): Seq[Instr[Ctx]] =
expr.genCode(state)
override def reset(): Unit = {
expr.reset()
super.reset()
}
}
trait IfBranch[Ctx <: StatelessContext] extends Positioned {
def condition: Expr[Ctx]
def checkCondition(state: Compiler.State[Ctx]): Unit = {
val conditionType = condition.getType(state)
if (conditionType != Seq(Type.Bool)) {
throw Compiler.Error(
s"Invalid type of condition expr: $conditionType",
condition.sourceIndex
)
}
}
def genCode(state: Compiler.State[Ctx]): Seq[Instr[Ctx]]
}
trait ElseBranch[Ctx <: StatelessContext] extends Positioned {
def genCode(state: Compiler.State[Ctx]): Seq[Instr[Ctx]]
}
trait IfElse[Ctx <: StatelessContext] extends Positioned {
def ifBranches: Seq[IfBranch[Ctx]]
def elseBranchOpt: Option[ElseBranch[Ctx]]
private def genElseBodyIRs(state: Compiler.State[Ctx]) = {
elseBranchOpt
.map { branch =>
state.withScope(branch)(branch.genCode(state))
}
.getOrElse(Seq.empty)
}
@SuppressWarnings(Array("org.wartremover.warts.IterableOps"))
def genCode(state: Compiler.State[Ctx]): Seq[Instr[Ctx]] = {
val ifBranchesIRs = Array.ofDim[Seq[Instr[Ctx]]](ifBranches.length + 1)
val elseOffsets = Array.ofDim[Int](ifBranches.length + 1)
val elseBodyIRs = genElseBodyIRs(state)
ifBranchesIRs(ifBranches.length) = elseBodyIRs
elseOffsets(ifBranches.length) = elseBodyIRs.length
ifBranches.zipWithIndex.view.reverse.foreach { case (ifBranch, index) =>
val initialOffset = elseOffsets(index + 1)
val notTheLastBranch = index < ifBranches.length - 1 || elseBranchOpt.nonEmpty
val bodyIRsWithoutOffset = state.withScope(ifBranch) { ifBranch.genCode(state) }
val bodyOffsetIR = if (notTheLastBranch) {
Seq(Jump(initialOffset))
} else {
Seq.empty
}
val bodyIRs = bodyIRsWithoutOffset ++ bodyOffsetIR
val conditionOffset =
if (notTheLastBranch) bodyIRs.length else bodyIRs.length + initialOffset
val conditionIRs = Statement.getCondIR(ifBranch.condition, state, conditionOffset)
ifBranchesIRs(index) = conditionIRs ++ bodyIRs
elseOffsets(index) = initialOffset + bodyIRs.length + conditionIRs.length
}
ifBranchesIRs.reduce(_ ++ _)
}
}
final case class IfBranchExpr[Ctx <: StatelessContext](
condition: Expr[Ctx],
expr: Expr[Ctx]
) extends IfBranch[Ctx] {
def genCode(state: Compiler.State[Ctx]): Seq[Instr[Ctx]] = expr.genCode(state)
def reset(): Unit = {
condition.reset()
expr.reset()
}
}
final case class ElseBranchExpr[Ctx <: StatelessContext](
expr: Expr[Ctx]
) extends ElseBranch[Ctx] {
def genCode(state: Compiler.State[Ctx]): Seq[Instr[Ctx]] = expr.genCode(state)
def reset(): Unit = expr.reset()
}
final case class IfElseExpr[Ctx <: StatelessContext](
ifBranches: Seq[IfBranchExpr[Ctx]],
elseBranch: ElseBranchExpr[Ctx]
) extends IfElse[Ctx]
with Expr[Ctx] {
def elseBranchOpt: Option[ElseBranch[Ctx]] = Some(elseBranch)
def _getType(state: Compiler.State[Ctx]): Seq[Type] = {
val elseBranchType = elseBranch.expr.getType(state)
ifBranches.foreach { ifBranch =>
ifBranch.checkCondition(state)
val ifBranchType = ifBranch.expr.getType(state)
if (ifBranchType != elseBranchType) {
throw Compiler.Error(
s"Invalid types of if-else expression branches, expected ${quote(elseBranchType)}, got ${quote(ifBranchType)}",
sourceIndex
)
}
}
elseBranchType
}
override def reset(): Unit = {
ifBranches.foreach(_.reset())
elseBranch.reset()
super.reset()
}
}
final case class StructField(ident: Ident, isMutable: Boolean, tpe: Type) extends UniqueDef {
def name: String = ident.name
def signature: String = s"${ident.name}:${tpe.signature}"
}
sealed trait GlobalDefinition extends UniqueDef
final case class Struct(id: TypeId, fields: Seq[StructField]) extends GlobalDefinition {
lazy val tpe: Type.Struct = Type.Struct(id)
def name: String = id.name
def getFieldNames(): AVector[String] = AVector.from(fields.view.map(_.ident.name))
def getFieldTypeSignatures(): AVector[String] = AVector.from(fields.view.map(_.tpe.signature))
def getFieldsMutability(): AVector[Boolean] = AVector.from(fields.view.map(_.isMutable))
def getField(selector: Ident): StructField = {
fields
.find(_.ident == selector)
.getOrElse(
throw Compiler.Error(
s"Field ${selector.name} does not exist in struct ${id.name}",
selector.sourceIndex
)
)
}
def calcFieldOffset[Ctx <: StatelessContext](
state: Compiler.State[Ctx],
selector: Ast.Ident,
isMutable: Boolean
): (Int, Int) = {
val result = fields.slice(0, fields.indexWhere(_.ident == selector)).flatMap { field =>
val isFieldMutable = isMutable && field.isMutable
state.flattenTypeMutability(field.tpe, isFieldMutable)
}
val mutFieldSize = result.count(identity)
(result.length - mutFieldSize, mutFieldSize)
}
def calcLocalOffset[Ctx <: StatelessContext](
state: Compiler.State[Ctx],
selector: Ast.Ident
): Int = {
val types = fields.slice(0, fields.indexWhere(_.ident == selector)).map(_.tpe)
state.flattenTypeLength(types)
}
}
final case class StructCtor[Ctx <: StatelessContext](
id: TypeId,
fields: Seq[(Ident, Option[Expr[Ctx]])]
) extends Expr[Ctx] {
def _getType(state: Compiler.State[Ctx]): Seq[Type] = {
val struct = state.getStruct(id)
val expected = struct.fields.map(field => (field.ident, Seq(state.resolveType(field.tpe))))
val have = fields.map { case (ident, expr) =>
val tpe = expr.map(_.getType(state)).getOrElse(Seq(state.getVariable(ident).tpe))
(ident, tpe)
}
if (expected.length != have.length || have.exists(f => !expected.contains(f))) {
throw Compiler.Error(
s"Invalid struct fields, expect ${struct.fields.map(_.signature)}",
sourceIndex
)
}
Seq(struct.tpe)
}
def genCode(state: Compiler.State[Ctx]): Seq[Instr[Ctx]] = {
val struct = state.getStruct(id)
val sortedFields = struct.fields.map { field =>
fields
.find(_._1 == field.ident)
.getOrElse(
throw Compiler.Error(s"Struct field ${field.ident} does not exist", id.sourceIndex)
)
}
sortedFields.flatMap {
case (_, Some(expr)) => expr.genCode(state)
case (field, None) => state.genLoadCode(field)
}
}
override def reset(): Unit = {
fields.foreach { case (_, expr) => expr.foreach(_.reset()) }
super.reset()
}
}
final case class MapDef(ident: Ident, tpe: Type.Map) extends UniqueDef with Positioned {
def name: String = ident.name
}
sealed trait Statement[Ctx <: StatelessContext]
extends Positioned
with Product
with Serializable {
def check(state: Compiler.State[Ctx]): Unit
def genCode(state: Compiler.State[Ctx]): Seq[Instr[Ctx]]
def reset(): Unit
}
object Statement {
@inline def getCondIR[Ctx <: StatelessContext](
condition: Expr[Ctx],
state: Compiler.State[Ctx],
offset: Int
): Seq[Instr[Ctx]] = {
condition match {
case UnaryOp(Not, expr) =>
expr.genCode(state) :+ IfTrue(offset)
case _ =>
condition.genCode(state) :+ IfFalse(offset)
}
}
}
sealed trait MapFuncCall extends Statement[StatefulContext] {
def ident: Ident
def args: Seq[Expr[StatefulContext]]
private var mapType: Option[Type.Map] = None
def getMapType(state: Compiler.State[StatefulContext]): Type.Map = {
mapType match {
case Some(tpe) => tpe
case None =>
state.getVariable(ident, isWrite = true).tpe match {
case tpe: Type.Map =>
mapType = Some(tpe)
tpe
case t => throw Compiler.Error(s"Expected map type, got $t", ident.sourceIndex)
}
}
}
def checkArgTypes(state: Compiler.State[StatefulContext], expected: Seq[Type]): Unit = {
if (args.length != expected.length) {
throw Compiler.Error(
s"Invalid args length, expected ${expected.length}, got ${args.length}",
sourceIndex
)
}
val argTypes = args.flatMap(_.getType(state))
if (argTypes != expected) {
throw Compiler.Error(s"Invalid args type $argTypes, expected $expected", sourceIndex)
}
}
override def reset(): Unit = mapType = None
}
private def genMapDebug(
state: Compiler.State[StatefulContext],
pathCodes: Seq[Instr[StatefulContext]],
isInsert: Boolean
): Seq[Instr[StatefulContext]] = {
if (state.allowDebug) {
val operation = if (isInsert) "insert" else "remove"
val message = s"$operation at map path: "
val stringParts = AVector(ByteString.fromString(message), ByteString.empty)
pathCodes ++ Seq[Instr[StatefulContext]](Dup, DEBUG(stringParts.map(Val.ByteVec.apply)))
} else {
pathCodes
}
}
final case class InsertToMap(
ident: Ident,
args: Seq[Expr[StatefulContext]]
) extends MapFuncCall {
def check(state: Compiler.State[StatefulContext]): Unit = {
val mapType = getMapType(state)
checkArgTypes(state, Seq(Type.Address, mapType.key, mapType.value))
}
private def checkFieldLength(length: Int): Unit = {
if (length > 0xff) {
throw Compiler.Error(
s"The number of struct fields exceeds the maximum limit",
args(2).sourceIndex
)
}
}
private def genCreateContract(
state: Compiler.State[StatefulContext]
): Seq[Instr[StatefulContext]] = {
val mapType = getMapType(state)
val fieldsMutability = state.flattenTypeMutability(mapType.value, isMutable = true)
val mutFieldLength = fieldsMutability.count(identity)
val immFieldLength = fieldsMutability.length - mutFieldLength + 1 // parent contract id
checkFieldLength(mutFieldLength)
checkFieldLength(immFieldLength)
val pathCodes = MapOps.genSubContractPath(state, ident, args(1))
val (immFields, mutFields) = state.genFieldsInitCodes(fieldsMutability, Seq(args(2)))
val insertWithDebug = genMapDebug(state, pathCodes, isInsert = true)
insertWithDebug ++ (immFields :+ SelfContractId) ++
mutFields :+ CreateMapEntry(immFieldLength.toByte, mutFieldLength.toByte)
}
def genCode(state: Compiler.State[StatefulContext]): Seq[Instr[StatefulContext]] = {
val approveALPHCodes = args(0).genCode(state) ++ Seq(MinimalContractDeposit, ApproveAlph)
val createContractCodes = genCreateContract(state)
approveALPHCodes ++ createContractCodes
}
override def reset(): Unit = {
args.foreach(_.reset())
super.reset()
}
}
final case class RemoveFromMap(ident: Ident, args: Seq[Expr[StatefulContext]])
extends MapFuncCall {
def check(state: Compiler.State[StatefulContext]): Unit = {
val mapType = getMapType(state)
checkArgTypes(state, Seq(Type.Address, mapType.key))
}
def genCode(state: Compiler.State[StatefulContext]): Seq[Instr[StatefulContext]] = {
val pathCodes = MapOps.genSubContractPath(state, ident, args(1))
val objCodes = genMapDebug(state, pathCodes, isInsert = false) :+ SubContractId
args(0).genCode(state) ++ Seq(
ConstInstr.u256(Val.U256(U256.One)), // the `address` parameter
ConstInstr.u256(Val.U256(U256.Zero))
) ++ objCodes :+ CallExternal(CreateMapEntry.DestroyMethodIndex)
}
override def reset(): Unit = {
args.foreach(_.reset())
super.reset()
}
}
sealed trait VarDeclaration extends Positioned
final case class NamedVar(mutable: Boolean, ident: Ident) extends VarDeclaration
case object AnonymousVar extends VarDeclaration
final case class VarDef[Ctx <: StatelessContext](
vars: Seq[VarDeclaration],
value: Expr[Ctx]
) extends Statement[Ctx] {
override def check(state: Compiler.State[Ctx]): Unit = {
val types = value.getType(state)
if (types.length != vars.length) {
throw Compiler.Error(
s"Invalid variable declaration, expected ${types.length} variables, got ${vars.length} variables",
sourceIndex
)
}
vars.zip(types).foreach {
case (NamedVar(isMutable, ident), tpe) =>
if (tpe.isMapType) {
throw Compiler
.Error(s"Cannot define local map variable ${ident.name}", ident.sourceIndex)
}
state.addLocalVariable(ident, tpe, isMutable, isUnused = false, isGenerated = false)
case _ =>
}
}
override def genCode(state: Compiler.State[Ctx]): Seq[Instr[Ctx]] = {
val storeCodes = vars.zip(value.getType(state)).flatMap {
case (NamedVar(_, ident), _) => state.genStoreCode(ident)
case (AnonymousVar, tpe) =>
Seq(Seq.fill(state.flattenTypeLength(Seq(tpe)))(Pop))
}
value.genCode(state) ++ storeCodes.reverse.flatten
}
def reset(): Unit = value.reset()
}
trait UniqueDef extends Positioned {
def name: String
}
object UniqueDef {
def checkDuplicates(defs: Seq[UniqueDef], name: String): Unit = {
if (defs.distinctBy(_.name).size != defs.size) {
val (dups, sourceIndex) = duplicates(defs)
throw Compiler.Error(
s"These $name are defined multiple times: ${dups}",
sourceIndex
)
}
}
def duplicates(defs: Seq[UniqueDef]): (String, Option[SourceIndex]) = {
val dups = defs
.groupBy(_.name)
.filter(_._2.size > 1)
val sourceIndex = dups.values.headOption.flatMap(_.drop(1).headOption.flatMap(_.sourceIndex))
(dups.keys.mkString(", "), sourceIndex)
}
}
final case class FuncSignature(
id: FuncId,
isPublic: Boolean,
usePreapprovedAssets: Boolean,
args: Seq[(Type, Boolean)],
rtypes: Seq[Type]
)
final case class FuncDef[Ctx <: StatelessContext](
annotations: Seq[Annotation[Ctx]],
id: FuncId,
isPublic: Boolean,
usePreapprovedAssets: Boolean,
useAssetsInContract: Ast.ContractAssetsAnnotation,
usePayToContractOnly: Boolean,
useCheckExternalCaller: Boolean,
useUpdateFields: Boolean,
useMethodIndex: Option[Int],
args: Seq[Argument],
rtypes: Seq[Type],
bodyOpt: Option[Seq[Statement[Ctx]]]
) extends UniqueDef
with OriginContractInfo {
def name: String = id.name
def isPrivate: Boolean = !isPublic
val body: Seq[Statement[Ctx]] = bodyOpt.getOrElse(Seq.empty)
private var funcAccessedVarsCache: Option[Set[Compiler.AccessVariable]] = None
private[ralph] var methodSelector: Option[Method.Selector] = None
def getMethodSelector(globalState: GlobalState[_]): Method.Selector = {
methodSelector match {
case Some(selector) => selector
case None =>
val argTypes = args.view
.flatMap(arg => globalState.flattenType(arg.tpe))
.map(_.signature)
.mkString(",")
val retTypes = rtypes.view.flatMap(globalState.flattenType).map(_.signature).mkString(",")
val bytes = ByteString.fromString(s"$name($argTypes)->($retTypes)")
val selector = Method.Selector(DjbHash.intHash(bytes))
methodSelector = Some(selector)
selector
}
}
def hasCheckExternalCallerAnnotation: Boolean = {
annotations.find(_.id.name == FunctionUsingAnnotation.id) match {
case Some(usingAnnotation) =>
usingAnnotation.fields.exists(
_.ident.name == FunctionUsingAnnotation.useCheckExternalCallerKey
)
case None => false
}
}
@inline private def isUpdateMap(state: Compiler.State[Ctx]): Boolean = {
body.exists {
case _: InsertToMap | _: RemoveFromMap => true
case Assign(targets, _) =>
targets.exists {
case AssignmentSelectedTarget(ident, _) => state.hasMapVar(ident)
case _ => false
}
case _ => false
}
}
def isSimpleViewFunc(state: Compiler.State[Ctx]): Boolean = {
val hasInterfaceFuncCall = state.hasInterfaceFuncCallSet.contains(id)
val hasMigrateSimple = body.exists {
case FuncCall(id, _, _) => id.isBuiltIn && id.name == "migrate"
case _ => false
}
!(useUpdateFields
|| usePreapprovedAssets
|| useAssetsInContract != Ast.NotUseContractAssets
|| hasInterfaceFuncCall
|| hasMigrateSimple
|| isUpdateMap(state))
}
lazy val signature: FuncSignature = FuncSignature(
id,
isPublic,
usePreapprovedAssets,
args.map(arg => (arg.tpe, arg.isMutable)),
rtypes
)
def getArgNames(): AVector[String] = AVector.from(args.view.map(_.ident.name))
def getArgTypeSignatures(): AVector[String] = AVector.from(args.view.map(_.tpe.signature))
def getArgMutability(): AVector[Boolean] = AVector.from(args.view.map(_.isMutable))
def getReturnSignatures(): AVector[String] = AVector.from(rtypes.view.map(_.signature))
def hasDirectCheckExternalCaller(): Boolean = {
!useCheckExternalCaller || // check external caller manually disabled
body.exists {
case FuncCall(id, _, _) => id.isBuiltIn && id.name == "checkCaller"
case _ => false
}
}
@SuppressWarnings(Array("org.wartremover.warts.Recursion"))
private def checkRetTypes(stmt: Option[Statement[Ctx]]): Unit = {
stmt match {
case Some(_: ReturnStmt[Ctx]) => () // we checked the `rtypes` in `ReturnStmt`
case Some(IfElseStatement(ifBranches, elseBranchOpt)) =>
ifBranches.foreach(branch => checkRetTypes(branch.body.lastOption))
checkRetTypes(elseBranchOpt.flatMap(_.body.lastOption))
case Some(call: FuncCall[_]) if call.id.name == "panic" && call.id.isBuiltIn == true => ()
case _ =>
throw Compiler.Error(
s"Expected return statement for function ${quote(id.name)}",
id.sourceIndex
)
}
}
def check(state: Compiler.State[Ctx]): Unit = {
state.setFuncScope(id)
state.checkArguments(args)
args.foreach { arg =>
val argTpe = state.resolveType(arg.tpe)
state.addLocalVariable(arg.ident, argTpe, arg.isMutable, arg.isUnused, isGenerated = false)
}
funcAccessedVarsCache match {
case Some(vars) => // the function has been compiled before
state.addAccessedVars(vars)
body.foreach(_.check(state))
case None =>
body.foreach(_.check(state))
val currentScopeUsedVars = Set.from(state.currentScopeAccessedVars)
funcAccessedVarsCache = Some(currentScopeUsedVars)
state.addAccessedVars(currentScopeUsedVars)
}
state.checkUnusedLocalVars(id)
state.checkUnassignedLocalMutableVars(id)
if (rtypes.nonEmpty) checkRetTypes(body.lastOption)
}
def genMethod(state: Compiler.State[Ctx]): Method[Ctx] = {
state.setFuncScope(id)
val instrs = body.flatMap(_.genCode(state))
val localVars = state.getLocalVars(id)
Method[Ctx](
isPublic,
usePreapprovedAssets,
useAssetsInContract != Ast.NotUseContractAssets,
usePayToContractOnly = usePayToContractOnly,
argsLength = state.flattenTypeLength(args.map(_.tpe)),
localsLength = localVars.length,
returnLength = state.flattenTypeLength(rtypes),
AVector.from(instrs)
)
}
def reset(): Unit = {
funcAccessedVarsCache = None
methodSelector = None
body.foreach(_.reset())
}
}
object FuncDef {
def main(
stmts: Seq[Ast.Statement[StatefulContext]],
usePreapprovedAssets: Boolean,
useAssetsInContract: Ast.ContractAssetsAnnotation,
useUpdateFields: Boolean
): FuncDef[StatefulContext] = {
FuncDef[StatefulContext](
Seq.empty,
id = FuncId("main", false),
isPublic = true,
usePreapprovedAssets = usePreapprovedAssets,
useAssetsInContract = useAssetsInContract,
usePayToContractOnly = false,
useCheckExternalCaller = true,
useUpdateFields = useUpdateFields,
useMethodIndex = None,
args = Seq.empty,
rtypes = Seq.empty,
bodyOpt = Some(stmts)
)
}
}
final case class StructFieldAlias(isMutable: Boolean, ident: Ident, alias: Option[Ident])
final case class StructDestruction[Ctx <: StatelessContext](
id: TypeId,
vars: Seq[StructFieldAlias],
expr: Expr[Ctx]
) extends Statement[Ctx] {
def check(state: Compiler.State[Ctx]): Unit = {
val struct = expr.getType(state) match {
case Seq(tpe: Type.Struct) if tpe.id == id => state.getStruct(id)
case types =>
throw Compiler.Error(
s"Expected struct type ${quote(id.name)}, got ${quoteTypes(types)}",
expr.sourceIndex
)
}
vars.foreach { v =>
val fieldType = state.resolveType(struct.getField(v.ident).tpe)
val varIdent = v.alias.getOrElse(v.ident)
state.addLocalVariable(
varIdent,
fieldType,
v.isMutable,
isUnused = false,
isGenerated = false
)
}
}
def genCode(state: Compiler.State[Ctx]): Seq[Instr[Ctx]] = {
val (structRef, instrs) = state.getOrCreateVariablesRef(expr)
instrs ++ vars.flatMap { v =>
val varIdent = v.alias.getOrElse(v.ident)
val loadCodes = structRef.genLoadCode(state, IdentSelector(v.ident))
val storeCodes = state.genStoreCode(varIdent).reverse.flatten
loadCodes ++ storeCodes
}
}
def reset(): Unit = expr.reset()
}
sealed trait AssignmentTarget[Ctx <: StatelessContext] extends Typed[Ctx, Type] {
def ident: Ident
def checkMutable(state: Compiler.State[Ctx], sourceIndex: Option[SourceIndex]): Unit
def genStore(state: Compiler.State[Ctx]): Seq[Seq[Instr[Ctx]]]
}
final case class AssignmentSimpleTarget[Ctx <: StatelessContext](ident: Ident)
extends AssignmentTarget[Ctx] {
def _getType(state: Compiler.State[Ctx]): Type = {
val variable = state.getVariable(ident, isWrite = true)
state.resolveType(variable.tpe)
}
def checkMutable(state: Compiler.State[Ctx], sourceIndex: Option[SourceIndex]): Unit = {
val variable = state.getVariable(ident)
variable match {
case _: Compiler.VarInfo.MapVar =>
throw Compiler.Error(s"Cannot assign to map variable ${ident.name}.", sourceIndex)
case _ =>
}
if (!variable.isMutable) {
throw Compiler.Error(s"Cannot assign to immutable variable ${ident.name}.", sourceIndex)
}
if (!state.isTypeMutable(getType(state))) {
throw Compiler.Error(
s"Cannot assign to variable ${ident.name}. Assignment only works when all of the (nested) fields are mutable.",
sourceIndex
)
}
}
def genStore(state: Compiler.State[Ctx]): Seq[Seq[Instr[Ctx]]] = state.genStoreCode(ident)
}
sealed trait DataSelector extends Positioned {
def reset(): Unit = this match {
case IndexSelector(expr) => expr.reset()
case _: IdentSelector => ()
}
}
final case class IndexSelector[Ctx <: StatelessContext](index: Expr[Ctx]) extends DataSelector
final case class IdentSelector(ident: Ident) extends DataSelector
final case class AssignmentSelectedTarget[Ctx <: StatelessContext](
ident: Ident,
selectors: Seq[DataSelector]
) extends AssignmentTarget[Ctx]
with AccessDataT[Ctx] {
// scalastyle:off method.length
private def checkMap(
state: Compiler.State[Ctx],
mapType: Type.Map,
selectors: Seq[DataSelector],
sourceIndex: Option[SourceIndex]
): Unit = {
if (selectors.isEmpty) {
if (!state.isTypeMutable(mapType.value)) {
throw Compiler.Error(
s"Cannot assign to value in map ${quote(ident.name)}. Assignment only works when all of the (nested) fields are mutable.",
sourceIndex
)
}
} else {
checkMutable(
state,
mapType.value,
selectors,
Ident(s"${ident.name}[${mapType.key.signature}]"),
None,
sourceIndex
)
}
}
@scala.annotation.tailrec
private def checkMutable(
state: Compiler.State[Ctx],
rootType: Type,
selectors: Seq[DataSelector],
lastField: Ident,
structId: Option[TypeId],
sourceIndex: Option[SourceIndex]
): Unit = {
(rootType, selectors) match {
case (array: Type.FixedSizeArray, Seq(IndexSelector(_))) =>
if (!state.isTypeMutable(array.baseType)) {
val arraySelector =
structId.map(id => s"${id.name}.${lastField.name}").getOrElse(lastField.name)
throw Compiler.Error(
s"Cannot assign to immutable element in array $arraySelector. Assignment only works when all of the (nested) fields are mutable.",
sourceIndex
)
}
case (array: Type.FixedSizeArray, IndexSelector(_) +: tail) =>
checkMutable(state, array.baseType, tail, lastField, structId, sourceIndex)
case (map: Type.Map, (_: IndexSelector[Ctx @unchecked]) +: tail) =>
checkMap(state, map, tail, sourceIndex)
case (struct: Type.Struct, Seq(IdentSelector(ident))) =>
val field = state.getStruct(struct.id).getField(ident)
if (!field.isMutable) {
throw Compiler.Error(
s"Cannot assign to immutable field ${field.name} in struct ${struct.id.name}.",
sourceIndex
)
}
if (!state.isTypeMutable(field.tpe)) {
throw Compiler.Error(
s"Cannot assign to field ${field.name} in struct ${struct.id.name}. Assignment only works when all of the (nested) fields are mutable.",
sourceIndex
)
}
case (struct: Type.Struct, IdentSelector(ident) +: tail) =>
val field = state.getStruct(struct.id).getField(ident)
if (!field.isMutable) {
throw Compiler.Error(
s"Cannot assign to immutable field ${field.name} in struct ${struct.id.name}.",
sourceIndex
)
}
val fieldType = state.resolveType(field.tpe)
checkMutable(state, fieldType, tail, field.ident, Some(struct.id), sourceIndex)
case _ => // dead branch
throw Compiler.Error(s"Invalid selectors ${selectors} for type $rootType", sourceIndex)
}
}
// scalastyle:on method.length
def _getType(state: Compiler.State[Ctx]): Type = {
val variable = state.getVariable(ident, isWrite = true)
_getType(state, state.resolveType(variable.tpe), ident.sourceIndex)
}
def checkMutable(state: Compiler.State[Ctx], sourceIndex: Option[SourceIndex]): Unit = {
val variable = state.getVariable(ident)
if (!variable.isMutable) {
throw Compiler.Error(s"Cannot assign to immutable variable ${ident.name}.", sourceIndex)
}
checkMutable(state, state.resolveType(variable.tpe), selectors, ident, None, sourceIndex)
}
@SuppressWarnings(Array("org.wartremover.warts.IterableOps"))
def genStore(state: Compiler.State[Ctx]): Seq[Seq[Instr[Ctx]]] = {
val variable = state.getVariable(ident)
variable.tpe match {
case map: Type.Map =>
val pathCodes = MapOps.genSubContractPath(state, ident, mapKeyIndex)
MapOps.genStore(state, map.value, getType(state), pathCodes, selectors.tail)
case _ =>
val ref = state.getVariablesRef(ident)
val subRef = ref.subRef(state, selectors.init)
subRef.genStoreCode(state, selectors.last)
}
}
override def reset(): Unit = {
selectors.foreach(_.reset())
super.reset()
}
}
trait OriginContractInfo {
private var originContractId: Option[TypeId] = None
def withOrigin(typeId: TypeId): this.type = {
originContractId = Some(typeId)
this
}
def origin: Option[TypeId] = originContractId
def definedIn(typeId: TypeId): Boolean = origin.contains(typeId)
}
sealed trait ConstantDefinition extends OriginContractInfo
final case class ConstantVarDef[Ctx <: StatelessContext](
ident: Ident,
expr: Expr[Ctx]
) extends GlobalDefinition
with ConstantDefinition {
def name: String = ident.name
}
final case class EnumField[Ctx <: StatelessContext](ident: Ident, value: Const[Ctx])
extends UniqueDef
with ConstantDefinition {
def name: String = ident.name
}
final case class EnumDef[Ctx <: StatelessContext](id: TypeId, fields: Seq[EnumField[Ctx]])
extends GlobalDefinition {
def name: String = id.name
}
object EnumDef {
def fieldIdent(enumId: TypeId, field: Ident): Ident =
Ident(s"${enumId.name}.${field.name}").atSourceIndex(field.sourceIndex)
}
final case class EventDef(
id: TypeId,
fields: Seq[EventField]
) extends UniqueDef {
def name: String = id.name
def signature: String = s"event ${id.name}(${fields.map(_.signature).mkString(",")})"
def getFieldNames(): AVector[String] = AVector.from(fields.view.map(_.ident.name))
def getFieldTypeSignatures(): AVector[String] = AVector.from(fields.view.map(_.tpe.signature))
}
final case class EmitEvent[Ctx <: StatefulContext](id: TypeId, args: Seq[Expr[Ctx]])
extends Statement[Ctx] {
override def check(state: Compiler.State[Ctx]): Unit = {
val eventInfo = state.getEvent(id)
val argsType = args.flatMap(_.getType(state))
if (argsType.exists(t => t.isArrayType || t.isStructType)) {
throw Compiler.Error(
s"Array and struct types are not supported for event ${quote(s"${state.typeId.name}.${id.name}")}",
sourceIndex
)
}
eventInfo.checkFieldTypes(state, argsType, args.headOption.flatMap(_.sourceIndex))
}
override def genCode(state: Compiler.State[Ctx]): Seq[Instr[Ctx]] = {
val eventIndex = {
val index = state.eventsInfo.map(_.typeId).indexOf(id)
// `check` method ensures that this event is defined
assume(index >= 0)
Const[Ctx](Val.I256(I256.from(index))).genCode(state)
}
val logOpCode = Compiler.genLogs(args.length, id.sourceIndex)
eventIndex ++ args.flatMap(_.genCode(state)) :+ logOpCode
}
def reset(): Unit = args.foreach(_.reset())
}
final case class Assign[Ctx <: StatelessContext](
targets: Seq[AssignmentTarget[Ctx]],
rhs: Expr[Ctx]
) extends Statement[Ctx] {
override def check(state: Compiler.State[Ctx]): Unit = {
val leftTypes = targets.map(_.getType(state))
val rightTypes = rhs.getType(state)
if (leftTypes != rightTypes) {
throw Compiler.Error(
s"Cannot assign ${quoteTypes(rightTypes)} to ${quoteTypes(leftTypes)}",
sourceIndex
)
}
targets.foreach(_.checkMutable(state, sourceIndex))
}
override def genCode(state: Compiler.State[Ctx]): Seq[Instr[Ctx]] = {
rhs.genCode(state) ++ targets.flatMap(_.genStore(state)).reverse.flatten
}
def reset(): Unit = {
targets.foreach(_.reset())
rhs.reset()
}
}
sealed trait CallStatement[Ctx <: StatelessContext] extends Statement[Ctx] {
def checkReturnValueUsed(
state: Compiler.State[Ctx],
typeId: TypeId,
funcId: FuncId,
retTypes: Seq[Type]
): Unit = {
if (retTypes.nonEmpty && retTypes != Seq(Type.Panic)) {
state.warningUnusedCallReturn(typeId, funcId, retTypes.length)
}
}
}
final case class FuncCall[Ctx <: StatelessContext](
id: FuncId,
approveAssets: Seq[ApproveAsset[Ctx]],
args: Seq[Expr[Ctx]]
) extends CallStatement[Ctx]
with CallAst[Ctx] {
def ignoreReturn: Boolean = true
def getFunc(state: Compiler.State[Ctx]): Compiler.FuncInfo[Ctx] = state.getFunc(id)
override def check(state: Compiler.State[Ctx]): Unit = {
checkApproveAssets(state)
val funcInfo = getFunc(state)
val retTypes = positionedError(funcInfo.getReturnType(args.flatMap(_.getType(state)), state))
checkReturnValueUsed(state, state.typeId, id, retTypes)
}
override def genCode(state: Compiler.State[Ctx]): Seq[Instr[Ctx]] = {
state.addInternalCall(
id
) // don't put this in _getType, otherwise the statement might get skipped
_genCode(state)
}
def reset(): Unit = {
approveAssets.foreach(_.reset())
args.foreach(_.reset())
}
}
final case class StaticContractFuncCall[Ctx <: StatelessContext](
contractId: TypeId,
id: FuncId,
approveAssets: Seq[ApproveAsset[Ctx]],
args: Seq[Expr[Ctx]]
) extends CallStatement[Ctx]
with CallAst[Ctx] {
def ignoreReturn: Boolean = true
def getFunc(state: Compiler.State[Ctx]): Compiler.ContractFunc[Ctx] =
state.getFunc(contractId, id)
override def check(state: Compiler.State[Ctx]): Unit = {
checkApproveAssets(state)
val funcInfo = getFunc(state)
checkStaticContractFunction(contractId, id, funcInfo)
val retTypes = positionedError(funcInfo.getReturnType(args.flatMap(_.getType(state)), state))
checkReturnValueUsed(state, contractId, id, retTypes)
}
override def genCode(state: Compiler.State[Ctx]): Seq[Instr[Ctx]] = {
_genCode(state)
}
def reset(): Unit = {
approveAssets.foreach(_.reset())
args.foreach(_.reset())
}
}
final case class ContractCall(
obj: Expr[StatefulContext],
callId: FuncId,
approveAssets: Seq[ApproveAsset[StatefulContext]],
args: Seq[Expr[StatefulContext]]
) extends CallStatement[StatefulContext]
with ContractCallBase {
override def check(state: Compiler.State[StatefulContext]): Unit = {
checkApproveAssets(state)
val (contractId, retTypes) = _getTypeBase(state)
checkReturnValueUsed(state, contractId, callId, retTypes)
}
override def genCode(state: Compiler.State[StatefulContext]): Seq[Instr[StatefulContext]] = {
genContractCall(state, true)
}
def reset(): Unit = {
obj.reset()
approveAssets.foreach(_.reset())
args.foreach(_.reset())
}
}
final case class IfBranchStatement[Ctx <: StatelessContext](
condition: Expr[Ctx],
body: Seq[Statement[Ctx]]
) extends IfBranch[Ctx] {
def genCode(state: Compiler.State[Ctx]): Seq[Instr[Ctx]] = body.flatMap(_.genCode(state))
def reset(): Unit = {
condition.reset()
body.foreach(_.reset())
}
}
final case class ElseBranchStatement[Ctx <: StatelessContext](
body: Seq[Statement[Ctx]]
) extends ElseBranch[Ctx] {
def genCode(state: Compiler.State[Ctx]): Seq[Instr[Ctx]] = body.flatMap(_.genCode(state))
def reset(): Unit = body.foreach(_.reset())
}
final case class IfElseStatement[Ctx <: StatelessContext](
ifBranches: Seq[IfBranchStatement[Ctx]],
elseBranchOpt: Option[ElseBranchStatement[Ctx]]
) extends IfElse[Ctx]
with Statement[Ctx] {
override def check(state: Compiler.State[Ctx]): Unit = {
ifBranches.foreach(_.checkCondition(state))
ifBranches.foreach { ifBranch =>
state.withScope(ifBranch) { ifBranch.body.foreach(_.check(state)) }
}
elseBranchOpt.foreach { elseBranch =>
state.withScope(elseBranch) { elseBranch.body.foreach(_.check(state)) }
}
}
def reset(): Unit = {
ifBranches.foreach(_.reset())
elseBranchOpt.foreach(_.reset())
}
}
final case class While[Ctx <: StatelessContext](
condition: Expr[Ctx],
body: Seq[Statement[Ctx]]
) extends Statement[Ctx] {
override def check(state: Compiler.State[Ctx]): Unit = state.withScope(this) {
if (condition.getType(state) != Seq(Type.Bool)) {
throw Compiler.Error(s"Invalid type of conditional expr ${quote(condition)}", sourceIndex)
}
body.foreach(_.check(state))
}
override def genCode(state: Compiler.State[Ctx]): Seq[Instr[Ctx]] = state.withScope(this) {
val bodyIR = body.flatMap(_.genCode(state))
val condIR = Statement.getCondIR(condition, state, bodyIR.length + 1)
val whileLen = condIR.length + bodyIR.length + 1
if (whileLen > 0xff) {
// TODO: support long branches
throw Compiler.Error(s"Too many instructions for if-else branches", sourceIndex)
}
condIR ++ bodyIR :+ Jump(-whileLen)
}
def reset(): Unit = {
condition.reset()
body.foreach(_.reset())
}
}
final case class ForLoop[Ctx <: StatelessContext](
initialize: Statement[Ctx],
condition: Expr[Ctx],
update: Statement[Ctx],
body: Seq[Statement[Ctx]]
) extends Statement[Ctx] {
override def check(state: Compiler.State[Ctx]): Unit = state.withScope(this) {
initialize.check(state)
if (condition.getType(state) != Seq(Type.Bool)) {
throw Compiler.Error(s"Invalid condition type: $condition", sourceIndex)
}
update.check(state)
body.foreach(_.check(state))
}
override def genCode(state: Compiler.State[Ctx]): Seq[Instr[Ctx]] = state.withScope(this) {
val initializeIR = initialize.genCode(state)
val bodyIR = body.flatMap(_.genCode(state))
val updateIR = update.genCode(state)
val fullBodyLength = bodyIR.length + updateIR.length + 1
val condIR = Statement.getCondIR(condition, state, fullBodyLength)
val jumpLength = condIR.length + fullBodyLength
initializeIR ++ condIR ++ bodyIR ++ updateIR :+ Jump(-jumpLength)
}
def reset(): Unit = {
initialize.reset()
condition.reset()
update.reset()
body.foreach(_.reset())
}
}
final case class ReturnStmt[Ctx <: StatelessContext](exprs: Seq[Expr[Ctx]])
extends Statement[Ctx] {
override def check(state: Compiler.State[Ctx]): Unit = {
state.checkReturn(exprs.flatMap(_.getType(state)), sourceIndex)
}
def genCode(state: Compiler.State[Ctx]): Seq[Instr[Ctx]] =
exprs.flatMap(_.genCode(state)) :+ Return
def reset(): Unit = exprs.foreach(_.reset())
}
final case class Debug[Ctx <: StatelessContext](
stringParts: AVector[Val.ByteVec],
interpolationParts: Seq[Expr[Ctx]]
) extends Statement[Ctx] {
def check(state: Compiler.State[Ctx]): Unit = {
interpolationParts.foreach(_.getType(state))
}
def genCode(state: Compiler.State[Ctx]): Seq[Instr[Ctx]] = {
if (state.allowDebug) {
interpolationParts.flatMap(_.genCode(state)) :+
vm.DEBUG(stringParts)
} else {
Seq.empty
}
}
def reset(): Unit = interpolationParts.foreach(_.reset())
}
object TemplateVar {
private val arraySuffix = "-template-array"
private val structSuffix = "-template-struct"
@inline private[ralph] def rename(ident: Ident, tpe: Type): Ident = {
tpe match {
case _: Type.FixedSizeArray => Ident(s"_${ident.name}$arraySuffix")
case _: Type.Struct => Ident(s"_${ident.name}$structSuffix")
case _ => ident
}
}
}
final case class GlobalState[Ctx <: StatelessContext](
structs: Seq[Struct],
constantVars: Seq[Ast.ConstantVarDef[Ctx]],
enums: Seq[Ast.EnumDef[Ctx]]
) extends Constants[Ctx] {
private[ralph] val constants = mutable.Map.empty[Ast.Ident, Compiler.VarInfo.Constant[Ctx]]
private val usedConstants: mutable.Set[Ast.Ident] = mutable.Set.empty
def getCalculatedConstants(): Seq[(Ident, Val)] = {
constantVars.map(c => (c.ident, constants(c.ident).value))
}
@inline def getConstantOpt(ident: Ident): Option[Compiler.VarInfo.Constant[Ctx]] = {
usedConstants.addOne(ident)
constants.get(ident)
}
def getUnusedGlobalConstantsWarning(): Option[String] = {
val unused = mutable.ArrayBuffer.empty[String]
constantVars.foreach { c =>
if (!usedConstants.contains(c.ident)) unused.addOne(c.name)
}
enums.foreach(e =>
e.fields.foreach { field =>
val fieldIdent = EnumDef.fieldIdent(e.id, field.ident)
if (!usedConstants.contains(fieldIdent)) unused.addOne(fieldIdent.name)
}
)
if (unused.isEmpty) None else Some(Warnings.unusedGlobalConstants(unused.toSeq))
}
def getConstant(ident: Ident): Compiler.VarInfo.Constant[Ctx] = {
getConstantOpt(ident) match {
case Some(v: Compiler.VarInfo.Constant[Ctx @unchecked]) => v
case _ =>
throw Compiler.Error(
s"Constant variable ${ident.name} does not exist or is used before declaration",
ident.sourceIndex
)
}
}
def addConstant(ident: Ident, value: Val, constantDef: Ast.ConstantDefinition): Unit = {
val tpe = Type.fromVal(value.tpe)
constants(ident) = Compiler.VarInfo.Constant(tpe, value, Seq(value.toConstInstr), constantDef)
}
private val flattenSizeCache = mutable.Map.empty[Type, Int]
@SuppressWarnings(Array("org.wartremover.warts.Recursion"))
private def flattenSize(tpe: Type, accessedTypes: Seq[TypeId]): Int = {
tpe match {
case Type.NamedType(id) =>
if (accessedTypes.contains(id)) {
throw Compiler.Error(
s"These structs ${quote(accessedTypes.map(_.name))} have circular references",
id.sourceIndex
)
}
structs.find(_.id == id) match {
case Some(struct) =>
struct.fields.map(f => getFlattenSize(f.tpe, accessedTypes :+ id)).sum
case None => 1
}
case t: Type.FixedSizeArray =>
calcArraySize(t) * flattenSize(t.baseType, accessedTypes)
case Type.Struct(id) => flattenSize(Type.NamedType(id), accessedTypes)
case _ => 1
}
}
private def getFlattenSize(tpe: Type, accessedTypes: Seq[TypeId]): Int = {
flattenSizeCache.get(tpe) match {
case Some(size) => size
case None =>
val size = flattenSize(tpe, accessedTypes)
flattenSizeCache(tpe) = size
size
}
}
private val typeCache: mutable.Map[Type, Type] = mutable.Map.empty
@SuppressWarnings(Array("org.wartremover.warts.Recursion"))
private def _resolveType(tpe: Type): Type = {
tpe match {
case t: Type.NamedType =>
structs.find(_.id == t.id) match {
case Some(struct) => struct.tpe
case None => Type.Contract(t.id)
}
case t: Type.FixedSizeArray =>
Type.FixedSizeArray(resolveType(t.baseType), Left(calcArraySize(t)))
case Type.Map(key, value) =>
Type.Map(resolveType(key), resolveType(value))
case _ => tpe
}
}
@inline def resolveType(tpe: Type): Type = {
tpe match {
case _: Type.NamedType | _: Type.FixedSizeArray | _: Type.Map =>
typeCache.get(tpe) match {
case Some(tpe) => tpe
case None =>
val resolvedType = _resolveType(tpe)
typeCache.update(tpe, resolvedType)
resolvedType
}
case _ => tpe
}
}
@inline def resolveTypes(types: Seq[Type]): Seq[Type] = types.map(resolveType)
def flattenTypeLength(types: Seq[Type]): Int = {
types.foldLeft(0) { case (acc, tpe) =>
tpe match {
case _: Type.FixedSizeArray | _: Type.NamedType | _: Type.Struct =>
acc + getFlattenSize(tpe, Seq.empty)
case _ => acc + 1
}
}
}
@SuppressWarnings(Array("org.wartremover.warts.Recursion"))
def flattenType(tpe: Type): Seq[Type] = {
resolveType(tpe) match {
case Type.Struct(id) =>
getStruct(id).fields.flatMap(field => flattenType(field.tpe))
case t: Type.FixedSizeArray =>
val baseTypes = flattenType(t.baseType)
Seq.fill(calcArraySize(t))(baseTypes).flatten
case tpe => Seq(tpe)
}
}
def getStruct(typeId: Ast.TypeId): Ast.Struct = {
structs.find(_.id == typeId) match {
case Some(struct) => struct
case None =>
throw Compiler.Error(s"Struct ${quote(typeId.name)} does not exist", typeId.sourceIndex)
}
}
@SuppressWarnings(Array("org.wartremover.warts.Recursion"))
def flattenTypeMutability(tpe: Type, isMutable: Boolean): Seq[Boolean] = {
val resolvedType = resolveType(tpe)
if (isMutable) {
resolvedType match {
case t: Type.FixedSizeArray =>
val array = flattenTypeMutability(t.baseType, isMutable)
Seq.fill(calcArraySize(t))(array).flatten
case Type.Struct(id) =>
getStruct(id).fields.flatMap(field =>
flattenTypeMutability(resolveType(field.tpe), field.isMutable && isMutable)
)
case _ => Seq(isMutable)
}
} else {
Seq.fill(flattenTypeLength(Seq(resolvedType)))(false)
}
}
}
object GlobalState {
def from[Ctx <: StatelessContext](definitions: Seq[GlobalDefinition]): GlobalState[Ctx] = {
val structs = mutable.ArrayBuffer.empty[Ast.Struct]
val constantVars = mutable.ArrayBuffer.empty[Ast.ConstantVarDef[Ctx]]
val enums = mutable.ArrayBuffer.empty[Ast.EnumDef[Ctx]]
definitions.foreach {
case s: Ast.Struct => structs.addOne(s)
case c: Ast.ConstantVarDef[Ctx @unchecked] => constantVars.addOne(c)
case e: Ast.EnumDef[Ctx @unchecked] => enums.addOne(e)
case d => throw Compiler.Error(s"Invalid global definition: ${d.name}", d.sourceIndex)
}
val globalState = GlobalState[Ctx](structs.toSeq, constantVars.toSeq, enums.toSeq)
globalState.addConstants(globalState.constantVars)
globalState.addEnums(globalState.enums)
globalState.structs.foreach(_.fields.foreach(_.tpe match {
case t: Type.FixedSizeArray => globalState.calcArraySize(t); ()
case _ => ()
}))
globalState
}
}
sealed trait ContractT[Ctx <: StatelessContext] extends GlobalDefinition {
def ident: TypeId
def templateVars: Seq[Argument]
def fields: Seq[Argument]
def funcs: Seq[FuncDef[Ctx]]
def name: String = ident.name
def builtInContractFuncs(globalState: GlobalState[Ctx]): Seq[Compiler.ContractFunc[Ctx]]
private var functionTable: Option[Map[FuncId, Compiler.ContractFunc[Ctx]]] = None
def funcTable(globalState: GlobalState[Ctx]): Map[FuncId, Compiler.ContractFunc[Ctx]] = {
functionTable match {
case Some(funcs) => funcs
case None =>
val builtInFuncs = builtInContractFuncs(globalState)
val isInterface = this match {
case _: ContractInterface => true
case _ => false
}
var table = Compiler.SimpleFunc
.from(funcs, isInterface)
.map(f => f.funcDef.id -> f)
.toMap[FuncId, Compiler.ContractFunc[Ctx]]
builtInFuncs.foreach(func =>
table = table + (FuncId(func.name, isBuiltIn = true) -> func)
)
if (table.size != (funcs.size + builtInFuncs.length)) {
val (duplicates, sourceIndex) = UniqueDef.duplicates(funcs)
throw Compiler.Error(
s"These functions are defined multiple times: $duplicates",
sourceIndex
)
}
functionTable = Some(table)
table
}
}
private def addTemplateVars(state: Compiler.State[Ctx]): Unit = {
templateVars.foreach { templateVar =>
val tpe = state.resolveType(templateVar.tpe)
val ident = TemplateVar.rename(templateVar.ident, tpe)
state.addTemplateVariable(ident, tpe)
}
if (state.templateVarIndex >= Compiler.State.maxVarIndex) {
throw Compiler.Error(
s"Number of template variables more than ${Compiler.State.maxVarIndex}",
ident.sourceIndex
)
}
}
protected def checkConstants(@nowarn state: Compiler.State[Ctx]): Unit = {}
private def checkAndAddFields(state: Compiler.State[Ctx]): Unit = {
fields.foreach { field =>
state.addFieldVariable(
field.ident,
state.resolveType(field.tpe),
field.isMutable,
field.isUnused,
isGenerated = false
)
}
}
def check(state: Compiler.State[Ctx]): Unit = {
state.setCheckPhase()
state.checkArguments(fields)
addTemplateVars(state)
checkAndAddFields(state)
checkConstants(state)
funcs.foreach(_.check(state))
state.checkUnusedMaps()
state.checkUnusedFieldsAndConstants()
state.checkUnassignedMutableFields()
}
def genMethods(state: Compiler.State[Ctx]): AVector[Method[Ctx]] = {
AVector.from(funcs.view.map(_.genMethod(state)))
}
def genCode(state: Compiler.State[Ctx]): VmContract[Ctx]
def reset(): Unit = funcs.foreach(_.reset())
}
final case class AssetScript(
ident: TypeId,
templateVars: Seq[Argument],
funcs: Seq[FuncDef[StatelessContext]]
) extends ContractT[StatelessContext] {
val fields: Seq[Argument] = Seq.empty
def builtInContractFuncs(
globalState: GlobalState[StatelessContext]
): Seq[Compiler.ContractFunc[StatelessContext]] = Seq.empty
def genCode(state: Compiler.State[StatelessContext]): StatelessScript = {
state.setGenCodePhase()
StatelessScript
.from(genMethods(state))
.getOrElse(
throw Compiler.Error(s"No methods found in ${quote(ident.name)}", ident.sourceIndex)
)
}
def genCodeFull(state: Compiler.State[StatelessContext]): StatelessScript = {
check(state)
val script = genCode(state)
StaticAnalysis.checkMethodsStateless(this, script.methods, state)
script
}
}
sealed trait ContractWithState extends ContractT[StatefulContext] {
def inheritances: Seq[Inheritance]
def templateVars: Seq[Argument]
def fields: Seq[Argument]
def maps: Seq[MapDef]
def events: Seq[EventDef]
def constantVars: Seq[ConstantVarDef[StatefulContext]]
def enums: Seq[EnumDef[StatefulContext]]
def builtInContractFuncs(
globalState: GlobalState[StatefulContext]
): Seq[Compiler.ContractFunc[StatefulContext]] = Seq.empty
def eventsInfo(): Seq[Compiler.EventInfo] = {
UniqueDef.checkDuplicates(events, "events")
events.map { event =>
Compiler.EventInfo(event.id, event.fields.map(_.tpe))
}
}
}
final case class TxScript(
ident: TypeId,
templateVars: Seq[Argument],
funcs: Seq[FuncDef[StatefulContext]]
) extends ContractWithState {
val fields: Seq[Argument] = Seq.empty
val events: Seq[EventDef] = Seq.empty
val inheritances: Seq[ContractInheritance] = Seq.empty
def error(tpe: String): Compiler.Error =
Compiler.Error(s"TxScript ${ident.name} should not contain any $tpe", sourceIndex)
def constantVars: Seq[ConstantVarDef[StatefulContext]] = throw error("constant variable")
def enums: Seq[EnumDef[StatefulContext]] = throw error("enum")
def maps: Seq[MapDef] = throw error("map")
def getTemplateVarsSignature(): String =
s"TxScript ${name}(${templateVars.map(_.signature).mkString(",")})"
def getTemplateVarsNames(): AVector[String] = AVector.from(templateVars.view.map(_.ident.name))
def getTemplateVarsTypes(): AVector[String] =
AVector.from(templateVars.view.map(_.tpe.signature))
def getTemplateVarsMutability(): AVector[Boolean] =
AVector.from(templateVars.view.map(_.isMutable))
def withTemplateVarDefs(globalState: GlobalState[StatefulContext]): TxScript = {
val templateVarDefs = templateVars.foldLeft(Seq.empty[Statement[StatefulContext]]) {
case (acc, arg) =>
val argType = globalState.resolveType(arg.tpe)
argType match {
case _: Type.FixedSizeArray | _: Type.Struct =>
acc :+ VarDef(
Seq(NamedVar(mutable = false, arg.ident)),
Variable(TemplateVar.rename(arg.ident, argType))
)
case _ => acc
}
}
val newFuncs =
funcs.map(func => func.copy(bodyOpt = Some(templateVarDefs ++ func.body)).withOrigin(ident))
this.copy(funcs = newFuncs)
}
@SuppressWarnings(Array("org.wartremover.warts.IterableOps"))
def genCode(state: Compiler.State[StatefulContext]): StatefulScript = {
state.setGenCodePhase()
val methods = genMethods(state)
StatefulScript
.from(methods)
.getOrElse(
throw Compiler.Error(
"Expected the 1st function to be public and the other functions to be private for tx script",
sourceIndex
)
)
}
def genCodeFull(state: Compiler.State[StatefulContext]): StatefulScript = {
check(state)
val script = genCode(state)
StaticAnalysis.checkMethodsStateful(this, script.methods, state)
script
}
}
sealed trait Inheritance extends Positioned {
def parentId: TypeId
}
final case class ContractInheritance(parentId: TypeId, idents: Seq[Ident]) extends Inheritance
final case class InterfaceInheritance(parentId: TypeId) extends Inheritance
final case class Contract(
stdIdEnabled: Option[Boolean],
stdInterfaceId: Option[StdInterfaceId],
isAbstract: Boolean,
ident: TypeId,
templateVars: Seq[Argument],
fields: Seq[Argument],
funcs: Seq[FuncDef[StatefulContext]],
maps: Seq[MapDef],
events: Seq[EventDef],
constantVars: Seq[ConstantVarDef[StatefulContext]],
enums: Seq[EnumDef[StatefulContext]],
inheritances: Seq[Inheritance]
) extends ContractWithState {
lazy val hasStdIdField: Boolean = stdIdEnabled.exists(identity) && stdInterfaceId.nonEmpty
lazy val contractFields: Seq[Argument] = if (hasStdIdField) fields :+ Ast.stdArg else fields
lazy val selfDefinedConstants: Seq[Ident] = {
val constants = mutable.ArrayBuffer.empty[Ident]
constantVars.foreach { c =>
if (c.definedIn(ident)) constants.addOne(c.ident)
}
enums.foreach(e =>
e.fields.foreach { field =>
if (field.definedIn(ident)) {
constants.addOne(EnumDef.fieldIdent(e.id, field.ident))
}
}
)
constants.toSeq
}
def getFieldsSignature(): String =
s"Contract ${name}(${contractFields.map(_.signature).mkString(",")})"
def getFieldNames(): AVector[String] = AVector.from(contractFields.view.map(_.ident.name))
def getFieldTypes(): AVector[String] = AVector.from(contractFields.view.map(_.tpe.signature))
def getFieldMutability(): AVector[Boolean] = AVector.from(contractFields.view.map(_.isMutable))
override def builtInContractFuncs(
globalState: GlobalState[StatefulContext]
): Seq[Compiler.ContractFunc[StatefulContext]] = {
val stdInterfaceIdOpt = if (hasStdIdField) stdInterfaceId else None
Seq(BuiltIn.encodeFields(stdInterfaceIdOpt, fields, globalState))
}
private def checkFuncs(): Unit = {
if (funcs.length < 1) {
throw Compiler.Error(
s"No function found in Contract ${quote(ident.name)}",
ident.sourceIndex
)
}
}
@SuppressWarnings(Array("org.wartremover.warts.OptionPartial"))
def getFuncUnsafe(funcId: FuncId): FuncDef[StatefulContext] = funcs.find(_.id == funcId).get
private var calculatedConstants: Option[Seq[(Ident, Val)]] = None
def getCalculatedConstants(): Seq[(Ident, Val)] = calculatedConstants.getOrElse(Seq.empty)
override def checkConstants(state: Compiler.State[StatefulContext]): Unit = {
constantVars.foreach { c =>
if (state.globalState.constantVars.exists(_.ident == c.ident)) {
throw Compiler.Error(
s"Local constant ${c.name} conflicts with an existing global constant, please use a fresh name",
c.sourceIndex
)
}
}
val constants = state.addConstants(constantVars)
if (constants.nonEmpty) calculatedConstants = Some(constants)
enums.foreach { e =>
if (state.globalState.enums.exists(_.id == e.id)) {
throw Compiler.Error(
s"Local enum ${e.name} conflicts with an existing global enum, please use a fresh name",
e.sourceIndex
)
}
}
state.addEnums(enums)
}
private def checkInheritances(state: Compiler.State[StatefulContext]): Unit = {
inheritances.foreach { inheritance =>
val id = inheritance.parentId
val kind = state.getContractInfo(id).kind
if (!kind.inheritable) {
throw Compiler.Error(s"$kind ${id.name} can not be inherited", id.sourceIndex)
}
}
}
private def checkFields(state: Compiler.State[StatefulContext]): Unit = {
fields.foreach { case field @ Argument(fieldId, tpe, isFieldMutable, _) =>
state.resolveType(tpe) match {
case Type.Struct(structId) =>
val isStructImmutable = state.flattenTypeMutability(tpe, isMutable = true).forall(!_)
if (isFieldMutable && isStructImmutable) {
throw Compiler.Error(
s"The struct ${structId.name} is immutable, please remove the `mut` from ${ident.name}.${fieldId.name}",
field.sourceIndex
)
}
case _ => ()
}
}
}
private def checkMaps(state: Compiler.State[StatefulContext]): Unit = {
UniqueDef.checkDuplicates(maps, "maps")
maps.find(mapDef => fields.exists(_.ident == mapDef.ident)) match {
case Some(mapDef) =>
throw Compiler.Error(
s"The map ${mapDef.ident.name} cannot have the same name as the contract field",
mapDef.ident.sourceIndex
)
case _ => ()
}
maps.view.zipWithIndex.foreach { case (m, index) =>
val mapType = Type.Map(m.tpe.key, state.resolveType(m.tpe.value))
state.addMapVar(m.ident, mapType, index)
}
}
override def check(state: Compiler.State[StatefulContext]): Unit = {
state.setCheckPhase()
checkFields(state)
checkMaps(state)
checkFuncs()
checkInheritances(state)
super.check(state)
}
override def genMethods(
state: Compiler.State[StatefulContext]
): AVector[Method[StatefulContext]] = {
val selectors = mutable.Map.empty[Method.Selector, FuncId]
AVector.from(funcs.view.map { func =>
val method = func.genMethod(state)
if (func.isPublic && state.isUseMethodSelector(ident, func.id)) {
val methodSelector = func.getMethodSelector(state.globalState)
selectors.get(methodSelector) match {
case Some(funcId) =>
throw Compiler.Error(
s"Function ${func.name}'s method selector conflicts with function ${funcId.name}'s method selector. Please use a new function name.",
func.id.sourceIndex
)
case _ => ()
}
selectors(methodSelector) = func.id
val methodSelectorInstr = MethodSelector(methodSelector)
method.copy(instrs = methodSelectorInstr +: method.instrs)
} else {
method
}
})
}
def genCode(state: Compiler.State[StatefulContext]): StatefulContract = {
assume(!isAbstract)
state.setGenCodePhase()
val methods = genMethods(state)
val fieldsLength =
state.flattenTypeLength(fields.map(_.tpe)) + (if (hasStdIdField) 1 else 0)
StatefulContract(fieldsLength, methods)
}
// the state must have been updated in the check pass
def buildCheckExternalCallerTable(
state: Compiler.State[StatefulContext]
): mutable.Map[FuncId, Boolean] = {
val checkExternalCallerTable = mutable.Map.empty[FuncId, Boolean]
funcs.foreach(func => checkExternalCallerTable(func.id) = false)
// TODO: optimize these two functions
def updateCheckedRecursivelyForPrivateMethod(checkedPrivateCalleeId: FuncId): Unit = {
state.internalCallsReversed.get(checkedPrivateCalleeId) match {
case Some(callers) =>
callers.foreach { caller =>
updateCheckedRecursively(getFuncUnsafe(caller))
}
case None => ()
}
}
def updateCheckedRecursively(func: FuncDef[StatefulContext]): Unit = {
if (!checkExternalCallerTable(func.id)) {
checkExternalCallerTable(func.id) = true
if (func.isPrivate) { // indirect check external caller should be in private methods
updateCheckedRecursivelyForPrivateMethod(func.id)
}
}
}
funcs.foreach { func =>
if (!func.isPublic && func.hasCheckExternalCallerAnnotation) {
state.warnPrivateFuncHasCheckExternalCaller(ident, func.id)
}
if (func.hasDirectCheckExternalCaller()) {
updateCheckedRecursively(func)
}
}
checkExternalCallerTable
}
}
final case class ContractInterface(
stdId: Option[StdInterfaceId],
useMethodSelector: Boolean,
ident: TypeId,
funcs: Seq[FuncDef[StatefulContext]],
events: Seq[EventDef],
inheritances: Seq[InterfaceInheritance]
) extends ContractWithState {
def error(tpe: String): Compiler.Error =
Compiler.Error(
s"Interface ${quote(ident.name)} should not contain any ${quote(tpe)}",
sourceIndex
)
def templateVars: Seq[Argument] = throw error("template variable")
def fields: Seq[Argument] = throw error("field")
def maps: Seq[MapDef] = throw error("map")
def getFieldsSignature(): String = throw error("field")
def getFieldTypes(): Seq[String] = throw error("field")
def constantVars: Seq[ConstantVarDef[StatefulContext]] = throw error("constant variable")
def enums: Seq[EnumDef[StatefulContext]] = throw error("enum")
def genCode(state: Compiler.State[StatefulContext]): StatefulContract = {
throw Compiler.Error(s"Interface ${quote(ident.name)} should not generate code", sourceIndex)
}
}
final case class MultiContract(
contracts: Seq[ContractWithState],
globalState: GlobalState[StatefulContext],
dependencies: Option[Map[TypeId, Seq[TypeId]]],
methodSelectorTable: Option[Map[(TypeId, FuncId), Boolean]]
) extends Positioned {
lazy val contractsTable = contracts.map { contract =>
val kind = contract match {
case _: Ast.ContractInterface =>
Compiler.ContractKind.Interface
case _: Ast.TxScript =>
Compiler.ContractKind.TxScript
case txContract: Ast.Contract =>
Compiler.ContractKind.Contract(txContract.isAbstract)
}
contract.ident -> Compiler.ContractInfo(kind, contract.funcTable(globalState))
}.toMap
def structs: Seq[Struct] = globalState.structs
def enums: Seq[EnumDef[StatefulContext]] = globalState.enums
def get(contractIndex: Int): ContractWithState = {
if (contractIndex >= 0 && contractIndex < contracts.size) {
contracts(contractIndex)
} else {
throw Compiler.Error(s"Invalid contract index $contractIndex", None)
}
}
private def getContract(typeId: TypeId): ContractWithState = {
contracts.find(_.ident.name == typeId.name) match {
case None =>
throw Compiler.Error(s"Contract ${quote(typeId.name)} does not exist", typeId.sourceIndex)
case Some(ts: TxScript) =>
throw Compiler.Error(
s"Expected contract ${quote(typeId.name)}, but was script",
ts.sourceIndex
)
case Some(contract: ContractWithState) => contract
}
}
def isContract(typeId: TypeId): Boolean = {
contracts.find(_.ident.name == typeId.name) match {
case None =>
throw Compiler.Error(s"Contract ${quote(typeId.name)} does not exist", typeId.sourceIndex)
case Some(contract: Contract) if !contract.isAbstract => true
case _ => false
}
}
def getInterface(typeId: TypeId): ContractInterface = {
getContract(typeId) match {
case interface: ContractInterface => interface
case _ =>
throw Compiler.Error(s"Interface ${typeId.name} does not exist", typeId.sourceIndex)
}
}
@SuppressWarnings(Array("org.wartremover.warts.Recursion"))
private def buildDependencies(
contract: ContractWithState,
parentsCache: mutable.Map[TypeId, Seq[ContractWithState]],
visited: mutable.Set[TypeId]
): Unit = {
if (!visited.add(contract.ident)) {
throw Compiler.Error(
s"Cyclic inheritance detected for contract ${contract.ident.name}",
contract.sourceIndex
)
}
val allParents = mutable.LinkedHashMap.empty[TypeId, ContractWithState]
contract.inheritances.foreach { inheritance =>
val parentId = inheritance.parentId
val parentContract = getContract(parentId)
MultiContract.checkInheritanceFields(contract, inheritance, parentContract)
allParents += parentId -> parentContract
if (!parentsCache.contains(parentId)) {
buildDependencies(parentContract, parentsCache, visited)
}
parentsCache(parentId).foreach { grandParent =>
allParents += grandParent.ident -> grandParent
}
}
parentsCache += contract.ident -> allParents.values.toSeq
}
private def buildDependencies(): mutable.Map[TypeId, Seq[ContractWithState]] = {
val parentsCache = mutable.Map.empty[TypeId, Seq[ContractWithState]]
val visited = mutable.Set.empty[TypeId]
contracts.foreach {
case _: TxScript => ()
case contract =>
if (!parentsCache.contains(contract.ident)) {
buildDependencies(contract, parentsCache, visited)
}
}
parentsCache
}
@SuppressWarnings(Array("org.wartremover.warts.IsInstanceOf"))
def extendedContracts(): MultiContract = {
UniqueDef.checkDuplicates(
contracts ++ structs ++ enums,
"TxScript/Contract/Interface/Struct/Enum"
)
val methodSelectorTable = mutable.Map.empty[(TypeId, Ast.FuncId), Boolean]
val parentsCache = buildDependencies()
val newContracts: Seq[ContractWithState] = contracts.map {
case script: TxScript =>
script.withTemplateVarDefs(globalState).atSourceIndex(script.sourceIndex)
case c: Contract =>
val (stdIdEnabled, stdId, funcs, maps, events, constantVars, enums) =
MultiContract.extractContract(parentsCache, methodSelectorTable, c)
Contract(
Some(stdIdEnabled),
stdId,
c.isAbstract,
c.ident,
c.templateVars,
c.fields,
funcs,
maps,
events,
constantVars,
enums,
c.inheritances
).atSourceIndex(c.sourceIndex)
case i: ContractInterface =>
val (stdId, funcs, events) =
MultiContract.extractInterface(parentsCache, methodSelectorTable, i)
ContractInterface(stdId, i.useMethodSelector, i.ident, funcs, events, i.inheritances)
.atSourceIndex(i.sourceIndex)
}
val dependencies = Map.from(parentsCache.map(p => (p._1, p._2.map(_.ident))))
MultiContract(newContracts, globalState, Some(dependencies), Some(methodSelectorTable.toMap))
}
def genStatefulScripts()(implicit compilerOptions: CompilerOptions): AVector[CompiledScript] = {
AVector.from(contracts.view.zipWithIndex.collect { case (_: TxScript, index) =>
genStatefulScript(index)
})
}
def genStatefulScript(contractIndex: Int)(implicit
compilerOptions: CompilerOptions
): CompiledScript = {
val state = Compiler.State.buildFor(this, contractIndex)
get(contractIndex) match {
case script: TxScript =>
val statefulScript = script.genCodeFull(state)
val warnings = state.getWarnings
state.allowDebug = true
val statefulDebugScript = script.genCode(state)
CompiledScript(statefulScript, script, warnings, statefulDebugScript)
case c: Contract =>
throw Compiler.Error(s"The code is for Contract, not for TxScript", c.sourceIndex)
case ci: ContractInterface =>
throw Compiler.Error(s"The code is for Interface, not for TxScript", ci.sourceIndex)
}
}
private def checkUnusedDefsInParentContract(
states: Map[TypeId, (Contract, Compiler.State[StatefulContext])],
defsInParentContract: Contract => Iterable[(TypeId, String)],
usedDefsInContract: (
Contract,
Compiler.State[StatefulContext]
) => Iterable[(TypeId, String)],
genWarning: (TypeId, collection.Seq[String]) => String
): AVector[String] = {
val allDefs = mutable.Set.empty[(TypeId, String)]
states.foreach { case (_, (contract, _)) =>
if (contract.isAbstract) allDefs.addAll(defsInParentContract(contract))
}
states.foreach { case (_, (contract, state)) =>
if (!contract.isAbstract) allDefs.subtractAll(usedDefsInContract(contract, state))
}
if (allDefs.nonEmpty) {
AVector.from(allDefs.groupBy(_._1).map { case (parentId, defs) =>
genWarning(parentId, defs.map(_._2).toSeq)
})
} else {
AVector.empty[String]
}
}
private def checkUnusedLocalConstants(
states: Map[TypeId, (Contract, Compiler.State[StatefulContext])]
): AVector[String] = {
val defsInParentContract = (contract: Contract) => {
contract.selfDefinedConstants.map(ident => (contract.ident, ident.name))
}
checkUnusedDefsInParentContract(
states,
defsInParentContract,
(_, state) => state.getUsedParentConstants(),
Warnings.unusedLocalConstants
)
}
private def checkUnusedPrivateFunctions(
states: Map[TypeId, (Contract, Compiler.State[StatefulContext])]
): AVector[String] = {
val defsInParentContract = (contract: Contract) => {
contract.funcs.collect {
case func if func.isPrivate && func.definedIn(contract.ident) =>
(contract.ident, func.name)
}
}
val usedDefsInContract = (contract: Contract, state: Compiler.State[StatefulContext]) => {
val usedPrivateFuncs = mutable.ArrayBuffer.empty[(TypeId, String)]
state.internalCallsReversed.keys.foreach { funcId =>
contract.funcs.find(f => f.id == funcId && f.isPrivate).foreach { func =>
func.origin.foreach { originId =>
if (originId != contract.ident) usedPrivateFuncs.addOne((originId, func.name))
}
}
}
usedPrivateFuncs
}
checkUnusedDefsInParentContract(
states,
defsInParentContract,
usedDefsInContract,
Warnings.unusedPrivateFunctions
)
}
private def checkUnusedDefsInParentContract(
states: AVector[Compiler.State[StatefulContext]]
)(implicit compilerOptions: CompilerOptions): AVector[String] = {
if (
compilerOptions.ignoreUnusedConstantsWarnings && compilerOptions.ignoreUnusedPrivateFunctionsWarnings
) {
AVector.empty[String]
} else {
val contractAndStates = contracts.view.zipWithIndex.collect {
case (contract: Contract, index) => (contract.ident, (contract, states(index)))
}.toMap
val unusedConstantsWarnings = if (!compilerOptions.ignoreUnusedConstantsWarnings) {
checkUnusedLocalConstants(contractAndStates)
} else {
AVector.empty[String]
}
val unusedPrivateFuncsWarnings =
if (!compilerOptions.ignoreUnusedPrivateFunctionsWarnings) {
checkUnusedPrivateFunctions(contractAndStates)
} else {
AVector.empty[String]
}
unusedConstantsWarnings ++ unusedPrivateFuncsWarnings
}
}
def genStatefulContracts()(implicit
compilerOptions: CompilerOptions
): (AVector[String], AVector[(CompiledContract, Int)]) = {
val states = AVector.tabulate(contracts.length)(Compiler.State.buildFor(this, _))
val statefulContracts = AVector.from(contracts.view.zipWithIndex.collect {
case (contract: Contract, index) if !contract.isAbstract =>
val state = states(index)
contract.check(state)
state.allowDebug = true
val statefulDebugContract = contract.genCode(state)
(statefulDebugContract, contract, state, index)
})
StaticAnalysis.checkExternalCalls(this, states)
val warnings = checkUnusedDefsInParentContract(states)
val compiled = statefulContracts.map { case (statefulDebugContract, contract, state, index) =>
val statefulContract = genReleaseCode(contract, statefulDebugContract, state)
StaticAnalysis.checkMethods(contract, statefulDebugContract, state)
CompiledContract(
statefulContract,
contract,
state.getWarnings,
statefulDebugContract
) -> index
}
(warnings, compiled)
}
def genReleaseCode(
contract: Contract,
debugCode: StatefulContract,
state: Compiler.State[StatefulContext]
): StatefulContract = {
if (debugCode.methods.exists(_.instrs.exists(_.isInstanceOf[DEBUG]))) {
state.allowDebug = false
contract.genCode(state)
} else {
debugCode
}
}
def genStatefulContract(contractIndex: Int)(implicit
compilerOptions: CompilerOptions
): CompiledContract = {
get(contractIndex) match {
case contract: Contract =>
if (contract.isAbstract) {
throw Compiler.Error(
s"Code generation is not supported for abstract contract ${quote(contract.ident.name)}",
contract.sourceIndex
)
}
val statefulContracts = genStatefulContracts()._2
statefulContracts.find(_._2 == contractIndex) match {
case Some(v) => v._1
case None => // should never happen
throw Compiler.Error(
s"Failed to compile contract ${contract.ident.name}",
contract.sourceIndex
)
}
case ts: TxScript =>
throw Compiler.Error(s"The code is for TxScript, not for Contract", ts.sourceIndex)
case ci: ContractInterface =>
throw Compiler.Error(s"The code is for Interface, not for Contract", ci.sourceIndex)
}
}
}
object MultiContract {
def checkInheritanceFields(
contract: ContractWithState,
inheritance: Inheritance,
parentContract: ContractWithState
): Unit = {
inheritance match {
case i: ContractInheritance => _checkInheritanceFields(contract, i, parentContract)
case _ => ()
}
}
private def _checkInheritanceFields(
contract: ContractWithState,
inheritance: ContractInheritance,
parentContract: ContractWithState
): Unit = {
val fields = inheritance.idents.map { ident =>
contract.fields
.find(_.ident.name == ident.name)
.getOrElse(
throw Compiler.Error(
s"Inherited field ${quote(ident.name)} does not exist in contract ${quote(contract.name)}",
ident.sourceIndex
)
)
}
if (fields != parentContract.fields) {
throw Compiler.Error(
s"Invalid contract inheritance fields, expected ${quote(parentContract.fields)}, got ${quote(fields)}",
fields.headOption.flatMap(_.sourceIndex)
)
}
}
@inline private[ralph] def getStdId(
interfaces: Seq[ContractInterface]
): Option[StdInterfaceId] = {
interfaces.foldLeft[Option[StdInterfaceId]](None) { case (parentStdIdOpt, interface) =>
(parentStdIdOpt, interface.stdId) match {
case (Some(parentStdId), Some(stdId)) =>
if (stdId.bytes == parentStdId.bytes) {
throw Compiler.Error(
s"The std id of interface ${interface.ident.name} is the same as parent interface",
interface.sourceIndex
)
}
if (!stdId.bytes.startsWith(parentStdId.bytes)) {
throw Compiler.Error(
s"The std id of interface ${interface.ident.name} should start with ${Hex
.toHexString(parentStdId.bytes.drop(Ast.StdInterfaceIdPrefix.length))}",
interface.sourceIndex
)
}
Some(stdId)
case (Some(parentStdId), None) => Some(parentStdId)
case (None, stdId) => stdId
}
}
}
@inline private[ralph] def getStdIdEnabled(
contracts: Seq[Contract],
typeId: Ast.TypeId
): Boolean = {
contracts
.foldLeft[Option[Boolean]](None) {
case (None, contract) => contract.stdIdEnabled
case (v, contract) =>
if (contract.stdIdEnabled.nonEmpty && contract.stdIdEnabled != v) {
throw Compiler.Error(
s"There are different std id enabled options on the inheritance chain of contract ${typeId.name}",
typeId.sourceIndex
)
}
v
}
.getOrElse(true)
}
def extractInterface(
parentsCache: mutable.Map[TypeId, Seq[ContractWithState]],
methodSelectorTable: mutable.Map[(TypeId, FuncId), Boolean],
interface: ContractInterface
): (Option[StdInterfaceId], Seq[FuncDef[StatefulContext]], Seq[EventDef]) = {
val parents = parentsCache(interface.ident).map {
case parent: ContractInterface => parent
case p =>
throw Compiler.Error(s"${p.ident.name} is not an interface", p.ident.sourceIndex)
}
if (!interface.useMethodSelector) {
parents.find(_.useMethodSelector) match {
case Some(parent) =>
throw Compiler.Error(
s"Interface ${interface.name} does not use method selector, but its parent ${parent.name} use method selector",
interface.ident.sourceIndex
)
case None => ()
}
}
val sortedInterfaces = sortInterfaces(parentsCache, parents :+ interface)
val stdId = getStdId(sortedInterfaces)
val allFuncs = sortedInterfaces.flatMap { parentOrSelf =>
parentOrSelf.funcs.foreach(func =>
methodSelectorTable.update((interface.ident, func.id), parentOrSelf.useMethodSelector)
)
parentOrSelf.funcs
}
val (unimplementedFuncs, _) = checkFuncs(allFuncs)
// call the `checkFuncs` first to avoid duplicate function definition
checkInterfaceMethodIndex(sortedInterfaces)
val events = sortedInterfaces.flatMap(_.events)
(stdId, unimplementedFuncs, events)
}
// scalastyle:off method.length
@SuppressWarnings(Array("org.wartremover.warts.IsInstanceOf"))
def extractContract(
parentsCache: mutable.Map[TypeId, Seq[ContractWithState]],
methodSelectorTable: mutable.Map[(TypeId, FuncId), Boolean],
contract: Contract
): (
Boolean,
Option[StdInterfaceId],
Seq[FuncDef[StatefulContext]],
Seq[MapDef],
Seq[EventDef],
Seq[ConstantVarDef[StatefulContext]],
Seq[EnumDef[StatefulContext]]
) = {
val parents = parentsCache(contract.ident)
val (allContracts, allInterfaces) = (parents :+ contract).partition(_.isInstanceOf[Contract])
val sortedInterfaces =
sortInterfaces(parentsCache, allInterfaces.map(_.asInstanceOf[ContractInterface]))
val stdId = getStdId(sortedInterfaces)
val stdIdEnabled = getStdIdEnabled(allContracts.map(_.asInstanceOf[Contract]), contract.ident)
val interfaceFuncs = sortedInterfaces.flatMap { interface =>
interface.funcs.foreach(func =>
methodSelectorTable.update((contract.ident, func.id), interface.useMethodSelector)
)
interface.funcs
}
val allFuncs = interfaceFuncs ++ allContracts.flatMap(_.funcs)
val (unimplementedFuncs, allUniqueFuncs) = checkFuncs(allFuncs)
val constantVars = allContracts.flatMap(_.constantVars)
val enums = mergeEnums(allContracts.flatMap(_.enums))
// call the `checkFuncs` first to avoid duplicate function definition
checkInterfaceMethodIndex(sortedInterfaces)
val contractEvents = allContracts.flatMap(_.events)
val maps = allContracts.flatMap(_.maps)
val events = sortedInterfaces.flatMap(_.events) ++ contractEvents
if (!contract.isAbstract && unimplementedFuncs.nonEmpty) {
val methodNames = unimplementedFuncs.map(_.name).mkString(",")
throw Compiler.Error(
s"Contract ${contract.name} has unimplemented methods: $methodNames",
contract.sourceIndex
)
}
val funcs = if (contract.isAbstract) {
allUniqueFuncs
} else {
rearrangeFuncs(sortedInterfaces, allUniqueFuncs)
}
(stdIdEnabled, stdId, funcs, maps, events, constantVars, enums)
}
// scalastyle:on method.length
private def rearrangeFuncs(
interfaces: Seq[ContractInterface],
funcs: Seq[FuncDef[StatefulContext]]
): Seq[FuncDef[StatefulContext]] = {
val interfaceFuncs = interfaces.flatMap(_.funcs)
val (remains, preDefinedIndexFuncs) = funcs.partitionMap { func =>
val methodIndex = interfaceFuncs.find(_.id == func.id).flatMap(_.useMethodIndex)
if (methodIndex.isDefined) {
Right(func.copy(useMethodIndex = methodIndex).atSourceIndex(func.sourceIndex))
} else {
Left(func)
}
}
val invalidFuncs = preDefinedIndexFuncs.filter(_.useMethodIndex.exists(_ >= funcs.length))
if (invalidFuncs.nonEmpty) {
throw Compiler.Error(
s"The method index of these functions is out of bound: ${invalidFuncs.map(_.name).mkString(",")}, total number of methods: ${funcs.length}",
invalidFuncs.headOption.flatMap(_.id.sourceIndex)
)
}
val remainFuncsIterator = remains.iterator
funcs.indices.map { index =>
preDefinedIndexFuncs.find(_.useMethodIndex.contains(index)) match {
case Some(func) => func
case None => remainFuncsIterator.next()
}
}
}
@tailrec
def ensureChainedInterfaces(sortedInterfaces: Seq[ContractInterface]): Unit = {
if (sortedInterfaces.length >= 2) {
val parent = sortedInterfaces(0)
val child = sortedInterfaces(1)
if (!child.inheritances.exists(_.parentId.name == parent.ident.name)) {
throw Compiler.Error(
s"Interface ${child.name} does not inherit from ${parent.name}, " +
s"please annotate ${child.name} with @using(methodSelector = true) annotation",
child.sourceIndex
)
}
ensureChainedInterfaces(sortedInterfaces.drop(1))
}
}
@SuppressWarnings(Array("org.wartremover.warts.IterableOps"))
def checkInterfaceMethodIndex(sortedInterfaces: Seq[ContractInterface]): Unit = {
val methodLength = sortedInterfaces.map(_.funcs.length).sum
val predefinedMethodIndexMax = sortedInterfaces
.map(_.funcs.map(_.useMethodIndex.getOrElse(-1)).max)
.maxOption
.getOrElse(-1)
val methodLengthMax = math.max(methodLength, predefinedMethodIndexMax + 1)
assume(methodLengthMax <= 0xff + 1)
val usedMethodIndexes = mutable.ArrayBuffer.fill(methodLengthMax)(false)
var fromMethodIndex = 0
sortedInterfaces.foreach { interface =>
val (preDefinedMethodIndexFuncs, remains) =
interface.funcs.partition(_.useMethodIndex.nonEmpty)
preDefinedMethodIndexFuncs.foreach { func =>
func.useMethodIndex match {
case Some(index) =>
if (usedMethodIndexes(index)) {
throw Compiler.Error(
s"Function ${interface.name}.${func.id.name} have invalid predefined method index $index",
func.id.sourceIndex
)
} else {
usedMethodIndexes(index) = true
}
case _ => // dead branch
}
}
remains.foreach { _ =>
val methodIndex = usedMethodIndexes.indexOf(false, fromMethodIndex)
assume(methodIndex != -1)
usedMethodIndexes(methodIndex) = true
fromMethodIndex = methodIndex + 1
}
}
}
private[ralph] def sortInterfaces(
parentsCache: mutable.Map[TypeId, Seq[ContractWithState]],
interfaces: Seq[ContractInterface]
): Seq[ContractInterface] = {
val (useMethodSelector, notUseMethodSelector) = interfaces.partition(_.useMethodSelector)
val chainedInterfaces =
notUseMethodSelector.sortBy(interface => parentsCache(interface.ident).length)
ensureChainedInterfaces(chainedInterfaces)
chainedInterfaces ++ useMethodSelector.sorted(
Ordering
.by[ContractInterface, Int](interface => parentsCache(interface.ident).length)
.orElse(Ordering.by[ContractInterface, String](_.name))
)
}
def mergeEnums(enums: Seq[EnumDef[StatefulContext]]): Seq[EnumDef[StatefulContext]] = {
val (enums0, enums1) = enums.partition(e => enums.count(_.id == e.id) == 1)
val mergedEnums = mutable.Map.empty[TypeId, mutable.ArrayBuffer[EnumField[StatefulContext]]]
enums1.foreach { enumDef =>
mergedEnums.get(enumDef.id) match {
case Some(fields) =>
// enum fields will never be empty
val expectedType = enumDef.fields(0).value.v.tpe
val haveType = fields(0).value.v.tpe
if (expectedType != haveType) {
throw Compiler.Error(
s"There are different field types in the enum ${enumDef.id.name}: $expectedType,$haveType",
fields(0).sourceIndex
)
}
val conflictFields = enumDef.fields.filter(f => fields.exists(_.name == f.name))
if (conflictFields.nonEmpty) {
throw Compiler.Error(
s"There are conflict fields in the enum ${enumDef.id.name}: ${conflictFields.map(_.name).mkString(",")}",
conflictFields.headOption.flatMap(_.sourceIndex)
)
}
fields.appendAll(enumDef.fields)
case None => mergedEnums(enumDef.id) = mutable.ArrayBuffer.from(enumDef.fields)
}
}
enums0 ++ mergedEnums.view.map(pair => EnumDef(pair._1, pair._2.toSeq)).toSeq
}
def checkFuncs(
allFuncs: Seq[FuncDef[StatefulContext]]
): (Seq[FuncDef[StatefulContext]], Seq[FuncDef[StatefulContext]]) = {
val (abstractFuncs, nonAbstractFuncs) = allFuncs.partition(_.bodyOpt.isEmpty)
val nonAbstractFuncSet = nonAbstractFuncs.view.map(f => f.id.name -> f).toMap
val abstractFuncsSet = abstractFuncs.view.map(f => f.id.name -> f).toMap
if (nonAbstractFuncSet.size != nonAbstractFuncs.size) {
val (duplicates, sourceIndex) = UniqueDef.duplicates(nonAbstractFuncs)
throw Compiler.Error(
s"These functions are implemented multiple times: $duplicates",
sourceIndex
)
}
if (abstractFuncsSet.size != abstractFuncs.size) {
val (duplicates, sourceIndex) = UniqueDef.duplicates(abstractFuncs)
throw Compiler.Error(
s"These abstract functions are defined multiple times: $duplicates",
sourceIndex
)
}
val (implementedFuncs, unimplementedFuncs) =
abstractFuncs.partition(func => nonAbstractFuncSet.contains(func.id.name))
implementedFuncs.foreach { abstractFunc =>
val funcName = abstractFunc.id.name
val implementedAbstractFunc = nonAbstractFuncSet(funcName)
if (implementedAbstractFunc.signature != abstractFunc.signature) {
throw Compiler.Error(
s"Function ${quote(funcName)} is implemented with wrong signature",
implementedAbstractFunc.sourceIndex
)
}
}
val inherited = abstractFuncs.map { f => nonAbstractFuncSet.getOrElse(f.id.name, f) }
val nonInherited = nonAbstractFuncs.filter(f => !abstractFuncsSet.contains(f.id.name))
(unimplementedFuncs, inherited ++ nonInherited)
}
}
}
// scalastyle:on number.of.methods number.of.types