Skip to content

Commit

Permalink
[jimple2cpg] - Enable Decompilation with CFR (joernio#4214)
Browse files Browse the repository at this point in the history
  • Loading branch information
AndreiDreyer authored Feb 21, 2024
1 parent 66852fb commit 389c1f0
Show file tree
Hide file tree
Showing 7 changed files with 237 additions and 11 deletions.
7 changes: 4 additions & 3 deletions joern-cli/frontends/jimple2cpg/build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@ name := "jimple2cpg"
dependsOn(Projects.dataflowengineoss % "compile->compile;test->test", Projects.x2cpg % "compile->compile;test->test")

libraryDependencies ++= Seq(
"io.shiftleft" %% "codepropertygraph" % Versions.cpg,
"org.soot-oss" % "soot" % "4.4.1",
"org.scalatest" %% "scalatest" % Versions.scalatest % Test
"io.shiftleft" %% "codepropertygraph" % Versions.cpg,
"org.soot-oss" % "soot" % "4.4.1",
"org.scalatest" %% "scalatest" % Versions.scalatest % Test,
"org.benf" % "cfr" % "0.152",
)

enablePlugins(JavaAppPackaging, LauncherJarPlugin)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package io.joern.jimple2cpg

import better.files.File
import io.joern.jimple2cpg.passes.{AstCreationPass, DeclarationRefPass, SootAstCreationPass}
import io.joern.jimple2cpg.util.Decompiler
import io.joern.jimple2cpg.util.ProgramHandlingUtil.{ClassFile, extractClassesInPackageLayout}
import io.joern.x2cpg.X2Cpg.withNewEmptyCpg
import io.joern.x2cpg.X2CpgFrontend
Expand Down Expand Up @@ -107,6 +108,8 @@ class Jimple2Cpg extends X2CpgFrontend[Config] {
}
case _ =>
val classFiles = sootLoad(input, tmpDir, config.recurse, config.depth)
decompileClassFiles(classFiles, !config.disableFileContent)

{ () =>
val astCreator = AstCreationPass(classFiles, cpg, config)
astCreator.createAndApply()
Expand All @@ -125,6 +128,23 @@ class Jimple2Cpg extends X2CpgFrontend[Config] {
DeclarationRefPass(cpg).createAndApply()
}

private def decompileClassFiles(classFiles: List[ClassFile], decompileJava: Boolean): Unit = {
Option.when(decompileJava) {
val decompiler = new Decompiler(classFiles.map(_.file))
val decompiledJava = decompiler.decompile()

classFiles.foreach(x => {
val decompiledJavaSrc = decompiledJava.get(x.fullyQualifiedClassName.get)
decompiledJavaSrc match {
case Some(src) =>
val outputFile = File(s"${x.file.pathAsString.replace(".class", ".java")}")
outputFile.write(src)
case None => // Do Nothing
}
})
}
}

override def createCpg(config: Config): Try[Cpg] =
try {
withNewEmptyCpg(config.outputPath, config: Config) { (cpg, config) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@ import soot.{Unit as SUnit, Local as _, *}

import scala.collection.immutable.Seq
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.jdk.CollectionConverters.CollectionHasAsScala
import scala.util.Try

class AstCreator(protected val filename: String, protected val cls: SootClass, global: Global)(implicit
withSchemaValidation: ValidationMode
) extends AstCreatorBase(filename)
class AstCreator(
protected val filename: String,
protected val cls: SootClass,
global: Global,
fileContent: Option[String] = None
)(implicit withSchemaValidation: ValidationMode)
extends AstCreatorBase(filename)
with AstForDeclarationsCreator
with AstForStatementsCreator
with AstForExpressionsCreator
Expand All @@ -46,6 +48,13 @@ class AstCreator(protected val filename: String, protected val cls: SootClass, g
def createAst(): DiffGraphBuilder = {
val astRoot = astForCompilationUnit(cls)
storeInDiffGraph(astRoot, diffGraph)

if (fileContent.isDefined) {
val fileNode = NewFile().name(filename).order(0)
fileContent.foreach(fileNode.content(_))
diffGraph.addNode(fileNode)
}

diffGraph
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,13 @@ import io.joern.x2cpg.datastructures.Global
import io.shiftleft.codepropertygraph.Cpg
import io.shiftleft.passes.ConcurrentWriterCpgPass
import org.slf4j.LoggerFactory
import better.files.{DefaultCharset, File}
import io.shiftleft.utils.IOUtils
import soot.Scene

import java.nio.charset.StandardCharsets
import scala.util.Try

/** Creates the AST layer from the given class file and stores all types in the given global parameter.
* @param classFiles
* List of class files and their fully qualified class names
Expand All @@ -27,7 +32,21 @@ class AstCreationPass(classFiles: List[ClassFile], cpg: Cpg, config: Config)
try {
val sootClass = Scene.v().loadClassAndSupport(classFile.fullyQualifiedClassName.get)
sootClass.setApplicationClass()
val localDiff = AstCreator(classFile.file.canonicalPath, sootClass, global)(config.schemaValidation).createAst()

val file = File(classFile.file.pathAsString.replace(".class", ".java"))

val fileContent = Option
.when(!config.disableFileContent && file.exists) {
Try(IOUtils.readEntireFile(file.path))
.orElse(Try(file.contentAsString(DefaultCharset)))
.orElse(Try(file.contentAsString(StandardCharsets.ISO_8859_1)))
.toOption
}
.flatten

val localDiff =
AstCreator(classFile.file.canonicalPath, sootClass, global, fileContent = fileContent)(config.schemaValidation)
.createAst()
builder.absorb(localDiff)
} catch {
case e: Exception =>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package io.joern.jimple2cpg.util

import better.files.File
import org.benf.cfr.reader.api.OutputSinkFactory.{Sink, SinkClass, SinkType}
import org.benf.cfr.reader.api.SinkReturns.Decompiled
import org.benf.cfr.reader.api.{CfrDriver, OutputSinkFactory}
import org.slf4j.LoggerFactory

import java.util
import java.util.{Collection, Collections}
import scala.collection.mutable
import scala.jdk.CollectionConverters.*

class Decompiler(classFile: List[File]) {

private val logger = LoggerFactory.getLogger(getClass)
private val classToDecompiledSource: mutable.HashMap[String, String] = mutable.HashMap.empty;

/** Decompiles the class files and returns a map of the method name to its source code contents.
*/
def decompile(): mutable.HashMap[String, String] = {
val driver = new CfrDriver.Builder().withOutputSink(outputSink).build()
driver.analyse(SeqHasAsJava(classFile.map(_.pathAsString)).asJava)
classToDecompiledSource
}

private val outputSink: OutputSinkFactory = new OutputSinkFactory() {

override def getSupportedSinks(sinkType: SinkType, collection: util.Collection[SinkClass]): util.List[SinkClass] =
if (sinkType == SinkType.JAVA && collection.contains(SinkClass.DECOMPILED)) {
util.Arrays.asList(SinkClass.DECOMPILED)
} else {
Collections.singletonList(SinkClass.STRING)
}

override def getSink[T](sinkType: SinkType, sinkClass: SinkClass): OutputSinkFactory.Sink[T] = new Sink[T]() {
override def write(s: T): Unit = {
sinkType match
case OutputSinkFactory.SinkType.JAVA =>
s match
case x: Decompiled =>
val className = x.getClassName
val packageName = x.getPackageName
val classFullName = Seq(packageName, className).filterNot(_.isBlank).mkString(".")
logger.debug(s"Decompiled '$classFullName', parsing...")

classToDecompiledSource.put(classFullName, x.getJava)
case _ =>
logger.error(s"Unhandled decompilation type ${s.getClass}")
case OutputSinkFactory.SinkType.PROGRESS =>
val className = s.toString.split(" ").last
logger.debug(s"Decompiling class '$className'")
case OutputSinkFactory.SinkType.EXCEPTION =>
logger.error(s.toString)
case _ => // ignore
}
}
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
package io.joern.jimple2cpg.querying

import io.joern.jimple2cpg.Config
import io.joern.jimple2cpg.testfixtures.JimpleCode2CpgFixture
import io.shiftleft.codepropertygraph.Cpg
import io.shiftleft.semanticcpg.language._

class CodeDumperTests extends JimpleCode2CpgFixture {
private val config = Config().withDisableFileContent(false)

"a Java source code CPG" should {
implicit val finder: NodeExtensionFinder = DefaultNodeExtensionFinder
val cpg: Cpg = code(
"""
|public class Foo {
|
| public void test() {
| var a = 1;
| var b = 2;
| var c = a + b;
| }
|
|}
|""".stripMargin,
"Foo.java"
)
.withConfig(config)
.cpg

"allow one to get decompiled java code in one file" in {
inside(cpg.file.name(".*Foo.class").l) {
case decompiledJava :: Nil =>
decompiledJava.content.linesIterator.map(_.strip).l shouldBe List(
"/*",
"* Decompiled with CFR 0.152.",
"*/",
"public class Foo {",
"public void test() {",
"int a = 1;",
"int b = 2;",
"int c = a + b;",
"}",
"}"
)
case content => fail(s"Expected exactly 1 file")
}
}
}

"Java Source CPG across multiple files" should {
implicit val finder: NodeExtensionFinder = DefaultNodeExtensionFinder
val cpg: Cpg = code(
"""
|package bar;
|public class Foo {
| public void test() {
| var a = 1;
| var b = 2;
| var c = a + b;
| }
|}
|""".stripMargin,
"Foo.java"
).moreCode(
"""
|package bar;
|public class Baz {
| public void bazTest() {
| var fooObj = new Foo();
| var b = 2;
| }
|}
|""".stripMargin,
"Baz.java"
).withConfig(config)
.cpg

"allow one to get java decompiled code for all classes" in {
inside(cpg.file.name(".*Foo.class").l) {
case decompiledJavaFoo :: Nil =>
decompiledJavaFoo.content.linesIterator.map(_.strip).filter(_.nonEmpty).l shouldBe List(
"/*",
"* Decompiled with CFR 0.152.",
"*/",
"package bar;",
"public class Foo {",
"public void test() {",
"int a = 1;",
"int b = 2;",
"int c = a + b;",
"}",
"}"
)

case _ => fail("Expected exactly 1 file")
}

inside(cpg.file.name(".*Baz.class").l) {
case decompiledJavaBaz :: Nil =>
decompiledJavaBaz.content.linesIterator.map(_.strip).filter(_.nonEmpty).l shouldBe List(
"/*",
"* Decompiled with CFR 0.152.",
"*/",
"package bar;",
"import bar.Foo;",
"public class Baz {",
"public void bazTest() {",
"Foo fooObj = new Foo();",
"int b = 2;",
"}",
"}"
)
case _ => fail("Expected exactly 1 file")
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ trait Jimple2CpgFrontend extends LanguageFrontend {
override val fileSuffix: String = ".java"

override def execute(sourceCodeFile: File): Cpg = {
implicit val defaultConfig: Config = Config()
new Jimple2Cpg().createCpg(sourceCodeFile.getAbsolutePath).get
val config = getConfig().map(_.asInstanceOf[Config]).getOrElse(Config())
new Jimple2Cpg().createCpg(sourceCodeFile.getAbsolutePath)(config).get
}
}

Expand Down

0 comments on commit 389c1f0

Please sign in to comment.