Skip to content

Commit

Permalink
Refactor to check files individually
Browse files Browse the repository at this point in the history
  • Loading branch information
mrjameshamilton committed Jan 12, 2022
1 parent 6f912de commit 08ff672
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 175 deletions.
147 changes: 26 additions & 121 deletions src/main/kotlin/eu/jameshamilton/log4shell/Main.kt
Original file line number Diff line number Diff line change
Expand Up @@ -2,85 +2,56 @@ package eu.jameshamilton.log4shell

import proguard.classfile.AccessConstants.PRIVATE
import proguard.classfile.ClassPool
import proguard.classfile.Clazz
import proguard.classfile.Member
import proguard.classfile.util.ClassReferenceInitializer
import proguard.classfile.util.ClassSubHierarchyInitializer
import proguard.classfile.util.ClassSuperHierarchyInitializer
import proguard.classfile.util.WarningPrinter
import proguard.classfile.visitor.AllMemberVisitor
import proguard.classfile.visitor.ClassCounter
import proguard.classfile.visitor.ClassNameFilter
import proguard.classfile.visitor.ClassVisitor
import proguard.classfile.visitor.ClassPoolFiller
import proguard.classfile.visitor.ConstructorMethodFilter
import proguard.classfile.visitor.MemberAccessFilter
import proguard.classfile.visitor.MemberCounter
import proguard.classfile.visitor.MemberDescriptorFilter
import proguard.classfile.visitor.MemberVisitor
import proguard.classfile.visitor.MethodFilter
import proguard.classfile.visitor.MultiMemberVisitor
import proguard.io.DataEntry
import proguard.io.ClassReader
import proguard.io.DataEntryNameFilter
import proguard.io.DataEntryReader
import proguard.io.Dex2JarReader
import proguard.io.DirectorySource
import proguard.io.FileDataEntry
import proguard.io.FilteredDataEntryReader
import proguard.io.JarReader
import proguard.io.NameFilteredDataEntryReader
import proguard.io.ZipFileDataEntry
import proguard.util.ExtensionMatcher
import proguard.util.OrMatcher
import java.io.File
import proguard.classfile.visitor.ClassPoolFiller as ProGuardClassPoolFiller
import proguard.io.ClassReader as ProGuardClassReader

fun main(args: Array<String>) {
if (args.isEmpty()) {
println("Usage: log4shell-detector <jar-file>")
return
}

val input = File(args.first())
val programClassPool = readInput(input)
val vulnerableFiles = when (val input = File(args.first())) {
input -> input.walk()
.filter { it.isFile && it.extension in listOf("jar", "war", "dex", "apk", "aar", "class", "zip") }
.map { if (check(it)) it else null }.filterNotNull().toList()
else -> if (check(input)) listOf(input) else emptyList()
}

check(programClassPool) { locations ->
if (vulnerableFiles.isEmpty()) {
println("No log4shell found")
} else {
println(
"""
|WARNING: log4j < 2.15.0 vulnerable to CVE-2021-44228 found in:
|${locations.joinToString(separator = "\n") { "\t- $it" }}
|
|For more information see: https://logging.apache.org/log4j/2.x/security.html
""".trimMargin()
|WARNING: log4j < 2.15.0 vulnerable to CVE-2021-44228 found in:
|
|${vulnerableFiles.joinToString(separator = "") { "\t- ${it.name}\n" }}
|For more information see: https://logging.apache.org/log4j/2.x/security.html
""".trimMargin()
)
}
}

fun check(programClassPool: ClassPool, onDetected: (Set<String>) -> Unit) = check(
programClassPool,
object : MemberVisitor {
@Suppress("UNCHECKED_CAST")
private fun processingInfoToLocation(clazz: Clazz): Set<String> = when (clazz.processingInfo) {
is DataEntry -> with(clazz.processingInfo as DataEntry) {
setOf(this.parent?.originalName ?: this.originalName)
}
is Set<*> -> (clazz.processingInfo as Set<DataEntry>).map {
when (it) {
is FileDataEntry -> it.file.absolutePath
is ZipFileDataEntry -> it.parent.originalName
else -> it.originalName
}
}.toSortedSet()
else -> setOf("unknown")
}

override fun visitAnyMember(clazz: Clazz, member: Member) {
onDetected(processingInfoToLocation(clazz))
}
}
)

fun check(programClassPool: ClassPool, jndiManagerOldConstructorVisitor: MemberVisitor) {
fun check(file: File): Boolean = check(readInput(file))
fun check(programClassPool: ClassPool): Boolean {
val jndiLookupCounter = ClassCounter()
val jndiManagerOldConstructorCounter = MemberCounter()

Expand Down Expand Up @@ -109,10 +80,7 @@ fun check(programClassPool: ClassPool, jndiManagerOldConstructorVisitor: MemberV
/* requiredSetAccessFlags = */ PRIVATE, /* requiredUnsetAccessFlags = */ 0,
MemberDescriptorFilter(
"(Ljava/lang/String;Ljavax/naming/Context;)V",
MultiMemberVisitor(
jndiManagerOldConstructorCounter,
jndiManagerOldConstructorVisitor
)
jndiManagerOldConstructorCounter
)
)
)
Expand All @@ -122,26 +90,20 @@ fun check(programClassPool: ClassPool, jndiManagerOldConstructorVisitor: MemberV
)
}

if (jndiManagerOldConstructorCounter.count == 0 && jndiLookupCounter.count > 0) {
println(
"""
JndiLookup class found, but no pre-2.15.0 constructor found.
""".trimIndent()
)
}
return jndiLookupCounter.count > 0 && jndiManagerOldConstructorCounter.count > 0
}

private fun readInput(inputFile: File): ClassPool {
val programClassPool = ClassPool()
var classReader: DataEntryReader = NameFilteredDataEntryReader(
"**.class",
ClassReader(
isLibrary = false,
skipNonPublicLibraryClasses = false,
skipNonPublicLibraryClassMembers = false,
ignoreStackMapAttributes = false,
warningPrinter = null,
classVisitor = ProcessingInfoMergingClassPoolFiller(programClassPool)
/* isLibrary = */ false,
/* skipNonPublicLibraryClasses = */ false,
/* skipNonPublicLibraryClassMembers = */ false,
/* ignoreStackMapAttributes = */ false,
/* warningPrinter = */ null,
ClassPoolFiller(programClassPool)
)
)

Expand Down Expand Up @@ -186,60 +148,3 @@ private fun readInput(inputFile: File): ClassPool {

return programClassPool
}

class ClassReader(
isLibrary: Boolean,
skipNonPublicLibraryClasses: Boolean,
skipNonPublicLibraryClassMembers: Boolean,
ignoreStackMapAttributes: Boolean,
warningPrinter: WarningPrinter?,
private val classVisitor: ClassVisitor
) : DataEntryReader {
private lateinit var currentDataEntry: DataEntry

private val proguardClassReader = ProGuardClassReader(
isLibrary,
skipNonPublicLibraryClasses,
skipNonPublicLibraryClassMembers,
ignoreStackMapAttributes,
warningPrinter
) {
it.processingInfo = currentDataEntry
it.accept(classVisitor)
}

override fun read(dataEntry: DataEntry) {
currentDataEntry = dataEntry
proguardClassReader.read(dataEntry)
}
}

class ProcessingInfoMergingClassPoolFiller(private val classPool: ClassPool) : ProGuardClassPoolFiller(classPool) {
@Suppress("UNCHECKED_CAST")
override fun visitAnyClass(clazz: Clazz) {
when (val existingClazz = classPool.getClass(clazz.name)) {
is Clazz -> {
val oldProcessingInfo = existingClazz.processingInfo as MutableSet<DataEntry>
val newProcessingInfo = clazz.processingInfo
oldProcessingInfo.add(newProcessingInfo as DataEntry)
}
else -> classPool.addClass(clazz.apply { processingInfo = mutableSetOf(processingInfo as DataEntry) })
}
}
}

@Suppress("unused") // not required in this app
fun initialize(programClassPool: ClassPool, libraryClassPool: ClassPool) {
val classReferenceInitializer = ClassReferenceInitializer(programClassPool, libraryClassPool)
val classSuperHierarchyInitializer = ClassSuperHierarchyInitializer(programClassPool, libraryClassPool)
val classSubHierarchyInitializer = ClassSubHierarchyInitializer()

programClassPool.classesAccept(classSuperHierarchyInitializer)
libraryClassPool.classesAccept(classSuperHierarchyInitializer)

programClassPool.classesAccept(classReferenceInitializer)
libraryClassPool.classesAccept(classReferenceInitializer)

programClassPool.accept(classSubHierarchyInitializer)
libraryClassPool.accept(classSubHierarchyInitializer)
}
58 changes: 4 additions & 54 deletions src/test/kotlin/Log4ShellDetectorTest.kt
Original file line number Diff line number Diff line change
@@ -1,19 +1,12 @@
import eu.jameshamilton.log4shell.check
import io.kotest.core.spec.style.FunSpec
import io.kotest.matchers.shouldBe
import io.mockk.spyk
import io.mockk.verify
import proguard.classfile.AccessConstants.PRIVATE
import proguard.classfile.AccessConstants.PUBLIC
import proguard.classfile.ClassPool
import proguard.classfile.Clazz
import proguard.classfile.Member
import proguard.classfile.ProgramClass
import proguard.classfile.ProgramMember
import proguard.classfile.VersionConstants.CLASS_VERSION_1_6
import proguard.classfile.editor.ClassBuilder
import proguard.classfile.util.ClassRenamer
import proguard.classfile.visitor.MemberVisitor

class Log4ShellDetectorTest : FunSpec({
val jndiLookup = ClassBuilder(
Expand All @@ -34,34 +27,12 @@ class Log4ShellDetectorTest : FunSpec({

test("Should not detect Log4Shell if JndiLookup is not present") {
val programClassPool = ClassPool()
val visitor = spyk(object : MemberVisitor {
override fun visitAnyMember(clazz: Clazz, member: Member) {}
})

check(programClassPool, visitor)

verify(exactly = 0) {
visitor.visitProgramMember(
ofType(ProgramClass::class),
ofType(ProgramMember::class)
)
}
check(programClassPool) shouldBe false
}

test("Should detect Log4Shell if JndiLookup and old constructor is present") {
val programClassPool = ClassPool(jndiLookup, jndiManager)
val visitor = spyk(object : MemberVisitor {
override fun visitAnyMember(clazz: Clazz, member: Member) {}
})

check(programClassPool, visitor)

verify(exactly = 1) {
visitor.visitProgramMember(
ofType(ProgramClass::class),
ofType(ProgramMember::class)
)
}
check(programClassPool) shouldBe true
}

test("Should detect shadowed log4j if JndiLookup and old constructor is present") {
Expand All @@ -79,33 +50,12 @@ class Log4ShellDetectorTest : FunSpec({
programClassPool.getClass("org/apache/logging/log4j/core/net/JndiManager")
.name shouldBe "com/example/shadow/org/apache/logging/log4j/core/net/JndiManager"

val visitor = spyk(object : MemberVisitor {
override fun visitAnyMember(clazz: Clazz, member: Member) {}
})
check(programClassPool, visitor)

verify(exactly = 1) {
visitor.visitProgramMember(
ofType(ProgramClass::class),
ofType(ProgramMember::class)
)
}
check(programClassPool) shouldBe true
}

test("Should not detect Log4Shell if old constructor is present but JndiLookup is not present") {
// Removing JndiLookup is a workaround for log4shell
val programClassPool = ClassPool(jndiManager)
val visitor = spyk(object : MemberVisitor {
override fun visitAnyMember(clazz: Clazz, member: Member) {}
})

check(programClassPool, visitor)

verify(exactly = 0) {
visitor.visitProgramMember(
ofType(ProgramClass::class),
ofType(ProgramMember::class)
)
}
check(programClassPool) shouldBe false
}
})

0 comments on commit 08ff672

Please sign in to comment.