All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.alephium.protocol.vm.Contract.scala Maven / Gradle / Ivy

The newest version!
// Copyright 2018 The Alephium Authors
// This file is part of the alephium project.
//
// The library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the library. If not, see .

package org.alephium.protocol.vm

import java.math.BigInteger

import scala.annotation.switch
import scala.collection.mutable

import akka.util.ByteString

import org.alephium.io.IOError
import org.alephium.macros.HashSerde
import org.alephium.protocol.Hash
import org.alephium.protocol.model.{ContractId, HardFork}
import org.alephium.serde
import org.alephium.serde._
import org.alephium.util.{AVector, Bytes, EitherF, Hex}

final case class Method[Ctx <: StatelessContext](
    isPublic: Boolean,
    usePreapprovedAssets: Boolean,
    useContractAssets: Boolean,
    usePayToContractOnly: Boolean,
    argsLength: Int,
    localsLength: Int,
    returnLength: Int,
    instrs: AVector[Instr[Ctx]]
) {
  def usesAssetsFromInputs(): Boolean = usePreapprovedAssets || useContractAssets

  def checkModifierSinceRhone(): ExeResult[Unit] = {
    if (useContractAssets && usePayToContractOnly) {
      failed(InvalidMethodModifierSinceRhone)
    } else {
      okay
    }
  }

  def checkModifierPreRhone(): ExeResult[Unit] = {
    if (!usePayToContractOnly) {
      okay
    } else {
      failed(InvalidMethodModifierBeforeRhone)
    }
  }

  def checkModifierPreLeman(): ExeResult[Unit] = {
    if (usePreapprovedAssets == useContractAssets) {
      okay
    } else {
      failed(InvalidMethodModifierBeforeLeman)
    }
  }

  def toTemplateString(): String = {
    val prefix = Hex.toHexString(
      serialize(isPublic) ++
        Method.serializeAssetModifier(this) ++
        serialize(argsLength) ++
        serialize(localsLength) ++
        serialize(returnLength) ++
        serialize(instrs.length)
    )
    prefix ++ instrs.map(_.toTemplateString()).mkString("")
  }

  def mockup(): Method[Ctx] = this.copy(instrs = instrs.map(_.mockup()))
}

object Method {
  val payToContractOnlyMask: Int = 4

  private def serializeAssetModifier[Ctx <: StatelessContext](method: Method[Ctx]): ByteString = {
    val first2bits = (method.usePreapprovedAssets, method.useContractAssets) match {
      case (false, false) => 0 // isPayble = false before Leman fork
      case (true, true)   => 1 //  isPayable = true before Leman fork
      case (false, true)  => 2
      case (true, false)  => 3
    }
    ByteString(first2bits | (if (method.usePayToContractOnly) payToContractOnlyMask else 0))
  }

  private def deserializeAssetModifier[Ctx <: StatelessContext](
      input: Byte
  ): SerdeResult[(Boolean, Boolean, Boolean)] = {
    val payToContractOnlyFlag = (input & payToContractOnlyMask) != 0
    (input & ~payToContractOnlyMask: @switch) match {
      case 0 => Right((false, false, payToContractOnlyFlag))
      case 1 => Right((true, true, payToContractOnlyFlag))
      case 2 => Right((false, true, payToContractOnlyFlag))
      case 3 => Right((true, false, payToContractOnlyFlag))
      case _ => Left(SerdeError.wrongFormat("Invalid assets modifier"))
    }
  }

  private def serdeGen[Ctx <: StatelessContext](implicit instrsSerde: Serde[AVector[Instr[Ctx]]]) =
    new Serde[Method[Ctx]] {
      override def serialize(method: Method[Ctx]): ByteString = {
        serde.serialize(method.isPublic) ++
          Method.serializeAssetModifier(method) ++
          serde.serialize(method.argsLength) ++
          serde.serialize(method.localsLength) ++
          serde.serialize(method.returnLength) ++
          serde.serialize(method.instrs)
      }

      def _deserialize(input: ByteString): SerdeResult[Staging[Method[Ctx]]] = {
        for {
          isPublicRest      <- serde._deserialize[Boolean](input)
          assetModifierRest <- serde._deserialize[Byte](isPublicRest.rest)
          assetModifier     <- Method.deserializeAssetModifier(assetModifierRest.value)
          argsLengthRest    <- serde._deserialize[Int](assetModifierRest.rest)
          localsLengthRest  <- serde._deserialize[Int](argsLengthRest.rest)
          returnLengthRest  <- serde._deserialize[Int](localsLengthRest.rest)
          instrsRest        <- serde._deserialize[AVector[Instr[Ctx]]](returnLengthRest.rest)
        } yield {
          Staging(
            Method[Ctx](
              isPublicRest.value,
              assetModifier._1,
              assetModifier._2,
              assetModifier._3,
              argsLengthRest.value,
              localsLengthRest.value,
              returnLengthRest.value,
              instrsRest.value
            ),
            instrsRest.rest
          )
        }
      }
    }

  implicit val statelessSerde: Serde[Method[StatelessContext]] = serdeGen[StatelessContext]
  implicit val statefulSerde: Serde[Method[StatefulContext]]   = serdeGen[StatefulContext]

  def validate(method: Method[_]): Boolean =
    method.argsLength >= 0 && method.localsLength >= method.argsLength && method.returnLength >= 0

  def forSMT: Method[StatefulContext] =
    Method[StatefulContext](
      isPublic = false,
      usePreapprovedAssets = false,
      useContractAssets = false,
      usePayToContractOnly = false,
      argsLength = 0,
      localsLength = 0,
      returnLength = 0,
      AVector(Pop)
    )

  final case class Selector(index: Int) extends AnyVal
  object Selector {
    implicit val serde: Serde[Selector] = new Serde[Selector] {
      override def serialize(selector: Selector): ByteString = Bytes.from(selector.index)

      override def _deserialize(input: ByteString): SerdeResult[Staging[Selector]] = {
        if (input.length < 4) {
          Left(SerdeError.validation(s"Invalid int from bytes: $input, expected 4 bytes"))
        } else {
          Right(Staging(Selector(Bytes.toIntUnsafe(input.take(4))), input.drop(4)))
        }
      }
    }
  }

  def extractSelector(methodBytes: ByteString): Option[Selector] = {
    serde._deserialize[Boolean](methodBytes) match {
      case Right(isPublicRest) if isPublicRest.value =>
        val selectorEither = for {
          assetModifierRest <- serde._deserialize[Byte](isPublicRest.rest)
          argsLengthRest    <- serde._deserialize[Int](assetModifierRest.rest)
          localsLengthRest  <- serde._deserialize[Int](argsLengthRest.rest)
          returnLengthRest  <- serde._deserialize[Int](localsLengthRest.rest)
          instrLengthRest   <- serde._deserialize[Int](returnLengthRest.rest)
          selectorInstrRest <-
            if (instrLengthRest.rest.headOption.contains(MethodSelector.code)) {
              MethodSelector.deserialize(instrLengthRest.rest.drop(1))
            } else {
              Left(SerdeError.other("selector does not exist"))
            }
        } yield {
          selectorInstrRest.value.selector
        }
        selectorEither.toOption
      case _ => None
    }
  }
}

sealed trait Contract[Ctx <: StatelessContext] {
  def fieldLength: Int
  def methodsLength: Int
  def getMethod(index: Int): ExeResult[Method[Ctx]]
  def hash: Hash

  def initialStateHash(immFields: AVector[Val], mutFields: AVector[Val]): Hash =
    Hash.doubleHash(
      hash.bytes ++ ContractStorageState.fieldsSerde.serialize(immFields ++ mutFields)
    )

  def checkAssetsModifier(ctx: StatelessContext): ExeResult[Unit] = {
    val hardFork = ctx.getHardFork()
    EitherF.foreachTry(0 until methodsLength) { methodIndex =>
      for {
        method <- getMethod(methodIndex)
        _ <-
          if (hardFork.isRhoneEnabled()) { method.checkModifierSinceRhone() }
          else { method.checkModifierPreRhone() }
        _ <-
          if (hardFork.isLemanEnabled()) { okay }
          else { method.checkModifierPreLeman() }
      } yield ()
    }
  }

  def validate(newImmFields: AVector[Val], newMutFields: AVector[Val]): Boolean = {
    (newImmFields.length + newMutFields.length) == fieldLength
  }

  def check(newImmFields: AVector[Val], newMutFields: AVector[Val]): ExeResult[Unit] = {
    if (validate(newImmFields, newMutFields)) { okay }
    else {
      failed(InvalidFieldLength)
    }
  }

  def mockup(): Contract[Ctx]
}

sealed trait Script[Ctx <: StatelessContext] extends Contract[Ctx] {
  def fieldLength: Int = 0
  def methods: AVector[Method[Ctx]]
  def toObject: ScriptObj[Ctx]

  def methodsLength: Int = methods.length

  def getMethod(index: Int): ExeResult[Method[Ctx]] = {
    methods.get(index).toRight(Right(InvalidMethodIndex(index, methodsLength)))
  }

  def toTemplateString(): String = {
    Hex.toHexString(serialize(methods.length)) ++ methods.map(_.toTemplateString()).mkString("")
  }
}

@HashSerde
final case class StatelessScript private (methods: AVector[Method[StatelessContext]])
    extends Script[StatelessContext] {
  override def toObject: ScriptObj[StatelessContext] = {
    StatelessScriptObject(this)
  }

  def mockup(): StatelessScript = this.copy(methods = methods.map(_.mockup()))
}

object StatelessScript {
  implicit val serde: Serde[StatelessScript] = {
    val serde: Serde[StatelessScript] = Serde.forProduct1(StatelessScript.apply, _.methods)
    serde.validate(script => Either.cond(validate(script.methods), (), s"Invalid script: $script"))
  }

  private def validate(methods: AVector[Method[StatelessContext]]): Boolean = {
    methods.nonEmpty &&
    methods.head.isPublic &&
    methods.forall(m => !m.usePreapprovedAssets && Method.validate(m))
  }

  def from(methods: AVector[Method[StatelessContext]]): Option[StatelessScript] = {
    Option.when(validate(methods))(new StatelessScript(methods))
  }

  def unsafe(methods: AVector[Method[StatelessContext]]): StatelessScript = {
    new StatelessScript(methods)
  }
}

@HashSerde
final case class StatefulScript private (methods: AVector[Method[StatefulContext]])
    extends Script[StatefulContext] {
  def entryMethod: Method[StatefulContext] = methods.head

  override def toObject: ScriptObj[StatefulContext] = {
    StatefulScriptObject(this)
  }

  def mockup(): StatefulScript = this.copy(methods = methods.map(_.mockup()))
}

object StatefulScript {
  implicit val serde: Serde[StatefulScript] = Serde
    .forProduct1[AVector[Method[StatefulContext]], StatefulScript](StatefulScript.unsafe, _.methods)
    .validate(script => if (validate(script.methods)) Right(()) else Left("Invalid TxScript"))

  def unsafe(methods: AVector[Method[StatefulContext]]): StatefulScript = {
    new StatefulScript(methods)
  }

  def from(methods: AVector[Method[StatefulContext]]): Option[StatefulScript] = {
    Option.when(validate(methods))(new StatefulScript(methods))
  }

  def validate(methods: AVector[Method[StatefulContext]]): Boolean = {
    methods.nonEmpty && methods.head.isPublic && methods.forall(Method.validate)
  }

  def alwaysFail: StatefulScript =
    StatefulScript(
      AVector(
        Method[StatefulContext](
          isPublic = true,
          usePreapprovedAssets = true,
          useContractAssets = true,
          usePayToContractOnly = false,
          argsLength = 0,
          localsLength = 0,
          returnLength = 0,
          instrs = AVector(ConstFalse, Assert)
        )
      )
    )
}

@HashSerde
final case class StatefulContract(
    fieldLength: Int,
    methods: AVector[Method[StatefulContext]]
) extends Contract[StatefulContext] {
  def methodsLength: Int = methods.length

  def getMethod(index: Int): ExeResult[Method[StatefulContext]] = {
    methods.get(index).toRight(Right(InvalidMethodIndex(index, methodsLength)))
  }

  def toHalfDecoded(): StatefulContract.HalfDecoded = {
    val methodsBytes = methods.map(Method.statefulSerde.serialize)
    var count        = 0
    val methodIndexes = AVector.tabulate(methods.length) { k =>
      count += methodsBytes(k).length
      count
    }
    StatefulContract.HalfDecoded(
      fieldLength,
      methodIndexes,
      methodsBytes.fold(ByteString.empty)(_ ++ _)
    )
  }

  def mockup(): StatefulContract = this.copy(methods = methods.map(_.mockup()))
}

object StatefulContract {
  final case class SelectorSearchResult(methodIndex: Int, methodSearched: Int)
  @HashSerde
  // We don't need to deserialize the whole contract if we only access part of the methods
  final case class HalfDecoded(
      fieldLength: Int,
      methodIndexes: AVector[Int], // end positions of methods in methodBytes
      methodsBytes: ByteString
  ) extends Contract[StatefulContext] {
    def methodsLength: Int = methodIndexes.length

    private[vm] lazy val methods = Array.ofDim[Method[StatefulContext]](methodsLength)

    def getMethod(index: Int): ExeResult[Method[StatefulContext]] = {
      if (index >= 0 && index < methodsLength) {
        val method = methods(index)
        if (method == null) {
          deserializeMethod(index) match {
            case Left(error) => ioFailed(IOErrorLoadContract(IOError.Serde(error)))
            case Right(method) =>
              methods(index) = method
              Right(method)
          }
        } else {
          Right(method)
        }
      } else {
        failed(InvalidMethodIndex(index, methodsLength))
      }
    }

    @inline
    private def getMethodBytes(index: Int): ByteString = {
      if (index == 0) {
        methodsBytes.take(methodIndexes(0))
      } else {
        methodsBytes.slice(methodIndexes(index - 1), methodIndexes(index))
      }
    }

    private def deserializeMethod(index: Int): SerdeResult[Method[StatefulContext]] = {
      val methodBytes = getMethodBytes(index)
      Method.statefulSerde.deserialize(methodBytes)
    }

    private[vm] var searchedMethodIndex  = -1
    private[vm] lazy val cachedSelectors = mutable.HashMap.empty[Method.Selector, Int]
    def getMethodBySelector(selector: Method.Selector): ExeResult[SelectorSearchResult] = {
      cachedSelectors.get(selector) match {
        case Some(methodIndex) => Right(SelectorSearchResult(methodIndex, 0))
        case None              => findUncachedMethodBySelector(selector)
      }
    }
    def getMethodSelector(methodIndex: Int): Option[Method.Selector] = {
      val methodBytes = getMethodBytes(methodIndex)
      Method.extractSelector(methodBytes)
    }
    private def findUncachedMethodBySelector(
        selector: Method.Selector
    ): ExeResult[SelectorSearchResult] = {
      var found       = false
      var methodIndex = searchedMethodIndex + 1
      while (!found && methodIndex < methodsLength) {
        getMethodSelector(methodIndex) match {
          case Some(foundSelector) =>
            cachedSelectors(foundSelector) = methodIndex
            if (foundSelector == selector) {
              found = true
            } else {
              methodIndex = methodIndex + 1
            }
          case None => methodIndex = methodIndex + 1
        }
      }
      val result = if (found) {
        Right(SelectorSearchResult(methodIndex, methodIndex - searchedMethodIndex))
      } else {
        failed(InvalidMethodSelector(selector))
      }
      searchedMethodIndex = methodIndex
      result
    }

    def toContract(): SerdeResult[StatefulContract] = {
      AVector
        .tabulateE(methodsLength)(deserializeMethod)
        .map(StatefulContract(fieldLength, _))
    }

    // For testing purpose
    def toObjectUnsafeTestOnly(
        contractId: ContractId,
        immFields: AVector[Val],
        mutFields: AVector[Val]
    ): StatefulContractObject = {
      StatefulContractObject.unsafe(
        this.hash,
        this,
        this.initialStateHash(immFields, mutFields),
        immFields,
        mutFields,
        contractId
      )
    }

    def mockup(): HalfDecoded = ???
  }

  object HalfDecoded {
    private val intsSerde: Serde[AVector[Int]] = avectorSerde[Int]
    implicit val serde: Serde[HalfDecoded] = new Serde[HalfDecoded] {
      override def serialize(input: HalfDecoded): ByteString = {
        intSerde.serialize(input.fieldLength) ++
          intsSerde.serialize(input.methodIndexes) ++
          input.methodsBytes
      }

      override def _deserialize(input: ByteString): SerdeResult[Staging[HalfDecoded]] = {
        for {
          fieldLengthRest   <- intSerde._deserialize(input)
          methodIndexesRest <- intsSerde._deserialize(fieldLengthRest.rest)
        } yield {
          // all the `take` and `drop` calls are safe since contracts are checked in the creation
          val length      = methodIndexesRest.value.lastOption.getOrElse(0)
          val data        = methodIndexesRest.rest
          val methodBytes = data.take(length)
          val rest        = data.drop(length)
          Staging(
            HalfDecoded(
              fieldLengthRest.value,
              methodIndexesRest.value,
              methodBytes
            ),
            rest
          )
        }
      }
    }
  }

  implicit val serde: Serde[StatefulContract] = new Serde[StatefulContract] {
    override def serialize(input: StatefulContract): ByteString = {
      HalfDecoded.serde.serialize(input.toHalfDecoded())
    }

    override def _deserialize(input: ByteString): SerdeResult[Staging[StatefulContract]] = {
      for {
        halfDecodedRest <- HalfDecoded.serde._deserialize(input)
        contract        <- halfDecodedRest.value.toContract()
      } yield Staging(contract, halfDecodedRest.rest)
    }
  }

  def check(contract: StatefulContract, hardFork: HardFork): ExeResult[Unit] = {
    if (contract.fieldLength < 0) {
      failed(InvalidFieldLength)
    } else if (hardFork.isLemanEnabled() && contract.fieldLength > 0xff) {
      failed(TooManyFields)
    } else if (contract.methods.isEmpty) {
      failed(EmptyMethods)
    } else if (!contract.methods.forall(Method.validate)) {
      failed(InvalidMethod)
    } else {
      okay
    }
  }

  val forSMT: StatefulContract.HalfDecoded =
    StatefulContract(0, AVector(Method.forSMT)).toHalfDecoded()
}

sealed trait ContractObj[Ctx <: StatelessContext] {
  def contractIdOpt: Option[ContractId]
  def code: Contract[Ctx]
  def immFields: AVector[Val]
  def mutFields: mutable.ArraySeq[Val]

  def getContractId(): ExeResult[ContractId] = contractIdOpt.toRight(Right(ExpectAContract))

  def getAddress(): ExeResult[Val.Address] =
    getContractId().map(id => Val.Address(LockupScript.p2c(id)))

  def isScript(): Boolean = contractIdOpt.isEmpty

  def getMethod(index: Int): ExeResult[Method[Ctx]] = code.getMethod(index)

  def getImmField(index: Int): ExeResult[Val] = {
    immFields.get(index) match {
      case Some(v) => Right(v)
      case None    => failed(InvalidImmFieldIndex(index, immFields.length))
    }
  }

  def getMutField(index: Int): ExeResult[Val] = {
    if (mutFields.isDefinedAt(index)) {
      Right(mutFields(index))
    } else {
      failed(InvalidMutFieldIndex(BigInteger.valueOf(index.toLong), immFields.length))
    }
  }

  def setMutField(index: Int, v: Val): ExeResult[Unit] = {
    if (!mutFields.isDefinedAt(index)) {
      failed(InvalidMutFieldIndex(BigInteger.valueOf(index.toLong), immFields.length))
    } else if (mutFields(index).tpe != v.tpe) {
      failed(InvalidMutFieldType)
    } else {
      Right(mutFields.update(index, v))
    }
  }
}

sealed trait ScriptObj[Ctx <: StatelessContext] extends ContractObj[Ctx] {
  val contractIdOpt: Option[ContractId] = None
  val immFields: AVector[Val]           = ScriptObj.emptyImmFields
  val mutFields: mutable.ArraySeq[Val]  = mutable.ArraySeq.empty
}

object ScriptObj {
  val emptyImmFields: AVector[Val] = AVector.empty
}

final case class StatelessScriptObject(code: StatelessScript) extends ScriptObj[StatelessContext]

final case class StatefulScriptObject(code: StatefulScript) extends ScriptObj[StatefulContext]

final case class StatefulContractObject private (
    codeHash: Hash,
    code: StatefulContract.HalfDecoded,
    initialStateHash: Hash,         // the state hash when the contract is created
    initialMutFields: AVector[Val], // the initial mutable field values when the contract is loaded
    immFields: AVector[Val],
    mutFields: mutable.ArraySeq[Val],
    contractId: ContractId
) extends ContractObj[StatefulContext] {
  def contractIdOpt: Option[ContractId] = Some(contractId)

  def getInitialStateHash(): Val.ByteVec = Val.ByteVec(initialStateHash.bytes)

  def getCodeHash(): Val.ByteVec = Val.ByteVec(codeHash.bytes)

  def isUpdated: Boolean =
    !mutFields.indices.forall(index => mutFields(index) == initialMutFields(index))

  def estimateContractLoadByteSize(): Int = {
    immFields.fold(0)(_ + _.estimateByteSize()) +
      mutFields.foldLeft(0)(_ + _.estimateByteSize()) +
      code.methodsBytes.length
  }
}

object StatefulContractObject {
  def unsafe(
      codeHash: Hash,
      code: StatefulContract.HalfDecoded,
      initialStateHash: Hash,
      immFields: AVector[Val],
      initialMutFields: AVector[Val],
      contractId: ContractId
  ): StatefulContractObject = {
    assume(code.hash == codeHash)
    new StatefulContractObject(
      codeHash,
      code,
      initialStateHash,
      initialMutFields,
      immFields,
      initialMutFields.toArray,
      contractId
    )
  }

  def from(
      contract: StatefulContract,
      immFields: AVector[Val],
      initialMutFields: AVector[Val],
      contractId: ContractId
  ): StatefulContractObject = {
    val code             = contract.toHalfDecoded()
    val codeHash         = code.hash
    val initialStateHash = code.initialStateHash(immFields, initialMutFields)
    unsafe(codeHash, code, initialStateHash, immFields, initialMutFields, contractId)
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy