From 75de1d76e313977fddbdd1210b19df80648a5c3a Mon Sep 17 00:00:00 2001 From: Chad Retz Date: Fri, 27 Jul 2018 01:15:48 -0500 Subject: [PATCH] Finish stack walker --- compiler/src/main/kotlin/asmble/ast/Stack.kt | 183 +++++++++++++----- .../src/test/kotlin/asmble/ast/StackTest.kt | 33 ++++ compiler/src/test/kotlin/asmble/io/IoTest.kt | 8 +- 3 files changed, 167 insertions(+), 57 deletions(-) create mode 100644 compiler/src/test/kotlin/asmble/ast/StackTest.kt diff --git a/compiler/src/main/kotlin/asmble/ast/Stack.kt b/compiler/src/main/kotlin/asmble/ast/Stack.kt index 48e7c85..f01e5ce 100644 --- a/compiler/src/main/kotlin/asmble/ast/Stack.kt +++ b/compiler/src/main/kotlin/asmble/ast/Stack.kt @@ -7,20 +7,65 @@ data class Stack( // Null if not tracking the current stack and all pops succeed val current: List? = null, val insnApplies: List = emptyList(), - val strictPop: Boolean = false + val strict: Boolean = false, + val unreachableUntilNextEndCount: Int = 0 ) { - fun next(v: Node.Instr, callFuncType: Node.Type.Func? = null) = insnApply(v) { + fun next(v: Node.Instr, callFuncTypeOverride: Node.Type.Func? = null) = insnApply(v) { + // If we're unreachable, and not an end, we skip and move on + if (unreachableUntilNextEndCount > 0 && v !is Node.Instr.End) { + // If it's a block, we increase it because we'll see another end + return@insnApply if (v is Node.Instr.Args.Type) unreachable(unreachableUntilNextEndCount + 1) else nop() + } when (v) { - is Node.Instr.Unreachable, is Node.Instr.Nop, is Node.Instr.Block, - is Node.Instr.Loop, is Node.Instr.If, is Node.Instr.Else, - is Node.Instr.End, is Node.Instr.Br, is Node.Instr.BrIf, - is Node.Instr.Return -> nop() - is Node.Instr.BrTable -> popI32() - is Node.Instr.Call -> (callFuncType ?: error("Call func type missing")).let { + is Node.Instr.Nop, is Node.Instr.Block, is Node.Instr.Loop -> nop() + is Node.Instr.If, is Node.Instr.BrIf -> popI32() + is Node.Instr.Return -> (func?.type?.ret?.let { pop(it) } ?: nop()) + unreachable(1) + is Node.Instr.Unreachable -> unreachable(1) + is Node.Instr.End, is Node.Instr.Else -> { + // Put back what was before the last block and add the block's type + // Go backwards to find the starting block + var currDepth = 0 + val found = insnApplies.findLast { + when (it.insn) { + is Node.Instr.End -> { currDepth++; false } + is Node.Instr.Args.Type -> if (currDepth > 0) { currDepth--; false } else true + else -> false + } + }?.takeIf { + // When it's else, needs to be if + v !is Node.Instr.Else || it.insn is Node.Instr.If + } + val changes = when { + found != null && found.insn is Node.Instr.Args.Type && + found.stackAtBeginning != null && this != null -> { + // Pop everything from before the block's start, then push if necessary... + // The If block includes an int at the beginning we must not include when subtracting + var preBlockStackSize = found.stackAtBeginning.size + if (found.insn is Node.Instr.If) preBlockStackSize-- + val popped = + if (unreachableUntilNextEndCount > 1) nop() + else (0 until (size - preBlockStackSize)).flatMap { pop() } + // Only push if this is not an else + val pushed = + if (unreachableUntilNextEndCount > 1 || v is Node.Instr.Else) nop() + else (found.insn.type?.let { push(it) } ?: nop()) + popped + pushed + } + strict -> error("Unable to find starting block for end") + else -> nop() + } + if (unreachableUntilNextEndCount > 0) changes + unreachable(unreachableUntilNextEndCount - 1) + else changes + } + is Node.Instr.Br -> unreachable(v.relativeDepth + 1) + is Node.Instr.BrTable -> popI32() + unreachable(1) + is Node.Instr.Call -> (callFuncTypeOverride ?: func(v.index)).let { + if (it == null) error("Call func type missing") it.params.reversed().flatMap { pop(it) } + (it.ret?.let { push(it) } ?: nop()) } - is Node.Instr.CallIndirect -> (callFuncType ?: error("Call func type missing")).let { + is Node.Instr.CallIndirect -> (callFuncTypeOverride ?: mod?.mod?.types?.getOrNull(v.index)).let { + if (it == null) error("Call func type missing") // We add one for the table index popI32() + it.params.reversed().flatMap { pop(it) } + (it.ret?.let { push(it) } ?: nop()) } @@ -35,13 +80,13 @@ data class Stack( is Node.Instr.I32Load16U, is Node.Instr.I32Load16S -> popI32() + pushI32() is Node.Instr.I64Load, is Node.Instr.I64Load8S, is Node.Instr.I64Load8U, is Node.Instr.I64Load16U, is Node.Instr.I64Load16S, is Node.Instr.I64Load32S, is Node.Instr.I64Load32U -> popI32() + pushI64() - is Node.Instr.F32Load -> popI32() + popF32() + is Node.Instr.F32Load -> popI32() + pushF32() is Node.Instr.F64Load -> popI32() + pushF64() is Node.Instr.I32Store, is Node.Instr.I32Store8, is Node.Instr.I32Store16 -> popI32() + popI32() is Node.Instr.I64Store, is Node.Instr.I64Store8, - is Node.Instr.I64Store16, is Node.Instr.I64Store32 -> popI32() + popI64() - is Node.Instr.F32Store -> popI32() + pushF32() - is Node.Instr.F64Store -> popI32() + pushF64() + is Node.Instr.I64Store16, is Node.Instr.I64Store32 -> popI64() + popI32() + is Node.Instr.F32Store -> popF32() + popI32() + is Node.Instr.F64Store -> popF64() + popI32() is Node.Instr.MemorySize -> pushI32() is Node.Instr.MemoryGrow -> popI32() + pushI32() is Node.Instr.I32Const -> pushI32() @@ -60,56 +105,60 @@ data class Stack( is Node.Instr.I64Add, is Node.Instr.I64Sub, is Node.Instr.I64Mul, is Node.Instr.I64DivS, is Node.Instr.I64DivU, is Node.Instr.I64RemS, is Node.Instr.I64RemU, is Node.Instr.I64And, is Node.Instr.I64Or, is Node.Instr.I64Xor, is Node.Instr.I64Shl, is Node.Instr.I64ShrS, - is Node.Instr.I64ShrU, is Node.Instr.I64Rotl, is Node.Instr.I64Rotr, is Node.Instr.I64Eq, - is Node.Instr.I64Ne, is Node.Instr.I64LtS, is Node.Instr.I64LeS, is Node.Instr.I64LtU, - is Node.Instr.I64LeU, is Node.Instr.I64GtS, is Node.Instr.I64GeS, is Node.Instr.I64GtU, - is Node.Instr.I64GeU -> popI64() + popI64() + pushI64() - is Node.Instr.I64Clz, is Node.Instr.I64Ctz, is Node.Instr.I64Popcnt, - is Node.Instr.I64Eqz -> popI64() + pushI64() + is Node.Instr.I64ShrU, is Node.Instr.I64Rotl, is Node.Instr.I64Rotr -> popI64() + popI64() + pushI64() + is Node.Instr.I64Eq, is Node.Instr.I64Ne, is Node.Instr.I64LtS, is Node.Instr.I64LeS, + is Node.Instr.I64LtU, is Node.Instr.I64LeU, is Node.Instr.I64GtS, + is Node.Instr.I64GeS, is Node.Instr.I64GtU, is Node.Instr.I64GeU -> popI64() + popI64() + pushI32() + is Node.Instr.I64Clz, is Node.Instr.I64Ctz, is Node.Instr.I64Popcnt -> popI64() + pushI64() + is Node.Instr.I64Eqz -> popI64() + pushI32() is Node.Instr.F32Add, is Node.Instr.F32Sub, is Node.Instr.F32Mul, is Node.Instr.F32Div, - is Node.Instr.F32Eq, is Node.Instr.F32Ne, is Node.Instr.F32Lt, is Node.Instr.F32Le, - is Node.Instr.F32Gt, is Node.Instr.F32Ge, is Node.Instr.F32Min, - is Node.Instr.F32Max, is Node.Instr.F32CopySign -> popF32() + popF32() + pushF32() + is Node.Instr.F32Min, is Node.Instr.F32Max, is Node.Instr.F32CopySign -> popF32() + popF32() + pushF32() + is Node.Instr.F32Eq, is Node.Instr.F32Ne, is Node.Instr.F32Lt, is Node.Instr.F32Le, + is Node.Instr.F32Gt, is Node.Instr.F32Ge -> popF32() + popF32() + pushI32() is Node.Instr.F32Abs, is Node.Instr.F32Neg, is Node.Instr.F32Ceil, is Node.Instr.F32Floor, is Node.Instr.F32Trunc, is Node.Instr.F32Nearest, is Node.Instr.F32Sqrt -> popF32() + pushF32() is Node.Instr.F64Add, is Node.Instr.F64Sub, is Node.Instr.F64Mul, is Node.Instr.F64Div, - is Node.Instr.F64Eq, is Node.Instr.F64Ne, is Node.Instr.F64Lt, is Node.Instr.F64Le, - is Node.Instr.F64Gt, is Node.Instr.F64Ge, is Node.Instr.F64Min, - is Node.Instr.F64Max, is Node.Instr.F64CopySign -> popF64() + popF64() + pushF64() + is Node.Instr.F64Min, is Node.Instr.F64Max, is Node.Instr.F64CopySign -> popF64() + popF64() + pushF64() + is Node.Instr.F64Eq, is Node.Instr.F64Ne, is Node.Instr.F64Lt, is Node.Instr.F64Le, + is Node.Instr.F64Gt, is Node.Instr.F64Ge -> popF64() + popF64() + pushI32() is Node.Instr.F64Abs, is Node.Instr.F64Neg, is Node.Instr.F64Ceil, is Node.Instr.F64Floor, - is Node.Instr.F64Trunc, is Node.Instr.F64Nearest, is Node.Instr.F64Sqrt -> popF64() + popF64() - is Node.Instr.I32WrapI64 -> popI32() + pushI64() + is Node.Instr.F64Trunc, is Node.Instr.F64Nearest, is Node.Instr.F64Sqrt -> popF64() + pushF64() + is Node.Instr.I32WrapI64 -> popI64() + pushI32() is Node.Instr.I32TruncSF32, is Node.Instr.I32TruncUF32, - is Node.Instr.I32ReinterpretF32 -> popI32() + pushF32() - is Node.Instr.I32TruncSF64, is Node.Instr.I32TruncUF64 -> popI32() + pushF64() - is Node.Instr.I64ExtendSI32, is Node.Instr.I64ExtendUI32 -> popI64() + pushI32() - is Node.Instr.I64TruncSF32, is Node.Instr.I64TruncUF32 -> popI64() + pushF32() + is Node.Instr.I32ReinterpretF32 -> popF32() + pushI32() + is Node.Instr.I32TruncSF64, is Node.Instr.I32TruncUF64 -> popF64() + pushI32() + is Node.Instr.I64ExtendSI32, is Node.Instr.I64ExtendUI32 -> popI32() + pushI64() + is Node.Instr.I64TruncSF32, is Node.Instr.I64TruncUF32 -> popF32() + pushI64() is Node.Instr.I64TruncSF64, is Node.Instr.I64TruncUF64, - is Node.Instr.I64ReinterpretF64 -> popI64() + pushF64() + is Node.Instr.I64ReinterpretF64 -> popF64() + pushI64() is Node.Instr.F32ConvertSI32, is Node.Instr.F32ConvertUI32, - is Node.Instr.F32ReinterpretI32 -> popF32() + pushI32() - is Node.Instr.F32ConvertSI64, is Node.Instr.F32ConvertUI64, - is Node.Instr.F64ReinterpretI64 -> popF32() + pushI64() - is Node.Instr.F32DemoteF64 -> popF32() + pushF64() - is Node.Instr.F64ConvertSI32, is Node.Instr.F64ConvertUI32 -> popF64() + pushI32() - is Node.Instr.F64ConvertSI64, is Node.Instr.F64ConvertUI64 -> popF64() + pushI64() - is Node.Instr.F64PromoteF32 -> popF64() + pushF32() + is Node.Instr.F32ReinterpretI32 -> popI32() + pushF32() + is Node.Instr.F32ConvertSI64, is Node.Instr.F32ConvertUI64 -> popI64() + pushF32() + is Node.Instr.F32DemoteF64 -> popF64() + pushF32() + is Node.Instr.F64ConvertSI32, is Node.Instr.F64ConvertUI32 -> popI32() + pushF64() + is Node.Instr.F64ConvertSI64, is Node.Instr.F64ConvertUI64, + is Node.Instr.F64ReinterpretI64 -> popI64() + pushF64() + is Node.Instr.F64PromoteF32 -> popF32() + pushF64() } } - protected fun insnApply(v: Node.Instr, fn: MutableList?.() -> List): Stack { + protected fun insnApply(v: Node.Instr, fn: MutableList?.() -> List): Stack { val mutStack = current?.toMutableList() - val stackChanges = mutStack.fn() + val applyResp = mutStack.fn() + val newUnreachable = (applyResp.find { it is Unreachable } as? Unreachable)?.untilEndCount return copy( current = mutStack, insnApplies = insnApplies + InsnApply( insn = v, stackAtBeginning = current, - stackChanges = stackChanges - ) + stackChanges = applyResp.mapNotNull { it as? StackChange }, + unreachableUntilEndCount = newUnreachable ?: unreachableUntilNextEndCount + ), + unreachableUntilNextEndCount = newUnreachable ?: unreachableUntilNextEndCount ) } + protected fun unreachable(untilEndCount: Int) = listOf(Unreachable(untilEndCount)) protected fun local(index: Int) = func?.let { it.type.params.getOrNull(index) ?: it.locals.getOrNull(index - it.type.params.size) } @@ -122,12 +171,16 @@ data class Stack( it.mod.funcs.getOrNull(index - it.importFuncs.size)?.type } - protected fun MutableList?.nop() = emptyList() + protected fun nop() = emptyList() protected fun MutableList?.popType(expecting: Node.Type.Value? = null) = this?.takeIf { - it.isNotEmpty().also { require(!strictPop || it) } + it.isNotEmpty().also { + require(!strict || it) { "Expected $expecting got empty" } + } }?.let { - removeAt(size - 1).takeIf { (it == expecting).also { require(!strictPop || it) } } + removeAt(size - 1).takeIf { actual -> (expecting == null || actual == expecting).also { + require(!strict || it) { "Expected $expecting got $actual" } + } } } ?: expecting protected fun MutableList?.pop(expecting: Node.Type.Value? = null) = listOf(StackChange(popType(expecting), true)) @@ -137,22 +190,30 @@ data class Stack( protected fun MutableList?.popF32() = pop(Node.Type.Value.F32) protected fun MutableList?.popF64() = pop(Node.Type.Value.F64) - protected fun push(type: Node.Type.Value? = null) = listOf(StackChange(type, false)) - protected fun pushI32() = push(Node.Type.Value.I32) - protected fun pushI64() = push(Node.Type.Value.I64) - protected fun pushF32() = push(Node.Type.Value.F32) - protected fun pushF64() = push(Node.Type.Value.F64) + protected fun MutableList?.push(type: Node.Type.Value? = null) = + listOf(StackChange(type, false)).also { if (this != null && type != null) add(type) } + protected fun MutableList?.pushI32() = push(Node.Type.Value.I32) + protected fun MutableList?.pushI64() = push(Node.Type.Value.I64) + protected fun MutableList?.pushF32() = push(Node.Type.Value.F32) + protected fun MutableList?.pushF64() = push(Node.Type.Value.F64) data class InsnApply( val insn: Node.Instr, val stackAtBeginning: List?, - val stackChanges: List + val stackChanges: List, + val unreachableUntilEndCount: Int ) + protected interface InsnApplyResponse + data class StackChange( val type: Node.Type.Value?, val pop: Boolean - ) + ) : InsnApplyResponse + + data class Unreachable( + val untilEndCount: Int + ) : InsnApplyResponse class CachedModule(val mod: Node.Module) { val importFuncs by lazy { mod.imports.mapNotNull { it.kind as? Node.Import.Kind.Func } } @@ -160,6 +221,22 @@ data class Stack( } companion object { + fun walkStrict(mod: Node.Module, func: Node.Func, afterInsn: ((Stack, Node.Instr) -> Unit)? = null) = + func.instructions.fold(Stack( + mod = CachedModule(mod), + func = func, + current = emptyList(), + strict = true + )) { stack, insn -> stack.next(insn).also { afterInsn?.invoke(it, insn) } }.also { stack -> + // We expect to be in an unreachable state at the end or have the single return value on the stack + if (stack.unreachableUntilNextEndCount == 0) { + val expectedStack = (func.type.ret?.let { listOf(it) } ?: emptyList()) + require(expectedStack == stack.current) { + "Expected end to be $expectedStack, got ${stack.current}" + } + } + } + fun stackChanges(v: Node.Instr, callFuncType: Node.Type.Func? = null) = Stack().next(v, callFuncType).insnApplies.last().stackChanges fun stackDiff(v: Node.Instr, callFuncType: Node.Type.Func? = null) = diff --git a/compiler/src/test/kotlin/asmble/ast/StackTest.kt b/compiler/src/test/kotlin/asmble/ast/StackTest.kt new file mode 100644 index 0000000..4fa711c --- /dev/null +++ b/compiler/src/test/kotlin/asmble/ast/StackTest.kt @@ -0,0 +1,33 @@ +package asmble.ast + +import asmble.SpecTestUnit +import asmble.TestBase +import org.junit.Test +import org.junit.runner.RunWith +import org.junit.runners.Parameterized + +@RunWith(Parameterized::class) +class StackTest(val unit: SpecTestUnit) : TestBase() { + + @Test + fun testStack() { + // If it's not a module expecting an error, we'll try to walk the stack on each function + unit.script.commands.mapNotNull { it as? Script.Cmd.Module }.forEach { mod -> + mod.module.funcs.filter { it.instructions.isNotEmpty() }.forEach { func -> + debug { "Func: ${func.type}" } + var indexCount = 0 + Stack.walkStrict(mod.module, func) { stack, insn -> + debug { " After $insn (next: ${func.instructions.getOrNull(++indexCount)}, " + + "unreach depth: ${stack.unreachableUntilNextEndCount})" } + debug { " " + stack.current } + } + } + } + } + + companion object { + // Only tests that shouldn't fail + @JvmStatic @Parameterized.Parameters(name = "{0}") + fun data() = SpecTestUnit.allUnits.filterNot { it.shouldFail }//.filter { it.name == "loop" } + } +} \ No newline at end of file diff --git a/compiler/src/test/kotlin/asmble/io/IoTest.kt b/compiler/src/test/kotlin/asmble/io/IoTest.kt index c74a856..091d087 100644 --- a/compiler/src/test/kotlin/asmble/io/IoTest.kt +++ b/compiler/src/test/kotlin/asmble/io/IoTest.kt @@ -1,6 +1,7 @@ package asmble.io import asmble.SpecTestUnit +import asmble.TestBase import asmble.ast.Node import asmble.ast.Script import asmble.util.Logger @@ -13,12 +14,10 @@ import java.io.ByteArrayOutputStream import kotlin.test.assertEquals @RunWith(Parameterized::class) -class IoTest(val unit: SpecTestUnit) : Logger by Logger.Print(Logger.Level.INFO) { +class IoTest(val unit: SpecTestUnit) : TestBase() { @Test fun testIo() { - // Ignore things that are supposed to fail - if (unit.shouldFail) return // Go from the AST to binary then back to AST then back to binary and confirm values are as expected val ast1 = unit.script.commands.mapNotNull { (it as? Script.Cmd.Module)?.module?.also { trace { "AST from script:\n" + SExprToStr.fromSExpr(AstToSExpr.fromModule(it)) } @@ -46,7 +45,8 @@ class IoTest(val unit: SpecTestUnit) : Logger by Logger.Print(Logger.Level.INFO) } companion object { + // Only tests that shouldn't fail @JvmStatic @Parameterized.Parameters(name = "{0}") - fun data() = SpecTestUnit.allUnits + fun data() = SpecTestUnit.allUnits.filterNot { it.shouldFail } } }