From 08ff6725aad15b0471268944e1772a9936314014 Mon Sep 17 00:00:00 2001 From: James Hamilton Date: Wed, 12 Jan 2022 21:30:29 +0100 Subject: [PATCH] Refactor to check files individually --- .../kotlin/eu/jameshamilton/log4shell/Main.kt | 147 ++++-------------- src/test/kotlin/Log4ShellDetectorTest.kt | 58 +------ 2 files changed, 30 insertions(+), 175 deletions(-) diff --git a/src/main/kotlin/eu/jameshamilton/log4shell/Main.kt b/src/main/kotlin/eu/jameshamilton/log4shell/Main.kt index 8ba36fb..8323466 100644 --- a/src/main/kotlin/eu/jameshamilton/log4shell/Main.kt +++ b/src/main/kotlin/eu/jameshamilton/log4shell/Main.kt @@ -2,38 +2,26 @@ package eu.jameshamilton.log4shell import proguard.classfile.AccessConstants.PRIVATE import proguard.classfile.ClassPool -import proguard.classfile.Clazz -import proguard.classfile.Member -import proguard.classfile.util.ClassReferenceInitializer -import proguard.classfile.util.ClassSubHierarchyInitializer -import proguard.classfile.util.ClassSuperHierarchyInitializer -import proguard.classfile.util.WarningPrinter import proguard.classfile.visitor.AllMemberVisitor import proguard.classfile.visitor.ClassCounter import proguard.classfile.visitor.ClassNameFilter -import proguard.classfile.visitor.ClassVisitor +import proguard.classfile.visitor.ClassPoolFiller import proguard.classfile.visitor.ConstructorMethodFilter import proguard.classfile.visitor.MemberAccessFilter import proguard.classfile.visitor.MemberCounter import proguard.classfile.visitor.MemberDescriptorFilter -import proguard.classfile.visitor.MemberVisitor import proguard.classfile.visitor.MethodFilter -import proguard.classfile.visitor.MultiMemberVisitor -import proguard.io.DataEntry +import proguard.io.ClassReader import proguard.io.DataEntryNameFilter import proguard.io.DataEntryReader import proguard.io.Dex2JarReader import proguard.io.DirectorySource -import proguard.io.FileDataEntry import proguard.io.FilteredDataEntryReader import proguard.io.JarReader import proguard.io.NameFilteredDataEntryReader -import proguard.io.ZipFileDataEntry import proguard.util.ExtensionMatcher import proguard.util.OrMatcher import java.io.File -import proguard.classfile.visitor.ClassPoolFiller as ProGuardClassPoolFiller -import proguard.io.ClassReader as ProGuardClassReader fun main(args: Array) { if (args.isEmpty()) { @@ -41,46 +29,29 @@ fun main(args: Array) { return } - val input = File(args.first()) - val programClassPool = readInput(input) + val vulnerableFiles = when (val input = File(args.first())) { + input -> input.walk() + .filter { it.isFile && it.extension in listOf("jar", "war", "dex", "apk", "aar", "class", "zip") } + .map { if (check(it)) it else null }.filterNotNull().toList() + else -> if (check(input)) listOf(input) else emptyList() + } - check(programClassPool) { locations -> + if (vulnerableFiles.isEmpty()) { + println("No log4shell found") + } else { println( """ - |WARNING: log4j < 2.15.0 vulnerable to CVE-2021-44228 found in: - |${locations.joinToString(separator = "\n") { "\t- $it" }} - | - |For more information see: https://logging.apache.org/log4j/2.x/security.html - """.trimMargin() + |WARNING: log4j < 2.15.0 vulnerable to CVE-2021-44228 found in: + | + |${vulnerableFiles.joinToString(separator = "") { "\t- ${it.name}\n" }} + |For more information see: https://logging.apache.org/log4j/2.x/security.html + """.trimMargin() ) } } -fun check(programClassPool: ClassPool, onDetected: (Set) -> Unit) = check( - programClassPool, - object : MemberVisitor { - @Suppress("UNCHECKED_CAST") - private fun processingInfoToLocation(clazz: Clazz): Set = when (clazz.processingInfo) { - is DataEntry -> with(clazz.processingInfo as DataEntry) { - setOf(this.parent?.originalName ?: this.originalName) - } - is Set<*> -> (clazz.processingInfo as Set).map { - when (it) { - is FileDataEntry -> it.file.absolutePath - is ZipFileDataEntry -> it.parent.originalName - else -> it.originalName - } - }.toSortedSet() - else -> setOf("unknown") - } - - override fun visitAnyMember(clazz: Clazz, member: Member) { - onDetected(processingInfoToLocation(clazz)) - } - } -) - -fun check(programClassPool: ClassPool, jndiManagerOldConstructorVisitor: MemberVisitor) { +fun check(file: File): Boolean = check(readInput(file)) +fun check(programClassPool: ClassPool): Boolean { val jndiLookupCounter = ClassCounter() val jndiManagerOldConstructorCounter = MemberCounter() @@ -109,10 +80,7 @@ fun check(programClassPool: ClassPool, jndiManagerOldConstructorVisitor: MemberV /* requiredSetAccessFlags = */ PRIVATE, /* requiredUnsetAccessFlags = */ 0, MemberDescriptorFilter( "(Ljava/lang/String;Ljavax/naming/Context;)V", - MultiMemberVisitor( - jndiManagerOldConstructorCounter, - jndiManagerOldConstructorVisitor - ) + jndiManagerOldConstructorCounter ) ) ) @@ -122,13 +90,7 @@ fun check(programClassPool: ClassPool, jndiManagerOldConstructorVisitor: MemberV ) } - if (jndiManagerOldConstructorCounter.count == 0 && jndiLookupCounter.count > 0) { - println( - """ - JndiLookup class found, but no pre-2.15.0 constructor found. - """.trimIndent() - ) - } + return jndiLookupCounter.count > 0 && jndiManagerOldConstructorCounter.count > 0 } private fun readInput(inputFile: File): ClassPool { @@ -136,12 +98,12 @@ private fun readInput(inputFile: File): ClassPool { var classReader: DataEntryReader = NameFilteredDataEntryReader( "**.class", ClassReader( - isLibrary = false, - skipNonPublicLibraryClasses = false, - skipNonPublicLibraryClassMembers = false, - ignoreStackMapAttributes = false, - warningPrinter = null, - classVisitor = ProcessingInfoMergingClassPoolFiller(programClassPool) + /* isLibrary = */ false, + /* skipNonPublicLibraryClasses = */ false, + /* skipNonPublicLibraryClassMembers = */ false, + /* ignoreStackMapAttributes = */ false, + /* warningPrinter = */ null, + ClassPoolFiller(programClassPool) ) ) @@ -186,60 +148,3 @@ private fun readInput(inputFile: File): ClassPool { return programClassPool } - -class ClassReader( - isLibrary: Boolean, - skipNonPublicLibraryClasses: Boolean, - skipNonPublicLibraryClassMembers: Boolean, - ignoreStackMapAttributes: Boolean, - warningPrinter: WarningPrinter?, - private val classVisitor: ClassVisitor -) : DataEntryReader { - private lateinit var currentDataEntry: DataEntry - - private val proguardClassReader = ProGuardClassReader( - isLibrary, - skipNonPublicLibraryClasses, - skipNonPublicLibraryClassMembers, - ignoreStackMapAttributes, - warningPrinter - ) { - it.processingInfo = currentDataEntry - it.accept(classVisitor) - } - - override fun read(dataEntry: DataEntry) { - currentDataEntry = dataEntry - proguardClassReader.read(dataEntry) - } -} - -class ProcessingInfoMergingClassPoolFiller(private val classPool: ClassPool) : ProGuardClassPoolFiller(classPool) { - @Suppress("UNCHECKED_CAST") - override fun visitAnyClass(clazz: Clazz) { - when (val existingClazz = classPool.getClass(clazz.name)) { - is Clazz -> { - val oldProcessingInfo = existingClazz.processingInfo as MutableSet - val newProcessingInfo = clazz.processingInfo - oldProcessingInfo.add(newProcessingInfo as DataEntry) - } - else -> classPool.addClass(clazz.apply { processingInfo = mutableSetOf(processingInfo as DataEntry) }) - } - } -} - -@Suppress("unused") // not required in this app -fun initialize(programClassPool: ClassPool, libraryClassPool: ClassPool) { - val classReferenceInitializer = ClassReferenceInitializer(programClassPool, libraryClassPool) - val classSuperHierarchyInitializer = ClassSuperHierarchyInitializer(programClassPool, libraryClassPool) - val classSubHierarchyInitializer = ClassSubHierarchyInitializer() - - programClassPool.classesAccept(classSuperHierarchyInitializer) - libraryClassPool.classesAccept(classSuperHierarchyInitializer) - - programClassPool.classesAccept(classReferenceInitializer) - libraryClassPool.classesAccept(classReferenceInitializer) - - programClassPool.accept(classSubHierarchyInitializer) - libraryClassPool.accept(classSubHierarchyInitializer) -} diff --git a/src/test/kotlin/Log4ShellDetectorTest.kt b/src/test/kotlin/Log4ShellDetectorTest.kt index d964bbf..2506e1d 100644 --- a/src/test/kotlin/Log4ShellDetectorTest.kt +++ b/src/test/kotlin/Log4ShellDetectorTest.kt @@ -1,19 +1,12 @@ import eu.jameshamilton.log4shell.check import io.kotest.core.spec.style.FunSpec import io.kotest.matchers.shouldBe -import io.mockk.spyk -import io.mockk.verify import proguard.classfile.AccessConstants.PRIVATE import proguard.classfile.AccessConstants.PUBLIC import proguard.classfile.ClassPool -import proguard.classfile.Clazz -import proguard.classfile.Member -import proguard.classfile.ProgramClass -import proguard.classfile.ProgramMember import proguard.classfile.VersionConstants.CLASS_VERSION_1_6 import proguard.classfile.editor.ClassBuilder import proguard.classfile.util.ClassRenamer -import proguard.classfile.visitor.MemberVisitor class Log4ShellDetectorTest : FunSpec({ val jndiLookup = ClassBuilder( @@ -34,34 +27,12 @@ class Log4ShellDetectorTest : FunSpec({ test("Should not detect Log4Shell if JndiLookup is not present") { val programClassPool = ClassPool() - val visitor = spyk(object : MemberVisitor { - override fun visitAnyMember(clazz: Clazz, member: Member) {} - }) - - check(programClassPool, visitor) - - verify(exactly = 0) { - visitor.visitProgramMember( - ofType(ProgramClass::class), - ofType(ProgramMember::class) - ) - } + check(programClassPool) shouldBe false } test("Should detect Log4Shell if JndiLookup and old constructor is present") { val programClassPool = ClassPool(jndiLookup, jndiManager) - val visitor = spyk(object : MemberVisitor { - override fun visitAnyMember(clazz: Clazz, member: Member) {} - }) - - check(programClassPool, visitor) - - verify(exactly = 1) { - visitor.visitProgramMember( - ofType(ProgramClass::class), - ofType(ProgramMember::class) - ) - } + check(programClassPool) shouldBe true } test("Should detect shadowed log4j if JndiLookup and old constructor is present") { @@ -79,33 +50,12 @@ class Log4ShellDetectorTest : FunSpec({ programClassPool.getClass("org/apache/logging/log4j/core/net/JndiManager") .name shouldBe "com/example/shadow/org/apache/logging/log4j/core/net/JndiManager" - val visitor = spyk(object : MemberVisitor { - override fun visitAnyMember(clazz: Clazz, member: Member) {} - }) - check(programClassPool, visitor) - - verify(exactly = 1) { - visitor.visitProgramMember( - ofType(ProgramClass::class), - ofType(ProgramMember::class) - ) - } + check(programClassPool) shouldBe true } test("Should not detect Log4Shell if old constructor is present but JndiLookup is not present") { // Removing JndiLookup is a workaround for log4shell val programClassPool = ClassPool(jndiManager) - val visitor = spyk(object : MemberVisitor { - override fun visitAnyMember(clazz: Clazz, member: Member) {} - }) - - check(programClassPool, visitor) - - verify(exactly = 0) { - visitor.visitProgramMember( - ofType(ProgramClass::class), - ofType(ProgramMember::class) - ) - } + check(programClassPool) shouldBe false } })