All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.joern.ghidra2cpg.passes.FunctionPass.scala Maven / Gradle / Ivy

The newest version!
package io.joern.ghidra2cpg.passes

import ghidra.app.util.template.TemplateSimplifier
import ghidra.program.model.address.GenericAddress
import ghidra.program.model.lang.Register
import ghidra.program.model.listing.{CodeUnitFormat, CodeUnitFormatOptions, Function, Instruction, Program}
import ghidra.program.model.pcode.{HighFunction, HighSymbol}
import ghidra.program.model.scalar.Scalar
import io.joern.ghidra2cpg.*
import io.joern.ghidra2cpg.processors.*
import io.joern.ghidra2cpg.utils.Decompiler
import io.joern.ghidra2cpg.utils.Utils.*
import io.shiftleft.codepropertygraph.generated.Cpg
import io.shiftleft.codepropertygraph.generated.nodes.{CfgNodeNew, NewBlock, NewMethod}
import io.shiftleft.codepropertygraph.generated.{EdgeTypes, nodes}
import io.shiftleft.passes.ForkJoinParallelCpgPass

import scala.collection.mutable
import scala.jdk.CollectionConverters.*
import scala.language.implicitConversions

abstract class FunctionPass(
  processor: Processor,
  currentProgram: Program,
  functions: List[Function],
  cpg: Cpg,
  decompiler: Decompiler
) extends ForkJoinParallelCpgPass[Function](cpg) {

  protected val functionByName: mutable.Map[String, Function] = mutable.HashMap[String, Function]()
  for (fn <- functions) {
    val other = functionByName.getOrElseUpdate(fn.getName, fn)
    if (!(other eq fn)) {
      baseLogger.warn(s"Multiple functions with same name ${fn.getName}, can't disambiguate: $fn, $other")
    }
  }

  def getHighFunction(function: Function): Option[HighFunction] =
    decompiler.toHighFunction(function)

  protected def getInstructions(function: Function): Seq[Instruction] =
    currentProgram.getListing.getInstructions(function.getBody, true).iterator().asScala.toList

  // needed by ghidra for decompiling reasons
  protected val codeUnitFormat = new CodeUnitFormat(
    new CodeUnitFormatOptions(
      CodeUnitFormatOptions.ShowBlockName.NEVER,
      CodeUnitFormatOptions.ShowNamespace.NEVER,
      "",
      true,
      true,
      true,
      true,
      true,
      true,
      true,
      new TemplateSimplifier()
    )
  )

  override def generateParts(): Array[Function] = functions.toArray

  implicit def intToIntegerOption(intOption: Option[Int]): Option[Int] = intOption.map(intValue => {
    val integerValue = intValue
    integerValue
  })

  def handleParameters(diffGraphBuilder: DiffGraphBuilder, function: Function, methodNode: NewMethod): Unit = {
    if (function.isThunk)
      function
        .getThunkedFunction(true)
        .getParameters
        .zipWithIndex
        .foreach { case (parameter, index) =>
          val node = createParameterNode(
            parameter.getName,
            parameter.getName,
            index + 1,
            parameter.getDataType.getName,
            function.getEntryPoint.getOffsetAsBigInteger.intValue()
          )
          diffGraphBuilder.addNode(node)
          diffGraphBuilder.addEdge(methodNode, node, EdgeTypes.AST)
        }
    else {
      getHighFunction(function).foreach { highFunction =>
        highFunction.getLocalSymbolMap.getSymbols.asScala.toSeq
          .filter(_.isParameter)
          .foreach { parameter =>
            val checkedParameter = Option(parameter.getStorage)
              .flatMap(x => Option(x.getRegister))
              .flatMap(x => Option(x.getName))
              .getOrElse(parameter.getName)
            val node =
              createParameterNode(
                checkedParameter,
                checkedParameter,
                parameter.getCategoryIndex + 1,
                parameter.getDataType.getName,
                function.getEntryPoint.getOffsetAsBigInteger.intValue()
              )
            diffGraphBuilder.addNode(node)
            diffGraphBuilder.addEdge(methodNode, node, EdgeTypes.AST)
          }
      }
    }
  }

  def handleLocals(diffGraphBuilder: DiffGraphBuilder, function: Function, blockNode: NewBlock): Unit = {
    function.getLocalVariables.foreach { local =>
      val localNode = nodes
        .NewLocal()
        .name(local.getName)
        .code(local.toString)
        .typeFullName(Types.registerType(local.getDataType.toString))
      val identifier =
        createIdentifier(local.getName, local.getSymbol.getName, -1, local.getDataType.toString, -1)

      diffGraphBuilder.addNode(localNode)
      diffGraphBuilder.addNode(identifier)
      diffGraphBuilder.addEdge(blockNode, localNode, EdgeTypes.AST)
      diffGraphBuilder.addEdge(blockNode, identifier, EdgeTypes.AST)
      diffGraphBuilder.addEdge(identifier, localNode, EdgeTypes.REF)
    }
  }

  def handleBody(
    diffGraphBuilder: DiffGraphBuilder,
    function: Function,
    methodNode: NewMethod,
    blockNode: NewBlock
  ): Unit = {
    val instructions = getInstructions(function)
    if (instructions.nonEmpty) {
      var prevInstructionNode = addCallOrReturnNode(instructions.head)
      handleArguments(diffGraphBuilder, instructions.head, prevInstructionNode, function)
      diffGraphBuilder.addEdge(blockNode, prevInstructionNode, EdgeTypes.AST)
      diffGraphBuilder.addEdge(methodNode, prevInstructionNode, EdgeTypes.CFG)
      instructions.drop(1).foreach { instruction =>
        val instructionNode = addCallOrReturnNode(instruction)
        diffGraphBuilder.addNode(instructionNode)
        handleArguments(diffGraphBuilder, instruction, instructionNode, function)
        diffGraphBuilder.addEdge(blockNode, instructionNode, EdgeTypes.AST)
        diffGraphBuilder.addEdge(prevInstructionNode, instructionNode, EdgeTypes.CFG)
        prevInstructionNode = instructionNode
      }
    }
  }

  // Iterating over operands and add edges to call
  def handleArguments(
    diffGraphBuilder: DiffGraphBuilder,
    instruction: Instruction,
    callNode: CfgNodeNew,
    function: Function
  ): Unit = {
    val mnemonicString = processor.getInstructions.getOrElse(instruction.getMnemonicString, "UNKNOWN")
    if (mnemonicString.equals("CALL")) {
      val calledFunction = codeUnitFormat.getOperandRepresentationString(instruction, 0)
      functionByName.get(calledFunction).map { callee =>
        // Array of tuples containing (checked parameter name, parameter index, parameter data type)
        var checkedParameters = Array.empty[(String, Int, String)]

        if (callee.isThunk) {
          // thunk functions contain parameters already
          val parameters = callee.getParameters
          // TODO:
          checkedParameters = parameters.map { parameter =>
            val checkedParameter =
              if (parameter.getRegister == null) parameter.getName
              else parameter.getRegister.getName

            // checked parameter name, parameter index, parameter data type
            (checkedParameter, parameter.getOrdinal + 1, parameter.getDataType.getName)
          }
        } else {
          // non thunk functions do not contain function parameters by default
          // need to decompile function to get parameter information
          // decompilation for a function is cached so subsequent calls to decompile should be free
          // TODO: replace this later on
          val parameters = decompiler
            .toHighFunction(callee)
            .map { highFunction =>
              highFunction.getLocalSymbolMap.getSymbols.asScala.toSeq.filter(_.isParameter).toArray
            }
            .getOrElse(Array.empty[HighSymbol])

          checkedParameters = parameters.map { parameter =>
            val checkedParameter =
              if (parameter.getStorage.getRegister == null) parameter.getName
              else parameter.getStorage.getRegister.getName

            // checked parameter name, parameter index, parameter data type
            (checkedParameter, parameter.getCategoryIndex + 1, parameter.getDataType.getName)
          }
        }
        checkedParameters.foreach { case (checkedParameter, index, dataType) =>
          val node = createIdentifier(
            checkedParameter,
            checkedParameter,
            index,
            Types.registerType(dataType),
            instruction.getMinAddress.getOffsetAsBigInteger.intValue
          )
          connectCallToArgument(diffGraphBuilder, callNode, node)
        }
      }
    } else
      for (index <- 0 until instruction.getNumOperands) {
        val opObjects = instruction.getOpObjects(index)
        if (opObjects.length > 1) {
          val argument = String.valueOf(instruction.getDefaultOperandRepresentation(index))
          val node = createIdentifier(
            argument,
            argument,
            index + 1,
            Types.registerType(argument),
            instruction.getMinAddress.getOffsetAsBigInteger.intValue
          )
          connectCallToArgument(diffGraphBuilder, callNode, node)
        } else
          for (opObject <- opObjects) { //
            val className = opObject.getClass.getSimpleName
            opObject.getClass.getSimpleName match {
              case "Register" =>
                val register = opObject.asInstanceOf[Register]
                val node = createIdentifier(
                  register.getName,
                  register.getName,
                  index + 1,
                  Types.registerType(register.getName),
                  instruction.getMinAddress.getOffsetAsBigInteger.intValue
                )
                connectCallToArgument(diffGraphBuilder, callNode, node)
              case "Scalar" =>
                val scalar =
                  opObject.asInstanceOf[Scalar].toString(16, false, false, "", "")

                val node = nodes
                  .NewLiteral()
                  .code(s"0x${scalar}")
                  .order(index + 1)
                  .typeFullName(scalar)
                  .lineNumber(Some(instruction.getMinAddress.getOffsetAsBigInteger.intValue))
                connectCallToArgument(diffGraphBuilder, callNode, node)
              case "GenericAddress" =>
                // TODO: try to resolve the address
                val genericAddress =
                  opObject.asInstanceOf[GenericAddress].toString()
                val node = nodes
                  .NewLiteral()
                  .code(s"0x${genericAddress}")
                  .order(index + 1)
                  .typeFullName(genericAddress)
                  .lineNumber(Some(instruction.getMinAddress.getOffsetAsBigInteger.intValue))
                connectCallToArgument(diffGraphBuilder, callNode, node)
              case _ =>
                println(s"""Unsupported argument: $opObject $className""")
            }
          }
      }
  }

  def connectCallToArgument(diffGraphBuilder: DiffGraphBuilder, call: CfgNodeNew, argument: CfgNodeNew): Unit = {
    diffGraphBuilder.addNode(argument)
    diffGraphBuilder.addEdge(call, argument, EdgeTypes.ARGUMENT)
    diffGraphBuilder.addEdge(call, argument, EdgeTypes.AST)
  }

  def addCallOrReturnNode(instruction: Instruction): CfgNodeNew =
    processor.getInstructions
      .getOrElse(instruction.getMnemonicString, "UNKNOWN") match {
      case "RET" =>
        createReturnNode(instruction.toString, instruction.getMinAddress.getOffsetAsBigInteger.intValue())
      case "CALL" | "LEAVE" =>
        val code = sanitizeMethodName(codeUnitFormat.getOperandRepresentationString(instruction, 0))
        createCallNode(code, code, instruction.getMinAddress.getOffsetAsBigInteger.intValue())
      case "UNKNOWN" =>
        createCallNode(instruction.toString, "UNKNOWN", instruction.getMinAddress.getOffsetAsBigInteger.intValue())
      case operator =>
        createCallNode(instruction.toString, operator, instruction.getMinAddress.getOffsetAsBigInteger.intValue())
    }

  def sanitizeMethodName(methodName: String): String =
    methodName.split(">").lastOption.getOrElse(methodName).replace("[", "").replace("]", "")

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy