All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.shiftleft.codepropertygraph.cpgloading.ProtoToCpg.scala Maven / Gradle / Ivy

The newest version!
package io.shiftleft.codepropertygraph.cpgloading

import java.lang.{Boolean => JBoolean, Integer => JInt, Long => JLong}
import java.util.{NoSuchElementException, Collection => JCollection, Iterator => JIterator}

import io.shiftleft.codepropertygraph.Cpg
import io.shiftleft.proto.cpg.Cpg.CpgStruct.{Edge, Node}
import io.shiftleft.proto.cpg.Cpg.PropertyValue
import org.apache.commons.configuration.Configuration
import org.apache.logging.log4j.{LogManager, Logger}
import org.apache.tinkerpop.gremlin.structure.{T, Vertex}
import io.shiftleft.overflowdb.OdbGraph

import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer
import io.shiftleft.overflowdb.OdbConfig
import io.shiftleft.utils.StringInterner

object ProtoToCpg {
  val logger: Logger = LogManager.getLogger(classOf[ProtoToCpg])

  def addProperties(keyValues: ArrayBuffer[AnyRef],
                    name: String,
                    value: PropertyValue,
                    interner: StringInterner = StringInterner.noop): Unit = {
    import io.shiftleft.proto.cpg.Cpg.PropertyValue.ValueCase._
    value.getValueCase match {
      case INT_VALUE =>
        keyValues += interner.intern(name)
        keyValues += (value.getIntValue: JInt)
      case STRING_VALUE =>
        keyValues += interner.intern(name)
        keyValues += interner.intern(value.getStringValue)
      case BOOL_VALUE =>
        keyValues += interner.intern(name)
        keyValues += (value.getBoolValue: JBoolean)
      case STRING_LIST =>
        value.getStringList.getValuesList.asScala.foreach { elem: String =>
          keyValues += interner.intern(name)
          keyValues += interner.intern(elem)
        }
      case VALUE_NOT_SET => ()
      case _ =>
        throw new RuntimeException("Error: unsupported property case: " + value.getValueCase.name)
    }
  }
}

class ProtoToCpg(overflowConfig: OdbConfig = OdbConfig.withoutOverflow) {
  import ProtoToCpg._
  private val nodeFilter = new NodeFilter
  private val odbGraph =
    OdbGraph.open(overflowConfig,
                  io.shiftleft.codepropertygraph.generated.nodes.Factories.AllAsJava,
                  io.shiftleft.codepropertygraph.generated.edges.Factories.AllAsJava)
  private val interner: StringInterner = StringInterner.makeStrongInterner()

  def addNodes(nodes: JCollection[Node]): Unit =
    addNodes(nodes.asScala)

  def addNodes(nodes: Iterable[Node]): Unit =
    nodes
      .filter(nodeFilter.filterNode)
      .foreach(addVertexToOdbGraph)

  private def addVertexToOdbGraph(node: Node) = {
    try odbGraph.addVertex(nodeToArray(node): _*)
    catch {
      case e: Exception =>
        throw new RuntimeException("Failed to insert a vertex. proto:\n" + node, e)
    }
  }

  def addEdges(protoEdges: JCollection[Edge]): Unit =
    addEdges(protoEdges.asScala)

  def addEdges(protoEdges: Iterable[Edge]): Unit = {
    for (edge <- protoEdges) {
      val srcVertex = findVertexById(edge, edge.getSrc)
      val dstVertex = findVertexById(edge, edge.getDst)
      val properties: Seq[Edge.Property] = edge.getPropertyList.asScala
      val keyValues = new ArrayBuffer[AnyRef](2 * properties.size)
      for (edgeProperty <- properties) {
        addProperties(keyValues, edgeProperty.getName.name(), edgeProperty.getValue, interner)
      }
      try {
        srcVertex.addEdge(edge.getType.name(), dstVertex, keyValues.toArray: _*)
      } catch {
        case e: IllegalArgumentException =>
          val context = "label=" + edge.getType.name +
            ", srcNodeId=" + edge.getSrc +
            ", dstNodeId=" + edge.getDst +
            ", srcVertex=" + srcVertex +
            ", dstVertex=" + dstVertex
          logger.warn("Failed to insert an edge. context: " + context, e)
      }
    }
  }

  def build(): Cpg = new Cpg(odbGraph)

  private def findVertexById(edge: Edge, nodeId: Long): Vertex = {
    if (nodeId == -1)
      throw new IllegalArgumentException(
        "edge " + edge + " has illegal src|dst node. something seems wrong with the cpg")
    val vertices: JIterator[Vertex] = odbGraph.vertices(nodeId: JLong)
    if (!vertices.hasNext)
      throw new NoSuchElementException(
        "Couldn't find src|dst node " + nodeId + " for edge " + edge + " of type " + edge.getType.name)
    vertices.next
  }

  private def nodeToArray(node: Node): Array[AnyRef] = {
    val props = node.getPropertyList
    val keyValues = new ArrayBuffer[AnyRef](4 + (2 * props.size()))
    keyValues += T.id
    keyValues += (node.getKey: JLong)
    keyValues += T.label
    keyValues += node.getType.name()
    for (prop <- props.asScala) {
      addProperties(keyValues, prop.getName.name, prop.getValue, interner)
    }
    keyValues.toArray
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy