Skip to content

Commit

Permalink
[c#] Type Map (joernio#4011)
Browse files Browse the repository at this point in the history
Implemented a pre-pass that will allow one to maximize the type resolution during the AST creator phase.

In C#, it is common to import whole namespaces rather than individual classes, so this type map will allow one to fetch all declared classes from an imported namespace.

Resolves joernio#3856
  • Loading branch information
DavidBakerEffendi authored Dec 23, 2023
1 parent 38ba540 commit fdf2531
Show file tree
Hide file tree
Showing 8 changed files with 190 additions and 51 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import io.joern.csharpsrc2cpg.parser.DotNetJsonParser
import io.joern.csharpsrc2cpg.passes.AstCreationPass
import io.joern.csharpsrc2cpg.utils.DotNetAstGenRunner
import io.joern.x2cpg.X2Cpg.withNewEmptyCpg
import io.joern.x2cpg.astgen.AstGenRunner.AstGenRunnerResult
import io.joern.x2cpg.passes.callgraph.NaiveCallLinker
import io.joern.x2cpg.passes.frontend.MetaDataPass
import io.joern.x2cpg.utils.{Environment, HashUtil, Report}
Expand All @@ -27,7 +28,8 @@ class CSharpSrc2Cpg extends X2CpgFrontend[Config] {
withNewEmptyCpg(config.outputPath, config) { (cpg, config) =>
File.usingTemporaryDirectory("csharpsrc2cpgOut") { tmpDir =>
val astGenResult = new DotNetAstGenRunner(config).execute(tmpDir)
val astCreators = CSharpSrc2Cpg.processAstGenRunnerResults(astGenResult.parsedFiles, config)
val typeMap = new TypeMap(astGenResult)
val astCreators = CSharpSrc2Cpg.processAstGenRunnerResults(astGenResult.parsedFiles, config, typeMap)

val hash = HashUtil.sha256(astCreators.map(_.parserResult).map(x => Paths.get(x.fullPath)))
new MetaDataPass(cpg, Languages.CSHARPSRC, config.inputPath, Option(hash)).createAndApply()
Expand All @@ -45,7 +47,7 @@ object CSharpSrc2Cpg {

/** Parses the generated AST Gen files in parallel and produces AstCreators from each.
*/
def processAstGenRunnerResults(astFiles: List[String], config: Config): Seq[AstCreator] = {
def processAstGenRunnerResults(astFiles: List[String], config: Config, typeMap: TypeMap): Seq[AstCreator] = {
Await.result(
Future.sequence(
astFiles
Expand All @@ -65,7 +67,7 @@ object CSharpSrc2Cpg {
} else {
SourceFiles.toRelativePath(parserResult.fullPath, config.inputPath)
}
new AstCreator(relativeFileName, parserResult)(config.schemaValidation)
new AstCreator(relativeFileName, parserResult, typeMap)(config.schemaValidation)
}
)
),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
package io.joern.csharpsrc2cpg

import io.joern.csharpsrc2cpg.astcreation.AstCreatorHelper
import io.joern.csharpsrc2cpg.parser.DotNetJsonAst.{
ClassDeclaration,
FieldDeclaration,
MethodDeclaration,
NamespaceDeclaration
}
import io.joern.csharpsrc2cpg.parser.{DotNetJsonAst, DotNetJsonParser, DotNetNodeInfo, ParserKeys}
import io.joern.x2cpg.astgen.AstGenRunner.AstGenRunnerResult
import io.joern.x2cpg.datastructures.Stack.Stack
import io.shiftleft.codepropertygraph.generated.nodes.NewNode

import java.nio.file.Paths
import scala.collection.mutable

class TypeMap(astGenResult: AstGenRunnerResult) {

private val namespaceToType: Map[String, Set[CSharpType]] = astGenResult.parsedFiles
.map { file =>
val parserResult = DotNetJsonParser.readFile(Paths.get(file))
val compilationUnit = AstCreatorHelper.createDotNetNodeInfo(parserResult.json(ParserKeys.AstRoot))
() => parseCompilationUnit(compilationUnit)
}
.map(task => task()) // TODO: To be parallelized with https://github.com/joernio/joern/pull/4009
.foldLeft(Map.empty[String, Set[CSharpType]])((a, b) => {
val accumulator = mutable.HashMap.from(a)
val allKeys = accumulator.keySet ++ b.keySet

allKeys.foreach(k =>
accumulator.updateWith(k) {
case Some(existing) => b.get(k).map(x => x ++ existing)
case None => b.get(k)
}
)
accumulator.toMap
})

/** For the given namespace, returns the declared classes.
*/
def classesIn(namespace: String): Set[CSharpType] = namespaceToType.getOrElse(namespace, Set.empty)

/** Parses a compilation unit and returns a mapping from all the contained namespaces and the immediate children
* types.
*/
private def parseCompilationUnit(cu: DotNetNodeInfo): Map[String, Set[CSharpType]] = {
cu.json(ParserKeys.Members)
.arr
.map(AstCreatorHelper.createDotNetNodeInfo(_))
.filter { x =>
x.node match
case NamespaceDeclaration => true
case _ => false
}
.map(parseNamespace)
.toMap
}

private def parseNamespace(namespace: DotNetNodeInfo): (String, Set[CSharpType]) = {
val namespaceName = AstCreatorHelper.nameFromNode(namespace)
val classes = namespace
.json(ParserKeys.Members)
.arr
.map(AstCreatorHelper.createDotNetNodeInfo(_))
.filter { x =>
x.node match
case ClassDeclaration => true
case _ => false
}
.map(parseClassDeclaration)
.toSet
namespaceName -> classes
}

private def parseClassDeclaration(classDecl: DotNetNodeInfo): CSharpType = {
val className = AstCreatorHelper.nameFromNode(classDecl)
val members = classDecl
.json(ParserKeys.Members)
.arr
.map(AstCreatorHelper.createDotNetNodeInfo(_))
.flatMap { x =>
x.node match
case MethodDeclaration => parseMethodDeclaration(x)
case FieldDeclaration => parseFieldDeclaration(x)
case _ => List.empty
}
.toList
CSharpType(className, members)
}

private def parseMethodDeclaration(methodDecl: DotNetNodeInfo): List[CSharpMethod] = {
List(CSharpMethod(AstCreatorHelper.nameFromNode(methodDecl)))
}

private def parseFieldDeclaration(fieldDecl: DotNetNodeInfo): List[CSharpField] = {
val declarationNode = AstCreatorHelper.createDotNetNodeInfo(fieldDecl.json(ParserKeys.Declaration))
declarationNode
.json(ParserKeys.Variables)
.arr
.map(AstCreatorHelper.createDotNetNodeInfo(_))
.map(AstCreatorHelper.nameFromNode)
.map(CSharpField.apply)
.toList
}

}

sealed trait CSharpMember {
def name: String
}

case class CSharpField(name: String) extends CSharpMember

case class CSharpMethod(name: String) extends CSharpMember

case class CSharpType(name: String, members: List[CSharpMember]) extends CSharpMember
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package io.joern.csharpsrc2cpg.astcreation

import io.joern.csharpsrc2cpg.TypeMap
import io.joern.csharpsrc2cpg.parser.DotNetJsonAst.*
import io.joern.csharpsrc2cpg.parser.{DotNetNodeInfo, ParserKeys}
import io.joern.x2cpg.astgen.{AstGenNodeBuilder, ParserResult}
Expand All @@ -13,7 +14,7 @@ import ujson.Value
import java.math.BigInteger
import java.security.MessageDigest

class AstCreator(val relativeFileName: String, val parserResult: ParserResult)(implicit
class AstCreator(val relativeFileName: String, val parserResult: ParserResult, val typeMap: TypeMap)(implicit
withSchemaValidation: ValidationMode
) extends AstCreatorBase(relativeFileName)
with AstCreatorHelper
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,8 @@ import ujson.Value

trait AstCreatorHelper(implicit withSchemaValidation: ValidationMode) { this: AstCreator =>

protected def createDotNetNodeInfo(json: Value): DotNetNodeInfo = {
val metaData = json(ParserKeys.MetaData)
val ln = metaData(ParserKeys.LineStart).numOpt.map(_.toInt.asInstanceOf[Integer])
val cn = metaData(ParserKeys.ColumnStart).numOpt.map(_.toInt.asInstanceOf[Integer])
val lnEnd = metaData(ParserKeys.LineEnd).numOpt.map(_.toInt.asInstanceOf[Integer])
val cnEnd = metaData(ParserKeys.ColumnEnd).numOpt.map(_.toInt.asInstanceOf[Integer])
val c =
metaData(ParserKeys.Code).strOpt.map(x => x.takeWhile(x => x != '\n' && x != '{')).getOrElse("<empty>").strip()
val node = nodeType(metaData)
DotNetNodeInfo(node, json, c, ln, cn, lnEnd, cnEnd)
}
protected def createDotNetNodeInfo(json: Value): DotNetNodeInfo =
AstCreatorHelper.createDotNetNodeInfo(json, Option(this.relativeFileName))

protected def notHandledYet(node: DotNetNodeInfo): Seq[Ast] = {
val text =
Expand All @@ -33,9 +24,6 @@ trait AstCreatorHelper(implicit withSchemaValidation: ValidationMode) { this: As
Seq(Ast(unknownNode(node, node.code)))
}

private def nodeType(node: Value): DotNetParserNode =
DotNetJsonAst.fromString(node(ParserKeys.Kind).str, this.relativeFileName)

protected def astFullName(node: DotNetNodeInfo): String = {
methodAstParentStack.headOption match
case Some(head: NewNamespaceBlock) => s"${head.fullName}.${nameFromNode(node)}"
Expand All @@ -44,27 +32,7 @@ trait AstCreatorHelper(implicit withSchemaValidation: ValidationMode) { this: As
case _ => nameFromNode(node)
}

protected def nameFromNode(identifierNode: DotNetNodeInfo): String = {
identifierNode.node match
case IdentifierName | Parameter => nameFromIdentifier(identifierNode)
case QualifiedName => nameFromQualifiedName(identifierNode)
case _: DeclarationExpr => nameFromDeclaration(identifierNode)
case _ => "<empty>"
}

protected def nameFromIdentifier(identifier: DotNetNodeInfo): String = {
identifier.json(ParserKeys.Identifier).obj(ParserKeys.Value).str
}

protected def nameFromDeclaration(node: DotNetNodeInfo): String = {
node.json(ParserKeys.Identifier).obj(ParserKeys.Value).str
}

protected def nameFromQualifiedName(qualifiedName: DotNetNodeInfo): String = {
val rhs = nameFromNode(createDotNetNodeInfo(qualifiedName.json(ParserKeys.Right)))
val lhs = nameFromNode(createDotNetNodeInfo(qualifiedName.json(ParserKeys.Left)))
s"$lhs.$rhs"
}
protected def nameFromNode(identifierNode: DotNetNodeInfo): String = AstCreatorHelper.nameFromNode(identifierNode)

// TODO: Use type map to try resolve full name
protected def nodeTypeFullName(node: DotNetNodeInfo): String = {
Expand All @@ -89,6 +57,55 @@ trait AstCreatorHelper(implicit withSchemaValidation: ValidationMode) { this: As

}

object AstCreatorHelper {

/** Creates a info node for the given JSON node.
* @param json
* the json node to convert.
* @param relativeFileName
* optional file name for debugging purposes.
* @return
* the node info.
*/
def createDotNetNodeInfo(json: Value, relativeFileName: Option[String] = None): DotNetNodeInfo = {
val metaData = json(ParserKeys.MetaData)
val ln = metaData(ParserKeys.LineStart).numOpt.map(_.toInt.asInstanceOf[Integer])
val cn = metaData(ParserKeys.ColumnStart).numOpt.map(_.toInt.asInstanceOf[Integer])
val lnEnd = metaData(ParserKeys.LineEnd).numOpt.map(_.toInt.asInstanceOf[Integer])
val cnEnd = metaData(ParserKeys.ColumnEnd).numOpt.map(_.toInt.asInstanceOf[Integer])
val c =
metaData(ParserKeys.Code).strOpt.map(x => x.takeWhile(x => x != '\n' && x != '{')).getOrElse("<empty>").strip()
val node = nodeType(metaData, relativeFileName)
DotNetNodeInfo(node, json, c, ln, cn, lnEnd, cnEnd)
}

private def nodeType(node: Value, relativeFileName: Option[String] = None): DotNetParserNode =
DotNetJsonAst.fromString(node(ParserKeys.Kind).str, relativeFileName)

def nameFromNode(node: DotNetNodeInfo): String = {
node.node match
case NamespaceDeclaration => nameFromNamespaceDeclaration(node)
case IdentifierName | Parameter | _: DeclarationExpr => nameFromIdentifier(node)
case QualifiedName => nameFromQualifiedName(node)
case _ => "<empty>"
}

private def nameFromNamespaceDeclaration(namespace: DotNetNodeInfo): String = {
val nameNode = createDotNetNodeInfo(namespace.json(ParserKeys.Name))
nameFromNode(nameNode)
}

private def nameFromIdentifier(identifier: DotNetNodeInfo): String = {
identifier.json(ParserKeys.Identifier).obj(ParserKeys.Value).str
}

private def nameFromQualifiedName(qualifiedName: DotNetNodeInfo): String = {
val rhs = nameFromNode(createDotNetNodeInfo(qualifiedName.json(ParserKeys.Right)))
val lhs = nameFromNode(createDotNetNodeInfo(qualifiedName.json(ParserKeys.Left)))
s"$lhs.$rhs"
}
}

/** Contains all the C# builtin types, as well as `null` and `void`.
*
* @see
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,13 @@ import scala.util.Try
trait AstForDeclarationsCreator(implicit withSchemaValidation: ValidationMode) { this: AstCreator =>

protected def astForNamespaceDeclaration(namespace: DotNetNodeInfo): Seq[Ast] = {
val nameNode = createDotNetNodeInfo(namespace.json(ParserKeys.Name))
val fullName = astFullName(nameNode)
val fullName = astFullName(namespace)
val name = fullName.split('.').filterNot(_.isBlank).lastOption.getOrElse(fullName)
val namespaceBlock = NewNamespaceBlock()
.name(name)
.code(code(namespace))
.lineNumber(line(nameNode))
.columnNumber(columnEnd(nameNode))
.lineNumber(line(namespace))
.columnNumber(columnEnd(namespace))
.filename(relativeFileName)
.fullName(fullName)
methodAstParentStack.push(namespaceBlock)
Expand All @@ -30,7 +29,7 @@ trait AstForDeclarationsCreator(implicit withSchemaValidation: ValidationMode) {
}

protected def astForClassDeclaration(classDecl: DotNetNodeInfo): Seq[Ast] = {
val name = nameFromIdentifier(classDecl)
val name = nameFromNode(classDecl)
val fullName = astFullName(classDecl)
val typeDecl = typeDeclNode(classDecl, name, fullName, relativeFileName, code(classDecl))
methodAstParentStack.push(typeDecl)
Expand Down Expand Up @@ -86,7 +85,7 @@ trait AstForDeclarationsCreator(implicit withSchemaValidation: ValidationMode) {
}

protected def astForMethodDeclaration(methodDecl: DotNetNodeInfo): Seq[Ast] = {
val name = nameFromIdentifier(methodDecl)
val name = nameFromNode(methodDecl)
val params = methodDecl
.json(ParserKeys.ParameterList)
.obj(ParserKeys.Parameters)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@ object DotNetJsonAst {
private val logger = LoggerFactory.getLogger(getClass)
private val QualifiedClassName: String = DotNetJsonAst.getClass.getName

def fromString(nodeName: String, fileName: String): DotNetParserNode = {
def fromString(nodeName: String, fileName: Option[String] = None): DotNetParserNode = {
try {
val clazz = Class.forName(s"$QualifiedClassName${nodeName.stripPrefix("ast.")}$$")
clazz.getField("MODULE$").get(clazz).asInstanceOf[DotNetParserNode]
} catch {
case _: Throwable =>
logger.warn(s"`$nodeName` AST type is not handled. We found this inside '$fileName'")
logger.warn(
s"`$nodeName` AST type is not handled.${fileName.map(x => s" We found this inside '$x'").getOrElse("")}"
)
NotHandledType
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@ import better.files.File
import io.joern.csharpsrc2cpg.passes.AstCreationPass
import io.joern.csharpsrc2cpg.testfixtures.CSharpCode2CpgFixture
import io.joern.csharpsrc2cpg.utils.DotNetAstGenRunner
import io.joern.csharpsrc2cpg.{CSharpSrc2Cpg, Config}
import io.joern.csharpsrc2cpg.{CSharpSrc2Cpg, Config, TypeMap}
import io.joern.x2cpg.X2Cpg.newEmptyCpg
import io.joern.x2cpg.utils.Report
import io.shiftleft.codepropertygraph.Cpg
import org.scalatest.BeforeAndAfterAll
import io.shiftleft.semanticcpg.language.*
import org.scalatest.BeforeAndAfterAll

class ProjectParseTests extends CSharpCode2CpgFixture with BeforeAndAfterAll {

Expand Down Expand Up @@ -56,7 +56,8 @@ class ProjectParseTests extends CSharpCode2CpgFixture with BeforeAndAfterAll {
val cpg = newEmptyCpg()
val config = Config().withInputPath(projectDir.toString).withOutputPath(tmpDir.toString)
val astGenResult = new DotNetAstGenRunner(config).execute(tmpDir)
val astCreators = CSharpSrc2Cpg.processAstGenRunnerResults(astGenResult.parsedFiles, config)
val typeMap = new TypeMap(astGenResult)
val astCreators = CSharpSrc2Cpg.processAstGenRunnerResults(astGenResult.parsedFiles, config, typeMap)
new AstCreationPass(cpg, astCreators, new Report()).createAndApply()
f(cpg)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class NamespaceTests extends CSharpCode2CpgFixture {
helloWorld.code shouldBe "namespace HelloWorld"
helloWorld.filename shouldBe "Program.cs"
helloWorld.lineNumber shouldBe Some(2)
helloWorld.columnNumber shouldBe Some(20)
helloWorld.columnNumber shouldBe Some(1)
helloWorld.fullName shouldBe "HelloWorld"
}

Expand Down

0 comments on commit fdf2531

Please sign in to comment.