Skip to content

Commit

Permalink
Finish stack walker
Browse files Browse the repository at this point in the history
  • Loading branch information
cretz committed Jul 27, 2018
1 parent 67e914d commit 75de1d7
Show file tree
Hide file tree
Showing 3 changed files with 167 additions and 57 deletions.
183 changes: 130 additions & 53 deletions compiler/src/main/kotlin/asmble/ast/Stack.kt
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,65 @@ data class Stack(
// Null if not tracking the current stack and all pops succeed
val current: List<Node.Type.Value>? = null,
val insnApplies: List<InsnApply> = emptyList(),
val strictPop: Boolean = false
val strict: Boolean = false,
val unreachableUntilNextEndCount: Int = 0
) {

fun next(v: Node.Instr, callFuncType: Node.Type.Func? = null) = insnApply(v) {
fun next(v: Node.Instr, callFuncTypeOverride: Node.Type.Func? = null) = insnApply(v) {
// If we're unreachable, and not an end, we skip and move on
if (unreachableUntilNextEndCount > 0 && v !is Node.Instr.End) {
// If it's a block, we increase it because we'll see another end
return@insnApply if (v is Node.Instr.Args.Type) unreachable(unreachableUntilNextEndCount + 1) else nop()
}
when (v) {
is Node.Instr.Unreachable, is Node.Instr.Nop, is Node.Instr.Block,
is Node.Instr.Loop, is Node.Instr.If, is Node.Instr.Else,
is Node.Instr.End, is Node.Instr.Br, is Node.Instr.BrIf,
is Node.Instr.Return -> nop()
is Node.Instr.BrTable -> popI32()
is Node.Instr.Call -> (callFuncType ?: error("Call func type missing")).let {
is Node.Instr.Nop, is Node.Instr.Block, is Node.Instr.Loop -> nop()
is Node.Instr.If, is Node.Instr.BrIf -> popI32()
is Node.Instr.Return -> (func?.type?.ret?.let { pop(it) } ?: nop()) + unreachable(1)
is Node.Instr.Unreachable -> unreachable(1)
is Node.Instr.End, is Node.Instr.Else -> {
// Put back what was before the last block and add the block's type
// Go backwards to find the starting block
var currDepth = 0
val found = insnApplies.findLast {
when (it.insn) {
is Node.Instr.End -> { currDepth++; false }
is Node.Instr.Args.Type -> if (currDepth > 0) { currDepth--; false } else true
else -> false
}
}?.takeIf {
// When it's else, needs to be if
v !is Node.Instr.Else || it.insn is Node.Instr.If
}
val changes = when {
found != null && found.insn is Node.Instr.Args.Type &&
found.stackAtBeginning != null && this != null -> {
// Pop everything from before the block's start, then push if necessary...
// The If block includes an int at the beginning we must not include when subtracting
var preBlockStackSize = found.stackAtBeginning.size
if (found.insn is Node.Instr.If) preBlockStackSize--
val popped =
if (unreachableUntilNextEndCount > 1) nop()
else (0 until (size - preBlockStackSize)).flatMap { pop() }
// Only push if this is not an else
val pushed =
if (unreachableUntilNextEndCount > 1 || v is Node.Instr.Else) nop()
else (found.insn.type?.let { push(it) } ?: nop())
popped + pushed
}
strict -> error("Unable to find starting block for end")
else -> nop()
}
if (unreachableUntilNextEndCount > 0) changes + unreachable(unreachableUntilNextEndCount - 1)
else changes
}
is Node.Instr.Br -> unreachable(v.relativeDepth + 1)
is Node.Instr.BrTable -> popI32() + unreachable(1)
is Node.Instr.Call -> (callFuncTypeOverride ?: func(v.index)).let {
if (it == null) error("Call func type missing")
it.params.reversed().flatMap { pop(it) } + (it.ret?.let { push(it) } ?: nop())
}
is Node.Instr.CallIndirect -> (callFuncType ?: error("Call func type missing")).let {
is Node.Instr.CallIndirect -> (callFuncTypeOverride ?: mod?.mod?.types?.getOrNull(v.index)).let {
if (it == null) error("Call func type missing")
// We add one for the table index
popI32() + it.params.reversed().flatMap { pop(it) } + (it.ret?.let { push(it) } ?: nop())
}
Expand All @@ -35,13 +80,13 @@ data class Stack(
is Node.Instr.I32Load16U, is Node.Instr.I32Load16S -> popI32() + pushI32()
is Node.Instr.I64Load, is Node.Instr.I64Load8S, is Node.Instr.I64Load8U, is Node.Instr.I64Load16U,
is Node.Instr.I64Load16S, is Node.Instr.I64Load32S, is Node.Instr.I64Load32U -> popI32() + pushI64()
is Node.Instr.F32Load -> popI32() + popF32()
is Node.Instr.F32Load -> popI32() + pushF32()
is Node.Instr.F64Load -> popI32() + pushF64()
is Node.Instr.I32Store, is Node.Instr.I32Store8, is Node.Instr.I32Store16 -> popI32() + popI32()
is Node.Instr.I64Store, is Node.Instr.I64Store8,
is Node.Instr.I64Store16, is Node.Instr.I64Store32 -> popI32() + popI64()
is Node.Instr.F32Store -> popI32() + pushF32()
is Node.Instr.F64Store -> popI32() + pushF64()
is Node.Instr.I64Store16, is Node.Instr.I64Store32 -> popI64() + popI32()
is Node.Instr.F32Store -> popF32() + popI32()
is Node.Instr.F64Store -> popF64() + popI32()
is Node.Instr.MemorySize -> pushI32()
is Node.Instr.MemoryGrow -> popI32() + pushI32()
is Node.Instr.I32Const -> pushI32()
Expand All @@ -60,56 +105,60 @@ data class Stack(
is Node.Instr.I64Add, is Node.Instr.I64Sub, is Node.Instr.I64Mul, is Node.Instr.I64DivS,
is Node.Instr.I64DivU, is Node.Instr.I64RemS, is Node.Instr.I64RemU, is Node.Instr.I64And,
is Node.Instr.I64Or, is Node.Instr.I64Xor, is Node.Instr.I64Shl, is Node.Instr.I64ShrS,
is Node.Instr.I64ShrU, is Node.Instr.I64Rotl, is Node.Instr.I64Rotr, is Node.Instr.I64Eq,
is Node.Instr.I64Ne, is Node.Instr.I64LtS, is Node.Instr.I64LeS, is Node.Instr.I64LtU,
is Node.Instr.I64LeU, is Node.Instr.I64GtS, is Node.Instr.I64GeS, is Node.Instr.I64GtU,
is Node.Instr.I64GeU -> popI64() + popI64() + pushI64()
is Node.Instr.I64Clz, is Node.Instr.I64Ctz, is Node.Instr.I64Popcnt,
is Node.Instr.I64Eqz -> popI64() + pushI64()
is Node.Instr.I64ShrU, is Node.Instr.I64Rotl, is Node.Instr.I64Rotr -> popI64() + popI64() + pushI64()
is Node.Instr.I64Eq, is Node.Instr.I64Ne, is Node.Instr.I64LtS, is Node.Instr.I64LeS,
is Node.Instr.I64LtU, is Node.Instr.I64LeU, is Node.Instr.I64GtS,
is Node.Instr.I64GeS, is Node.Instr.I64GtU, is Node.Instr.I64GeU -> popI64() + popI64() + pushI32()
is Node.Instr.I64Clz, is Node.Instr.I64Ctz, is Node.Instr.I64Popcnt -> popI64() + pushI64()
is Node.Instr.I64Eqz -> popI64() + pushI32()
is Node.Instr.F32Add, is Node.Instr.F32Sub, is Node.Instr.F32Mul, is Node.Instr.F32Div,
is Node.Instr.F32Eq, is Node.Instr.F32Ne, is Node.Instr.F32Lt, is Node.Instr.F32Le,
is Node.Instr.F32Gt, is Node.Instr.F32Ge, is Node.Instr.F32Min,
is Node.Instr.F32Max, is Node.Instr.F32CopySign -> popF32() + popF32() + pushF32()
is Node.Instr.F32Min, is Node.Instr.F32Max, is Node.Instr.F32CopySign -> popF32() + popF32() + pushF32()
is Node.Instr.F32Eq, is Node.Instr.F32Ne, is Node.Instr.F32Lt, is Node.Instr.F32Le,
is Node.Instr.F32Gt, is Node.Instr.F32Ge -> popF32() + popF32() + pushI32()
is Node.Instr.F32Abs, is Node.Instr.F32Neg, is Node.Instr.F32Ceil, is Node.Instr.F32Floor,
is Node.Instr.F32Trunc, is Node.Instr.F32Nearest, is Node.Instr.F32Sqrt -> popF32() + pushF32()
is Node.Instr.F64Add, is Node.Instr.F64Sub, is Node.Instr.F64Mul, is Node.Instr.F64Div,
is Node.Instr.F64Eq, is Node.Instr.F64Ne, is Node.Instr.F64Lt, is Node.Instr.F64Le,
is Node.Instr.F64Gt, is Node.Instr.F64Ge, is Node.Instr.F64Min,
is Node.Instr.F64Max, is Node.Instr.F64CopySign -> popF64() + popF64() + pushF64()
is Node.Instr.F64Min, is Node.Instr.F64Max, is Node.Instr.F64CopySign -> popF64() + popF64() + pushF64()
is Node.Instr.F64Eq, is Node.Instr.F64Ne, is Node.Instr.F64Lt, is Node.Instr.F64Le,
is Node.Instr.F64Gt, is Node.Instr.F64Ge -> popF64() + popF64() + pushI32()
is Node.Instr.F64Abs, is Node.Instr.F64Neg, is Node.Instr.F64Ceil, is Node.Instr.F64Floor,
is Node.Instr.F64Trunc, is Node.Instr.F64Nearest, is Node.Instr.F64Sqrt -> popF64() + popF64()
is Node.Instr.I32WrapI64 -> popI32() + pushI64()
is Node.Instr.F64Trunc, is Node.Instr.F64Nearest, is Node.Instr.F64Sqrt -> popF64() + pushF64()
is Node.Instr.I32WrapI64 -> popI64() + pushI32()
is Node.Instr.I32TruncSF32, is Node.Instr.I32TruncUF32,
is Node.Instr.I32ReinterpretF32 -> popI32() + pushF32()
is Node.Instr.I32TruncSF64, is Node.Instr.I32TruncUF64 -> popI32() + pushF64()
is Node.Instr.I64ExtendSI32, is Node.Instr.I64ExtendUI32 -> popI64() + pushI32()
is Node.Instr.I64TruncSF32, is Node.Instr.I64TruncUF32 -> popI64() + pushF32()
is Node.Instr.I32ReinterpretF32 -> popF32() + pushI32()
is Node.Instr.I32TruncSF64, is Node.Instr.I32TruncUF64 -> popF64() + pushI32()
is Node.Instr.I64ExtendSI32, is Node.Instr.I64ExtendUI32 -> popI32() + pushI64()
is Node.Instr.I64TruncSF32, is Node.Instr.I64TruncUF32 -> popF32() + pushI64()
is Node.Instr.I64TruncSF64, is Node.Instr.I64TruncUF64,
is Node.Instr.I64ReinterpretF64 -> popI64() + pushF64()
is Node.Instr.I64ReinterpretF64 -> popF64() + pushI64()
is Node.Instr.F32ConvertSI32, is Node.Instr.F32ConvertUI32,
is Node.Instr.F32ReinterpretI32 -> popF32() + pushI32()
is Node.Instr.F32ConvertSI64, is Node.Instr.F32ConvertUI64,
is Node.Instr.F64ReinterpretI64 -> popF32() + pushI64()
is Node.Instr.F32DemoteF64 -> popF32() + pushF64()
is Node.Instr.F64ConvertSI32, is Node.Instr.F64ConvertUI32 -> popF64() + pushI32()
is Node.Instr.F64ConvertSI64, is Node.Instr.F64ConvertUI64 -> popF64() + pushI64()
is Node.Instr.F64PromoteF32 -> popF64() + pushF32()
is Node.Instr.F32ReinterpretI32 -> popI32() + pushF32()
is Node.Instr.F32ConvertSI64, is Node.Instr.F32ConvertUI64 -> popI64() + pushF32()
is Node.Instr.F32DemoteF64 -> popF64() + pushF32()
is Node.Instr.F64ConvertSI32, is Node.Instr.F64ConvertUI32 -> popI32() + pushF64()
is Node.Instr.F64ConvertSI64, is Node.Instr.F64ConvertUI64,
is Node.Instr.F64ReinterpretI64 -> popI64() + pushF64()
is Node.Instr.F64PromoteF32 -> popF32() + pushF64()
}
}

protected fun insnApply(v: Node.Instr, fn: MutableList<Node.Type.Value>?.() -> List<StackChange>): Stack {
protected fun insnApply(v: Node.Instr, fn: MutableList<Node.Type.Value>?.() -> List<InsnApplyResponse>): Stack {
val mutStack = current?.toMutableList()
val stackChanges = mutStack.fn()
val applyResp = mutStack.fn()
val newUnreachable = (applyResp.find { it is Unreachable } as? Unreachable)?.untilEndCount
return copy(
current = mutStack,
insnApplies = insnApplies + InsnApply(
insn = v,
stackAtBeginning = current,
stackChanges = stackChanges
)
stackChanges = applyResp.mapNotNull { it as? StackChange },
unreachableUntilEndCount = newUnreachable ?: unreachableUntilNextEndCount
),
unreachableUntilNextEndCount = newUnreachable ?: unreachableUntilNextEndCount
)
}

protected fun unreachable(untilEndCount: Int) = listOf(Unreachable(untilEndCount))
protected fun local(index: Int) = func?.let {
it.type.params.getOrNull(index) ?: it.locals.getOrNull(index - it.type.params.size)
}
Expand All @@ -122,12 +171,16 @@ data class Stack(
it.mod.funcs.getOrNull(index - it.importFuncs.size)?.type
}

protected fun MutableList<Node.Type.Value>?.nop() = emptyList<StackChange>()
protected fun nop() = emptyList<StackChange>()
protected fun MutableList<Node.Type.Value>?.popType(expecting: Node.Type.Value? = null) =
this?.takeIf {
it.isNotEmpty().also { require(!strictPop || it) }
it.isNotEmpty().also {
require(!strict || it) { "Expected $expecting got empty" }
}
}?.let {
removeAt(size - 1).takeIf { (it == expecting).also { require(!strictPop || it) } }
removeAt(size - 1).takeIf { actual -> (expecting == null || actual == expecting).also {
require(!strict || it) { "Expected $expecting got $actual" }
} }
} ?: expecting
protected fun MutableList<Node.Type.Value>?.pop(expecting: Node.Type.Value? = null) =
listOf(StackChange(popType(expecting), true))
Expand All @@ -137,29 +190,53 @@ data class Stack(
protected fun MutableList<Node.Type.Value>?.popF32() = pop(Node.Type.Value.F32)
protected fun MutableList<Node.Type.Value>?.popF64() = pop(Node.Type.Value.F64)

protected fun push(type: Node.Type.Value? = null) = listOf(StackChange(type, false))
protected fun pushI32() = push(Node.Type.Value.I32)
protected fun pushI64() = push(Node.Type.Value.I64)
protected fun pushF32() = push(Node.Type.Value.F32)
protected fun pushF64() = push(Node.Type.Value.F64)
protected fun MutableList<Node.Type.Value>?.push(type: Node.Type.Value? = null) =
listOf(StackChange(type, false)).also { if (this != null && type != null) add(type) }
protected fun MutableList<Node.Type.Value>?.pushI32() = push(Node.Type.Value.I32)
protected fun MutableList<Node.Type.Value>?.pushI64() = push(Node.Type.Value.I64)
protected fun MutableList<Node.Type.Value>?.pushF32() = push(Node.Type.Value.F32)
protected fun MutableList<Node.Type.Value>?.pushF64() = push(Node.Type.Value.F64)

data class InsnApply(
val insn: Node.Instr,
val stackAtBeginning: List<Node.Type.Value>?,
val stackChanges: List<StackChange>
val stackChanges: List<StackChange>,
val unreachableUntilEndCount: Int
)

protected interface InsnApplyResponse

data class StackChange(
val type: Node.Type.Value?,
val pop: Boolean
)
) : InsnApplyResponse

data class Unreachable(
val untilEndCount: Int
) : InsnApplyResponse

class CachedModule(val mod: Node.Module) {
val importFuncs by lazy { mod.imports.mapNotNull { it.kind as? Node.Import.Kind.Func } }
val importGlobals by lazy { mod.imports.mapNotNull { it.kind as? Node.Import.Kind.Global } }
}

companion object {
fun walkStrict(mod: Node.Module, func: Node.Func, afterInsn: ((Stack, Node.Instr) -> Unit)? = null) =
func.instructions.fold(Stack(
mod = CachedModule(mod),
func = func,
current = emptyList(),
strict = true
)) { stack, insn -> stack.next(insn).also { afterInsn?.invoke(it, insn) } }.also { stack ->
// We expect to be in an unreachable state at the end or have the single return value on the stack
if (stack.unreachableUntilNextEndCount == 0) {
val expectedStack = (func.type.ret?.let { listOf(it) } ?: emptyList())
require(expectedStack == stack.current) {
"Expected end to be $expectedStack, got ${stack.current}"
}
}
}

fun stackChanges(v: Node.Instr, callFuncType: Node.Type.Func? = null) =
Stack().next(v, callFuncType).insnApplies.last().stackChanges
fun stackDiff(v: Node.Instr, callFuncType: Node.Type.Func? = null) =
Expand Down
33 changes: 33 additions & 0 deletions compiler/src/test/kotlin/asmble/ast/StackTest.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package asmble.ast

import asmble.SpecTestUnit
import asmble.TestBase
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.Parameterized

@RunWith(Parameterized::class)
class StackTest(val unit: SpecTestUnit) : TestBase() {

@Test
fun testStack() {
// If it's not a module expecting an error, we'll try to walk the stack on each function
unit.script.commands.mapNotNull { it as? Script.Cmd.Module }.forEach { mod ->
mod.module.funcs.filter { it.instructions.isNotEmpty() }.forEach { func ->
debug { "Func: ${func.type}" }
var indexCount = 0
Stack.walkStrict(mod.module, func) { stack, insn ->
debug { " After $insn (next: ${func.instructions.getOrNull(++indexCount)}, " +
"unreach depth: ${stack.unreachableUntilNextEndCount})" }
debug { " " + stack.current }
}
}
}
}

companion object {
// Only tests that shouldn't fail
@JvmStatic @Parameterized.Parameters(name = "{0}")
fun data() = SpecTestUnit.allUnits.filterNot { it.shouldFail }//.filter { it.name == "loop" }
}
}
8 changes: 4 additions & 4 deletions compiler/src/test/kotlin/asmble/io/IoTest.kt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package asmble.io

import asmble.SpecTestUnit
import asmble.TestBase
import asmble.ast.Node
import asmble.ast.Script
import asmble.util.Logger
Expand All @@ -13,12 +14,10 @@ import java.io.ByteArrayOutputStream
import kotlin.test.assertEquals

@RunWith(Parameterized::class)
class IoTest(val unit: SpecTestUnit) : Logger by Logger.Print(Logger.Level.INFO) {
class IoTest(val unit: SpecTestUnit) : TestBase() {

@Test
fun testIo() {
// Ignore things that are supposed to fail
if (unit.shouldFail) return
// Go from the AST to binary then back to AST then back to binary and confirm values are as expected
val ast1 = unit.script.commands.mapNotNull { (it as? Script.Cmd.Module)?.module?.also {
trace { "AST from script:\n" + SExprToStr.fromSExpr(AstToSExpr.fromModule(it)) }
Expand Down Expand Up @@ -46,7 +45,8 @@ class IoTest(val unit: SpecTestUnit) : Logger by Logger.Print(Logger.Level.INFO)
}

companion object {
// Only tests that shouldn't fail
@JvmStatic @Parameterized.Parameters(name = "{0}")
fun data() = SpecTestUnit.allUnits
fun data() = SpecTestUnit.allUnits.filterNot { it.shouldFail }
}
}

0 comments on commit 75de1d7

Please sign in to comment.