diff --git a/compiler/src/main/kotlin/asmble/ast/Node.kt b/compiler/src/main/kotlin/asmble/ast/Node.kt index 9704db3..0479266 100644 --- a/compiler/src/main/kotlin/asmble/ast/Node.kt +++ b/compiler/src/main/kotlin/asmble/ast/Node.kt @@ -30,10 +30,18 @@ sealed class Node { sealed class Type : Node() { sealed class Value : Type() { - object I32 : Value() - object I64 : Value() - object F32 : Value() - object F64 : Value() + object I32 : Value() { + override fun toString() = "I32" + } + object I64 : Value() { + override fun toString() = "I64" + } + object F32 : Value() { + override fun toString() = "F32" + } + object F64 : Value() { + override fun toString() = "F64" + } } data class Func( diff --git a/compiler/src/main/kotlin/asmble/cli/Invoke.kt b/compiler/src/main/kotlin/asmble/cli/Invoke.kt index e8b98ff..1f22eec 100644 --- a/compiler/src/main/kotlin/asmble/cli/Invoke.kt +++ b/compiler/src/main/kotlin/asmble/cli/Invoke.kt @@ -40,8 +40,7 @@ open class Invoke : ScriptCommand() { if (args.module == "") ctx.modules.lastOrNull() ?: error("No modules available") else ctx.registrations[args.module] as? Module.Instance ?: error("Unable to find module registered as ${args.module}") - // Just make sure the module is instantiated here... - module.instance(ctx) + module as Module.Compiled // If an export is provided, call it if (args.export != "") args.export.javaIdent.let { javaName -> val method = module.cls.declaredMethods.find { it.name == javaName } ?: @@ -59,7 +58,7 @@ open class Invoke : ScriptCommand() { else -> error("Unrecognized type for param ${index + 1}: $paramType") } } - val result = method.invoke(module.instance(ctx), *params.toTypedArray()) + val result = method.invoke(module.inst, *params.toTypedArray()) if (args.resultToStdout && method.returnType != Void.TYPE) println(result) } } diff --git a/compiler/src/main/kotlin/asmble/cli/ScriptCommand.kt b/compiler/src/main/kotlin/asmble/cli/ScriptCommand.kt index 5e3b4a1..de29ec4 100644 --- a/compiler/src/main/kotlin/asmble/cli/ScriptCommand.kt +++ b/compiler/src/main/kotlin/asmble/cli/ScriptCommand.kt @@ -3,6 +3,7 @@ package asmble.cli import asmble.ast.Script import asmble.compile.jvm.javaIdent import asmble.run.jvm.Module +import asmble.run.jvm.ModuleBuilder import asmble.run.jvm.ScriptContext import java.io.File import java.util.* @@ -45,21 +46,23 @@ abstract class ScriptCommand : Command() { ) fun prepareContext(args: ScriptArgs): ScriptContext { - var ctx = ScriptContext( + val builder = ModuleBuilder.Compiled( packageName = "asmble.temp" + UUID.randomUUID().toString().replace("-", ""), + logger = logger, defaultMaxMemPages = args.defaultMaxMemPages ) + var ctx = ScriptContext(logger = logger, builder = builder) // Compile everything ctx = args.inFiles.foldIndexed(ctx) { index, ctx, inFile -> try { when (inFile.substringAfterLast('.')) { - "class" -> ctx.classLoader.addClass(File(inFile).readBytes()).let { ctx } + "class" -> builder.classLoader.addClass(File(inFile).readBytes()).let { ctx } else -> Translate.inToAst(inFile, inFile.substringAfterLast('.')).let { inAst -> val (mod, name) = (inAst.commands.singleOrNull() as? Script.Cmd.Module) ?: error("Input file must only contain a single module") val className = name?.javaIdent?.capitalize() ?: "Temp" + UUID.randomUUID().toString().replace("-", "") - ctx.withCompiledModule(mod, className, name).let { ctx -> + ctx.withBuiltModule(mod, className, name).let { ctx -> if (name == null && index != args.inFiles.size - 1) logger.warn { "File '$inFile' not last and has no name so will be unused" } if (name == null || args.disableAutoRegister) ctx @@ -71,8 +74,8 @@ abstract class ScriptCommand : Command() { } // Do registrations ctx = args.registrations.fold(ctx) { ctx, (moduleName, className) -> - ctx.withModuleRegistered(moduleName, - Module.Native(Class.forName(className, true, ctx.classLoader).newInstance())) + ctx.withModuleRegistered( + Module.Native(moduleName, Class.forName(className, true, builder.classLoader).newInstance())) } if (args.specTestRegister) ctx = ctx.withHarnessRegistered() return ctx diff --git a/compiler/src/main/kotlin/asmble/compile/jvm/AsmExt.kt b/compiler/src/main/kotlin/asmble/compile/jvm/AsmExt.kt index 5db68be..094c27b 100644 --- a/compiler/src/main/kotlin/asmble/compile/jvm/AsmExt.kt +++ b/compiler/src/main/kotlin/asmble/compile/jvm/AsmExt.kt @@ -56,10 +56,10 @@ val Class<*>.ref: TypeRef get() = TypeRef(this.asmType) val Class<*>.valueType: Node.Type.Value? get() = when (this) { Void.TYPE -> null - Int::class.java -> Node.Type.Value.I32 - Long::class.java -> Node.Type.Value.I64 - Float::class.java -> Node.Type.Value.F32 - Double::class.java -> Node.Type.Value.F64 + Int::class.java, java.lang.Integer::class.java -> Node.Type.Value.I32 + Long::class.java, java.lang.Long::class.java -> Node.Type.Value.I64 + Float::class.java, java.lang.Float::class.java -> Node.Type.Value.F32 + Double::class.java, java.lang.Double::class.java -> Node.Type.Value.F64 else -> error("Unrecognized value type class: $this") } @@ -113,6 +113,15 @@ val Double.const: AbstractInsnNode get() = when (this) { else -> LdcInsnNode(this) } +val Number?.valueType get() = when (this) { + null -> null + is Int -> Node.Type.Value.I32 + is Long-> Node.Type.Value.I64 + is Float -> Node.Type.Value.F32 + is Double -> Node.Type.Value.F64 + else -> error("Unrecognized value type class: $this") +} + val String.const: AbstractInsnNode get() = LdcInsnNode(this) val javaKeywords = setOf("abstract", "assert", "boolean", @@ -177,7 +186,6 @@ fun MethodNode.addInsns(vararg insn: AbstractInsnNode): MethodNode { return this } - fun MethodNode.cloneWithInsnRange(range: IntRange) = MethodNode(access, name, desc, signature, exceptions.toTypedArray()).also { new -> accept(new) diff --git a/compiler/src/main/kotlin/asmble/compile/jvm/AstToAsm.kt b/compiler/src/main/kotlin/asmble/compile/jvm/AstToAsm.kt index 8f94776..fe597d2 100644 --- a/compiler/src/main/kotlin/asmble/compile/jvm/AstToAsm.kt +++ b/compiler/src/main/kotlin/asmble/compile/jvm/AstToAsm.kt @@ -385,7 +385,7 @@ open class AstToAsm { } // Otherwise, it was imported and we can set the elems on the imported one // from the parameter - // TODO: I think this is a security concern and bad practice, may revisit + // TODO: I think this is a security concern and bad practice, may revisit (TODO: consider cloning the array) val importIndex = ctx.importFuncs.size + ctx.importGlobals.sumBy { // Immutable is 1, mutable is 2 if ((it.kind as? Node.Import.Kind.Global)?.type?.mutable == false) 1 else 2 diff --git a/compiler/src/main/kotlin/asmble/run/jvm/ExceptionTranslator.kt b/compiler/src/main/kotlin/asmble/run/jvm/ExceptionTranslator.kt index b617f7b..56cf9d3 100644 --- a/compiler/src/main/kotlin/asmble/run/jvm/ExceptionTranslator.kt +++ b/compiler/src/main/kotlin/asmble/run/jvm/ExceptionTranslator.kt @@ -11,7 +11,8 @@ open class ExceptionTranslator { "/ by zero", "BigInteger divide by zero" -> listOf("integer divide by zero") else -> listOf(ex.message!!.decapitalize()) } - is ArrayIndexOutOfBoundsException -> listOf("undefined element", "elements segment does not fit") + is ArrayIndexOutOfBoundsException -> + listOf("out of bounds memory access", "undefined element", "elements segment does not fit") is AsmErr -> ex.asmErrStrings is IndexOutOfBoundsException -> listOf("out of bounds memory access") is MalformedInputException -> listOf("invalid UTF-8 encoding") diff --git a/compiler/src/main/kotlin/asmble/run/jvm/Module.kt b/compiler/src/main/kotlin/asmble/run/jvm/Module.kt index 8247eea..6c11f13 100644 --- a/compiler/src/main/kotlin/asmble/run/jvm/Module.kt +++ b/compiler/src/main/kotlin/asmble/run/jvm/Module.kt @@ -4,72 +4,81 @@ import asmble.annotation.WasmExport import asmble.annotation.WasmExternalKind import asmble.ast.Node import asmble.compile.jvm.Mem +import asmble.compile.jvm.javaIdent import asmble.compile.jvm.ref import java.lang.invoke.MethodHandle import java.lang.invoke.MethodHandles -import java.lang.invoke.MethodType import java.lang.reflect.Constructor import java.lang.reflect.Modifier interface Module { - fun bindMethod( - ctx: ScriptContext, - wasmName: String, - wasmKind: WasmExternalKind, - javaName: String, - type: MethodType - ): MethodHandle? - - data class Composite(val modules: List) : Module { - override fun bindMethod( - ctx: ScriptContext, - wasmName: String, - wasmKind: WasmExternalKind, - javaName: String, - type: MethodType - ) = modules.asSequence().mapNotNull { it.bindMethod(ctx, wasmName, wasmKind, javaName, type) }.singleOrNull() + val name: String? + + fun exportedFunc(field: String): MethodHandle? + fun exportedGlobal(field: String): Pair? + fun exportedMemory(field: String, memClass: Class): T? + fun exportedTable(field: String): Array? + + interface ImportResolver { + fun resolveImportFunc(module: String, field: String, type: Node.Type.Func): MethodHandle + fun resolveImportGlobal( + module: String, + field: String, + type: Node.Type.Global + ): Pair + fun resolveImportMemory(module: String, field: String, type: Node.Type.Memory, memClass: Class): T + fun resolveImportTable(module: String, field: String, type: Node.Type.Table): Array } interface Instance : Module { val cls: Class<*> - // Guaranteed to be the same instance when there is no error - fun instance(ctx: ScriptContext): Any + val inst: Any - override fun bindMethod( - ctx: ScriptContext, + fun bindMethod( wasmName: String, wasmKind: WasmExternalKind, - javaName: String, - type: MethodType + javaName: String = wasmName.javaIdent, + paramCountRequired: Int? = null ) = cls.methods.filter { // @WasmExport match or just javaName match Modifier.isPublic(it.modifiers) && !Modifier.isStatic(it.modifiers) && + (paramCountRequired == null || it.parameterCount == paramCountRequired) && it.getDeclaredAnnotation(WasmExport::class.java).let { ann -> if (ann == null) it.name == javaName else ann.value == wasmName && ann.kind == wasmKind } - }.mapNotNull { - MethodHandles.lookup().unreflect(it).bindTo(instance(ctx)).takeIf { it.type() == type } - }.singleOrNull() - } + }.mapNotNull { MethodHandles.lookup().unreflect(it).bindTo(inst) }.singleOrNull() - data class Native(override val cls: Class<*>, val inst: Any) : Instance { - constructor(inst: Any) : this(inst::class.java, inst) + override fun exportedFunc(field: String) = bindMethod(field, WasmExternalKind.FUNCTION, field.javaIdent) + override fun exportedGlobal(field: String) = + bindMethod(field, WasmExternalKind.GLOBAL, "get" + field.javaIdent.capitalize(), 0)?.let { + it to bindMethod(field, WasmExternalKind.GLOBAL, "set" + field.javaIdent.capitalize(), 1) + } + @SuppressWarnings("UNCHECKED_CAST") + override fun exportedMemory(field: String, memClass: Class) = + bindMethod(field, WasmExternalKind.MEMORY, "get" + field.javaIdent.capitalize(), 0)?. + takeIf { it.type().returnType() == memClass }?.let { it.invokeWithArguments() as? T } + @SuppressWarnings("UNCHECKED_CAST") + override fun exportedTable(field: String) = + bindMethod(field, WasmExternalKind.TABLE, "get" + field.javaIdent.capitalize(), 0)?. + let { it.invokeWithArguments() as? Array } + } - override fun instance(ctx: ScriptContext) = inst + data class Native(override val cls: Class<*>, override val name: String?, override val inst: Any) : Instance { + constructor(name: String?, inst: Any) : this(inst::class.java, name, inst) } class Compiled( val mod: Node.Module, override val cls: Class<*>, - val name: String?, - val mem: Mem + override val name: String?, + val mem: Mem, + imports: ImportResolver, + val defaultMaxMemPages: Int = 1 ) : Instance { - private var inst: Any? = null - override fun instance(ctx: ScriptContext) = - synchronized(this) { inst ?: createInstance(ctx).also { inst = it } } + override val inst = createInstance(imports) - private fun createInstance(ctx: ScriptContext): Any { + private fun createInstance(imports: ImportResolver): Any { // Find the constructor var constructorParams = emptyList() var constructor: Constructor<*>? @@ -79,7 +88,8 @@ interface Module { val memLimit = if (memImport != null) { constructor = cls.declaredConstructors.find { it.parameterTypes.firstOrNull()?.ref == mem.memType } val memImportKind = memImport.kind as Node.Import.Kind.Memory - val memInst = ctx.resolveImportMemory(memImport, memImportKind.type, mem) + val memInst = imports.resolveImportMemory(memImport.module, memImport.field, + memImportKind.type, Class.forName(mem.memType.asm.className)) constructorParams += memInst val (memLimit, memCap) = mem.limitAndCapacity(memInst) if (memLimit < memImportKind.type.limits.initial * Mem.PAGE_SIZE) @@ -101,7 +111,7 @@ interface Module { // If it is not there, find the one w/ the max mem amount val maybeMem = mod.memories.firstOrNull() if (constructor == null) { - val maxMem = Math.max(maybeMem?.limits?.initial ?: 0, ctx.defaultMaxMemPages) + val maxMem = Math.max(maybeMem?.limits?.initial ?: 0, defaultMaxMemPages) constructor = cls.declaredConstructors.find { it.parameterTypes.firstOrNull() == Int::class.java } constructorParams += maxMem * Mem.PAGE_SIZE } @@ -111,14 +121,16 @@ interface Module { // Function imports constructorParams += mod.imports.mapNotNull { - if (it.kind is Node.Import.Kind.Func) ctx.resolveImportFunc(it, mod.types[it.kind.typeIndex]) + if (it.kind is Node.Import.Kind.Func) + imports.resolveImportFunc(it.module, it.field, mod.types[it.kind.typeIndex]) else null } // Global imports val globalImports = mod.imports.flatMap { - if (it.kind is Node.Import.Kind.Global) ctx.resolveImportGlobals(it, it.kind.type) - else emptyList() + if (it.kind is Node.Import.Kind.Global) { + imports.resolveImportGlobal(it.module, it.field, it.kind.type).toList().mapNotNull { it } + } else emptyList() } constructorParams += globalImports @@ -126,7 +138,7 @@ interface Module { val tableImport = mod.imports.find { it.kind is Node.Import.Kind.Table } val tableSize = if (tableImport != null) { val tableImportKind = tableImport.kind as Node.Import.Kind.Table - val table = ctx.resolveImportTable(tableImport, tableImportKind.type) + val table = imports.resolveImportTable(tableImport.module, tableImport.field, tableImportKind.type) if (table.size < tableImportKind.type.limits.initial) throw RunErr.ImportTableTooSmall(tableImportKind.type.limits.initial, table.size) tableImportKind.type.limits.maximum?.let { @@ -164,7 +176,6 @@ interface Module { } // Construct - ctx.debug { "Instantiating $cls using $constructor with params $constructorParams" } return constructor.newInstance(*constructorParams.toTypedArray()) } } diff --git a/compiler/src/main/kotlin/asmble/run/jvm/ModuleBuilder.kt b/compiler/src/main/kotlin/asmble/run/jvm/ModuleBuilder.kt new file mode 100644 index 0000000..13bfb0f --- /dev/null +++ b/compiler/src/main/kotlin/asmble/run/jvm/ModuleBuilder.kt @@ -0,0 +1,65 @@ +package asmble.run.jvm + +import asmble.ast.Node +import asmble.compile.jvm.* +import asmble.util.Logger +import org.objectweb.asm.ClassReader +import org.objectweb.asm.ClassVisitor +import org.objectweb.asm.Opcodes + +interface ModuleBuilder { + fun build(imports: Module.ImportResolver, mod: Node.Module, className: String, name: String?): T + + class Compiled( + val packageName: String = "", + val logger: Logger = Logger.Print(Logger.Level.OFF), + val classLoader: SimpleClassLoader = SimpleClassLoader(Compiled::class.java.classLoader, logger), + val adjustContext: (ClsContext) -> ClsContext = { it }, + val includeBinaryInCompiledClass: Boolean = false, + val defaultMaxMemPages: Int = 1 + ) : ModuleBuilder { + override fun build( + imports: Module.ImportResolver, + mod: Node.Module, + className: String, + name: String? + ): Module.Compiled { + val ctx = ClsContext( + packageName = packageName, + className = className, + mod = mod, + logger = logger, + includeBinary = includeBinaryInCompiledClass + ).let(adjustContext) + AstToAsm.fromModule(ctx) + return Module.Compiled(mod, classLoader.fromBuiltContext(ctx), name, ctx.mem, imports, defaultMaxMemPages) + } + + open class SimpleClassLoader( + parent: ClassLoader, + logger: Logger, + val splitWhenTooLarge: Boolean = true + ) : ClassLoader(parent), Logger by logger { + fun fromBuiltContext(ctx: ClsContext): Class<*> { + trace { "Computing frames for ASM class:\n" + ctx.cls.toAsmString() } + val writer = if (splitWhenTooLarge) AsmToBinary else AsmToBinary.noSplit + return writer.fromClassNode(ctx.cls).let { bytes -> + debug { "ASM class:\n" + bytes.asClassNode().toAsmString() } + val prefix = if (ctx.packageName.isNotEmpty()) ctx.packageName + "." else "" + defineClass("$prefix${ctx.className}", bytes, 0, bytes.size) + } + } + + fun addClass(bytes: ByteArray) { + // Just get the name + var className = "" + ClassReader(bytes).accept(object : ClassVisitor(Opcodes.ASM5) { + override fun visit(a: Int, b: Int, name: String, c: String?, d: String?, e: Array?) { + className = name.replace('/', '.') + } + }, ClassReader.SKIP_CODE) + defineClass(className, bytes, 0, bytes.size) + } + } + } +} \ No newline at end of file diff --git a/compiler/src/main/kotlin/asmble/run/jvm/RunErr.kt b/compiler/src/main/kotlin/asmble/run/jvm/RunErr.kt index aa81251..8f6e15c 100644 --- a/compiler/src/main/kotlin/asmble/run/jvm/RunErr.kt +++ b/compiler/src/main/kotlin/asmble/run/jvm/RunErr.kt @@ -1,6 +1,7 @@ package asmble.run.jvm import asmble.AsmErr +import asmble.ast.Node sealed class RunErr(message: String, cause: Throwable? = null) : RuntimeException(message, cause), AsmErr { @@ -51,16 +52,17 @@ sealed class RunErr(message: String, cause: Throwable? = null) : RuntimeExceptio class ImportNotFound( val module: String, val field: String - ) : RunErr("Cannot find compatible import for $module::$field") { + ) : RunErr("Cannot find import for $module::$field") { override val asmErrString get() = "unknown import" override val asmErrStrings get() = listOf(asmErrString, "incompatible import type") } - class ImportGlobalInvalidMutability( + class ImportIncompatible( val module: String, val field: String, - val expected: Boolean - ) : RunErr("Expected imported global $module::$field to have mutability as ${!expected}") { + val expected: Node.Type, + val actual: Node.Type + ) : RunErr("Import $module::$field expected type $expected, got $actual") { override val asmErrString get() = "incompatible import type" } } \ No newline at end of file diff --git a/compiler/src/main/kotlin/asmble/run/jvm/ScriptContext.kt b/compiler/src/main/kotlin/asmble/run/jvm/ScriptContext.kt index 945d952..63d9cad 100644 --- a/compiler/src/main/kotlin/asmble/run/jvm/ScriptContext.kt +++ b/compiler/src/main/kotlin/asmble/run/jvm/ScriptContext.kt @@ -1,47 +1,35 @@ package asmble.run.jvm -import asmble.annotation.WasmExternalKind import asmble.ast.Node import asmble.ast.Script -import asmble.compile.jvm.* +import asmble.compile.jvm.valueType import asmble.io.AstToSExpr import asmble.io.SExprToStr import asmble.util.Logger import asmble.util.toRawIntBits import asmble.util.toRawLongBits -import org.objectweb.asm.ClassReader -import org.objectweb.asm.ClassVisitor -import org.objectweb.asm.Opcodes import java.io.PrintWriter import java.lang.invoke.MethodHandle -import java.lang.invoke.MethodHandles -import java.lang.invoke.MethodType import java.lang.reflect.InvocationTargetException import java.util.* data class ScriptContext( - val packageName: String, - val modules: List = emptyList(), + val modules: List = emptyList(), val registrations: Map = emptyMap(), val logger: Logger = Logger.Print(Logger.Level.OFF), - val adjustContext: (ClsContext) -> ClsContext = { it }, - val classLoader: SimpleClassLoader = - ScriptContext.SimpleClassLoader(ScriptContext::class.java.classLoader, logger), val exceptionTranslator: ExceptionTranslator = ExceptionTranslator, - val defaultMaxMemPages: Int = 1, - val includeBinaryInCompiledClass: Boolean = false -) : Logger by logger { + val builder: ModuleBuilder<*> = ModuleBuilder.Compiled(logger = logger), + val assertionExclusionFilter: (Script.Cmd.Assertion) -> Boolean = { false } +) : Module.ImportResolver, Logger by logger { fun withHarnessRegistered(out: PrintWriter = PrintWriter(System.out, true)) = - withModuleRegistered("spectest", Module.Native(TestHarness(out))) + withModuleRegistered(Module.Native("spectest", TestHarness(out))) - fun withModuleRegistered(name: String, mod: Module) = copy(registrations = registrations + (name to mod)) + fun withModuleRegistered(mod: Module) = + copy(registrations = registrations + ((mod.name ?: error("Missing module name")) to mod)) fun runCommand(cmd: Script.Cmd) = when (cmd) { is Script.Cmd.Module -> - // We ask for the module instance because some things are built on expectation - compileModule(cmd.module, "Module${modules.size}", cmd.name).also { it.instance(this) }.let { - copy(modules = modules + it) - } + copy(modules = modules + buildModule(cmd.module, "Module${modules.size}", cmd.name)) is Script.Cmd.Register -> copy(registrations = registrations + ( cmd.string to ( @@ -53,10 +41,14 @@ data class ScriptContext( doAction(cmd).let { this } is Script.Cmd.Assertion -> doAssertion(cmd).let { this } - else -> TODO("BOO: $cmd") + is Script.Cmd.Meta -> throw NotImplementedError("Meta commands cannot be run") } fun doAssertion(cmd: Script.Cmd.Assertion) { + if (assertionExclusionFilter(cmd)) { + debug { "Ignoring assertion: " + SExprToStr.fromSExpr(AstToSExpr.fromAssertion(cmd)) } + return + } debug { "Performing assertion: " + SExprToStr.fromSExpr(AstToSExpr.fromAssertion(cmd)) } when (cmd) { is Script.Cmd.Assertion.Return -> assertReturn(cmd) @@ -67,13 +59,14 @@ data class ScriptContext( is Script.Cmd.Assertion.Unlinkable -> assertUnlinkable(cmd) is Script.Cmd.Assertion.TrapModule -> assertTrapModule(cmd) is Script.Cmd.Assertion.Exhaustion -> assertExhaustion(cmd) - else -> TODO("Assertion misssing: $cmd") + else -> TODO("Assertion missing: $cmd") } } fun assertReturn(ret: Script.Cmd.Assertion.Return) { require(ret.exprs.size < 2) - val (retType, retVal) = doAction(ret.action) + val retVal = doAction(ret.action) + val retType = retVal?.javaClass?.valueType when (retType) { null -> if (ret.exprs.isNotEmpty()) @@ -106,8 +99,8 @@ data class ScriptContext( } fun assertReturnNan(ret: Script.Cmd.Assertion.ReturnNan) { - val (retType, retVal) = doAction(ret.action) - when (retType) { + val retVal = doAction(ret.action) + when (retVal?.javaClass?.valueType) { Node.Type.Value.F32 -> if (!(retVal as Float).isNaN()) throw ScriptAssertionError(ret, "Expected NaN, got $retVal", retVal) Node.Type.Value.F64 -> @@ -129,7 +122,7 @@ data class ScriptContext( try { debug { "Compiling malformed: " + SExprToStr.Compact.fromSExpr(AstToSExpr.fromModule(malformed.module.value)) } val className = "malformed" + UUID.randomUUID().toString().replace("-", "") - compileModule(malformed.module.value, className, null) + buildModule(malformed.module.value, className, null) throw ScriptAssertionError( malformed, "Expected malformed module with error '${malformed.failure}', was valid" @@ -141,7 +134,7 @@ data class ScriptContext( try { debug { "Compiling invalid: " + SExprToStr.Compact.fromSExpr(AstToSExpr.fromModule(invalid.module.value)) } val className = "invalid" + UUID.randomUUID().toString().replace("-", "") - compileModule(invalid.module.value, className, null) + buildModule(invalid.module.value, className, null) throw ScriptAssertionError(invalid, "Expected invalid module with error '${invalid.failure}', was valid") } catch (e: Exception) { assertFailure(invalid, e, invalid.failure) } } @@ -149,7 +142,7 @@ data class ScriptContext( fun assertUnlinkable(unlink: Script.Cmd.Assertion.Unlinkable) { try { val className = "unlinkable" + UUID.randomUUID().toString().replace("-", "") - compileModule(unlink.module, className, null).instance(this) + buildModule(unlink.module, className, null) throw ScriptAssertionError(unlink, "Expected module link error with '${unlink.failure}', was valid") } catch (e: Throwable) { assertFailure(unlink, e, unlink.failure) } } @@ -157,7 +150,7 @@ data class ScriptContext( fun assertTrapModule(trap: Script.Cmd.Assertion.TrapModule) { try { val className = "trapmod" + UUID.randomUUID().toString().replace("-", "") - compileModule(trap.module, className, null).instance(this) + buildModule(trap.module, className, null) throw ScriptAssertionError(trap, "Expected module init error with '${trap.failure}', was valid") } catch (e: Throwable) { assertFailure(trap, e, trap.failure) } } @@ -190,50 +183,29 @@ data class ScriptContext( is Script.Cmd.Action.Get -> doGet(cmd) } - fun doGet(cmd: Script.Cmd.Action.Get): Pair { + fun doGet(cmd: Script.Cmd.Action.Get): Number { // Grab last module or named one val module = if (cmd.name == null) modules.last() else modules.first { it.name == cmd.name } - // Just call the getter - val getter = module.cls.getDeclaredMethod("get" + cmd.string.javaIdent.capitalize()) - return getter.returnType.valueType!! to getter.invoke(module.instance(this)) + return module.exportedGlobal(cmd.string)!!.first.invokeWithArguments() as Number } - fun doInvoke(cmd: Script.Cmd.Action.Invoke): Pair { + fun doInvoke(cmd: Script.Cmd.Action.Invoke): Number? { // If there is a module name, use that index, otherwise just search. - val (compMod, method) = modules.filter { cmd.name == null || it.name == cmd.name }.flatMap { compMod -> - compMod.cls.declaredMethods.filter { it.name == cmd.string.javaIdent }.map { compMod to it } - }.let { methodPairs -> - // If there are multiple, we get the last one - if (methodPairs.isEmpty()) error("Unable to find method for invoke named ${cmd.string.javaIdent}") - else if (methodPairs.size == 1) methodPairs.single() - else methodPairs.last().also { debug { "Found multiple methods for ${cmd.string.javaIdent}, using last"} } - } - + val module = if (cmd.name == null) modules.last() else modules.first { it.name == cmd.name } // Invoke all parameter expressions - require(cmd.exprs.size == method.parameterTypes.size) - val params = cmd.exprs.zip(method.parameterTypes).map { (expr, paramType) -> - runExpr(expr, paramType.valueType!!) - } - + val mh = module.exportedFunc(cmd.string)!! + val paramTypes = mh.type().parameterList() + require(cmd.exprs.size == paramTypes.size) + val params = cmd.exprs.zip(paramTypes).map { (expr, paramType) -> runExpr(expr, paramType.valueType!!)!! } // Run returning the result - return method.returnType.valueType to method.invoke(compMod.instance(this), *params.toTypedArray()) - } - - fun runExpr(insns: List) { - MethodHandleUtil.invokeVoid(compileExpr(insns, null)) + return mh.invokeWithArguments(*params.toTypedArray()) as Number? } - fun runExpr(insns: List, retType: Node.Type.Value): Any = compileExpr(insns, retType).let { handle -> - when (retType) { - is Node.Type.Value.I32 -> MethodHandleUtil.invokeInt(handle) - is Node.Type.Value.I64 -> MethodHandleUtil.invokeLong(handle) - is Node.Type.Value.F32 -> MethodHandleUtil.invokeFloat(handle) - is Node.Type.Value.F64 -> MethodHandleUtil.invokeDouble(handle) - } - } + fun runExpr(insns: List, retType: Node.Type.Value?) = + buildExpr(insns, retType).exportedFunc("expr")!!.invokeWithArguments() as Number? - fun compileExpr(insns: List, retType: Node.Type.Value?): MethodHandle { - debug { "Compiling expression: $insns" } + fun buildExpr(insns: List, retType: Node.Type.Value?): Module { + debug { "Building expression: $insns" } val mod = Node.Module( exports = listOf(Node.Export("expr", Node.ExternalKind.FUNCTION, 0)), funcs = listOf(Node.Func( @@ -242,92 +214,43 @@ data class ScriptContext( instructions = insns )) ) - val className = "expr" + UUID.randomUUID().toString().replace("-", "") - val compiled = compileModule(mod, className, null) - return MethodHandles.lookup().bind(compiled.instance(this), "expr", - MethodType.methodType(retType?.jclass ?: Void.TYPE)) + return buildModule(mod, "expr" + UUID.randomUUID().toString().replace("-", ""), null) } - fun withCompiledModule(mod: Node.Module, className: String, name: String?) = - copy(modules = modules + compileModule(mod, className, name)) + fun withBuiltModule(mod: Node.Module, className: String, name: String?) = + copy(modules = modules + buildModule(mod, className, name)) - fun compileModule(mod: Node.Module, className: String, name: String?): Module.Compiled { - val ctx = ClsContext( - packageName = packageName, - className = className, - mod = mod, - logger = logger, - includeBinary = includeBinaryInCompiledClass - ).let(adjustContext) - AstToAsm.fromModule(ctx) - return Module.Compiled(mod, classLoader.fromBuiltContext(ctx), name, ctx.mem) - } - - fun bindImport(import: Node.Import, getter: Boolean, methodType: MethodType) = bindImport( - import, if (getter) "get" + import.field.javaIdent.capitalize() else import.field.javaIdent, methodType) + fun buildModule(mod: Node.Module, className: String, name: String?) = builder.build(this, mod, className, name) - fun bindImport(import: Node.Import, javaName: String, methodType: MethodType): MethodHandle { - // Find a method that matches our expectations - val module = registrations[import.module] ?: throw RunErr.ImportNotFound(import.module, import.field) - val kind = when (import.kind) { - is Node.Import.Kind.Func -> WasmExternalKind.FUNCTION - is Node.Import.Kind.Table -> WasmExternalKind.TABLE - is Node.Import.Kind.Memory -> WasmExternalKind.MEMORY - is Node.Import.Kind.Global -> WasmExternalKind.GLOBAL - } - return module.bindMethod(this, import.field, kind, javaName, methodType) ?: - throw RunErr.ImportNotFound(import.module, import.field) + override fun resolveImportFunc(module: String, field: String, type: Node.Type.Func): MethodHandle { + val hnd = registrations[module]?.exportedFunc(field) ?: throw RunErr.ImportNotFound(module, field) + val hndType = Node.Type.Func( + params = hnd.type().parameterList().map { it.valueType!! }, + ret = hnd.type().returnType().valueType + ) + if (hndType != type) throw RunErr.ImportIncompatible(module, field, type, hndType) + return hnd } - fun resolveImportFunc(import: Node.Import, funcType: Node.Type.Func) = - bindImport(import, false, - MethodType.methodType(funcType.ret?.jclass ?: Void.TYPE, funcType.params.map { it.jclass })) - - fun resolveImportGlobals(import: Node.Import, globalType: Node.Type.Global): List { - val getter = bindImport(import, true, MethodType.methodType(globalType.contentType.jclass)) - // Whether the setter is present or not defines whether it is mutable - val setter = try { - bindImport(import, "set" + import.field.javaIdent.capitalize(), - MethodType.methodType(Void.TYPE, globalType.contentType.jclass)) - } catch (e: RunErr.ImportNotFound) { null } - // Mutability must match - if (globalType.mutable == (setter == null)) - throw RunErr.ImportGlobalInvalidMutability(import.module, import.field, globalType.mutable) - return if (setter == null) listOf(getter) else listOf(getter, setter) + override fun resolveImportGlobal( + module: String, + field: String, + type: Node.Type.Global + ): Pair { + val hnd = registrations[module]?.exportedGlobal(field) ?: throw RunErr.ImportNotFound(module, field) + if (!hnd.first.type().returnType().isPrimitive) throw RunErr.ImportNotFound(module, field) + val hndType = Node.Type.Global( + contentType = hnd.first.type().returnType().valueType!!, + mutable = hnd.second != null + ) + if (hndType != type) throw RunErr.ImportIncompatible(module, field, type, hndType) + return hnd } - fun resolveImportMemory(import: Node.Import, memoryType: Node.Type.Memory, mem: Mem) = - bindImport(import, true, MethodType.methodType(Class.forName(mem.memType.asm.className))). - invokeWithArguments()!! - @Suppress("UNCHECKED_CAST") - fun resolveImportTable(import: Node.Import, tableType: Node.Type.Table) = - bindImport(import, true, MethodType.methodType(Array::class.java)). - invokeWithArguments()!! as Array - - open class SimpleClassLoader( - parent: ClassLoader, - logger: Logger, - val splitWhenTooLarge: Boolean = true - ) : ClassLoader(parent), Logger by logger { - fun fromBuiltContext(ctx: ClsContext): Class<*> { - trace { "Computing frames for ASM class:\n" + ctx.cls.toAsmString() } - val writer = if (splitWhenTooLarge) AsmToBinary else AsmToBinary.noSplit - return writer.fromClassNode(ctx.cls).let { bytes -> - debug { "ASM class:\n" + bytes.asClassNode().toAsmString() } - defineClass("${ctx.packageName}.${ctx.className}", bytes, 0, bytes.size) - } - } + override fun resolveImportMemory(module: String, field: String, type: Node.Type.Memory, memClass: Class) = + registrations[module]?.exportedMemory(field, memClass) ?: throw RunErr.ImportNotFound(module, field) - fun addClass(bytes: ByteArray) { - // Just get the name - var className = "" - ClassReader(bytes).accept(object : ClassVisitor(Opcodes.ASM5) { - override fun visit(a: Int, b: Int, name: String, c: String?, d: String?, e: Array?) { - className = name.replace('/', '.') - } - }, ClassReader.SKIP_CODE) - defineClass(className, bytes, 0, bytes.size) - } - } + override fun resolveImportTable(module: String, field: String, type: Node.Type.Table) = + registrations[module]?.exportedTable(field) ?: throw RunErr.ImportNotFound(module, field) } \ No newline at end of file diff --git a/compiler/src/main/kotlin/asmble/run/jvm/interpret/Imports.kt b/compiler/src/main/kotlin/asmble/run/jvm/interpret/Imports.kt new file mode 100644 index 0000000..90e82ae --- /dev/null +++ b/compiler/src/main/kotlin/asmble/run/jvm/interpret/Imports.kt @@ -0,0 +1,35 @@ +package asmble.run.jvm.interpret + +import asmble.ast.Node +import java.lang.invoke.MethodHandle +import java.nio.ByteBuffer + +interface Imports { + fun invokeFunction(module: String, field: String, type: Node.Type.Func, args: List): Number? + fun getGlobal(module: String, field: String, type: Node.Type.Global): Number + fun setGlobal(module: String, field: String, type: Node.Type.Global, value: Number) + fun getMemory(module: String, field: String, type: Node.Type.Memory): ByteBuffer + fun getTable(module: String, field: String, type: Node.Type.Table): Array + + object None : Imports { + override fun invokeFunction( + module: String, + field: String, + type: Node.Type.Func, + args: List + ) = throw NotImplementedError("Import function $module.$field not implemented") + + override fun getGlobal(module: String, field: String, type: Node.Type.Global) = + throw NotImplementedError("Import global $module.$field not implemented") + + override fun setGlobal(module: String, field: String, type: Node.Type.Global, value: Number) { + throw NotImplementedError("Import global $module.$field not implemented") + } + + override fun getMemory(module: String, field: String, type: Node.Type.Memory) = + throw NotImplementedError("Import memory $module.$field not implemented") + + override fun getTable(module: String, field: String, type: Node.Type.Table) = + throw NotImplementedError("Import table $module.$field not implemented") + } +} \ No newline at end of file diff --git a/compiler/src/main/kotlin/asmble/run/jvm/interpret/InterpretErr.kt b/compiler/src/main/kotlin/asmble/run/jvm/interpret/InterpretErr.kt new file mode 100644 index 0000000..77fc942 --- /dev/null +++ b/compiler/src/main/kotlin/asmble/run/jvm/interpret/InterpretErr.kt @@ -0,0 +1,67 @@ +package asmble.run.jvm.interpret + +import asmble.AsmErr +import asmble.ast.Node + +sealed class InterpretErr(message: String, cause: Throwable? = null) : RuntimeException(message, cause), AsmErr { + + class IndirectCallTypeMismatch( + val expected: Node.Type.Func, + val actual: Node.Type.Func + ) : InterpretErr("Expecting func type $expected, got $actual") { + override val asmErrString get() = "indirect call type mismatch" + } + + class InvalidCallResult( + val expected: Node.Type.Value?, + val actual: Number? + ) : InterpretErr("Expected call result to be $expected, got $actual") + + class EndReached(returned: Number?) : InterpretErr("Reached end of invocation") + + class StartFuncParamMismatch( + val expected: List, + val actual: List + ) : InterpretErr("Can't call start func, expected params $expected, got $actual") + + class OutOfBoundsMemory( + val index: Int, + val offset: Long + ) : InterpretErr("Unable to access mem $index + offset $offset") { + override val asmErrString get() = "out of bounds memory access" + } + + class UndefinedElement( + val index: Int + ) : InterpretErr("No table element for index $index") { + override val asmErrString get() = "undefined element" + override val asmErrStrings get() = listOf(asmErrString, "uninitialized element") + } + + class TruncIntegerNaN( + val orig: Number, + val target: Node.Type.Value, + val signed: Boolean + ) : InterpretErr("Invalid to trunc $orig to $target " + if (signed) "signed" else "unsigned") { + override val asmErrString get() = "invalid conversion to integer" + } + + class TruncIntegerOverflow( + val orig: Number, + val target: Node.Type.Value, + val signed: Boolean + ) : InterpretErr("Integer overflow attempting to trunc $orig to $target " + if (signed) "signed" else "unsigned") { + override val asmErrString get() = "integer overflow" + } + + class SignedDivOverflow( + val a: Number, + val b: Number + ) : InterpretErr("Integer overflow attempting $a / $b") { + override val asmErrString get() = "integer overflow" + } + + class StackOverflow(val max: Int) : InterpretErr("Call stack exceeeded $max depth") { + override val asmErrString get() = "call stack exhausted" + } +} \ No newline at end of file diff --git a/compiler/src/main/kotlin/asmble/run/jvm/interpret/Interpreter.kt b/compiler/src/main/kotlin/asmble/run/jvm/interpret/Interpreter.kt new file mode 100644 index 0000000..da647cc --- /dev/null +++ b/compiler/src/main/kotlin/asmble/run/jvm/interpret/Interpreter.kt @@ -0,0 +1,672 @@ +package asmble.run.jvm.interpret + +import asmble.ast.Node +import asmble.compile.jvm.* +import asmble.run.jvm.RunErr +import asmble.util.Either +import asmble.util.Logger +import asmble.util.toUnsignedInt +import asmble.util.toUnsignedLong +import java.lang.invoke.MethodHandle +import java.lang.invoke.MethodHandles +import java.lang.invoke.MethodType +import java.nio.ByteBuffer +import java.nio.ByteOrder + +// This is not intended to be fast, rather clear and easy to read. Very little cached/memoized, lots of extra cycles. +open class Interpreter { + + fun execFunc( + ctx: Context, + funcIndex: Int = ctx.mod.startFuncIndex ?: error("No start func index"), + vararg funcArgs: Number + ): Number? { + // Check params + val funcType = ctx.funcTypeAtIndex(funcIndex) + funcArgs.mapNotNull { it.valueType }.let { + if (it != funcType.params) throw InterpretErr.StartFuncParamMismatch(funcType.params, it) + } + // Import functions are executed inline and returned + ctx.importFuncs.getOrNull(funcIndex)?.also { + return ctx.imports.invokeFunction(it.module, it.field, funcType, funcArgs.toList()) + } + // This is the call stack we need to stop at + val startingCallStackSize = ctx.callStack.size + // Make the call on the context + var lastStep: StepResult = StepResult.Call(funcIndex, funcArgs.toList(), funcType) + // Run until done + while (lastStep !is StepResult.Return || ctx.callStack.size > startingCallStackSize + 1) { + next(ctx, lastStep) + lastStep = step(ctx) + } + ctx.callStack.subList(startingCallStackSize + 1, ctx.callStack.size).clear() + return lastStep.v + } + + fun step(ctx: Context): StepResult = ctx.currFuncCtx.run { + // If the insn is out of bounds, it's an implicit return, otherwise just execute the insn + if (insnIndex >= func.instructions.size) StepResult.Return(func.type.ret?.let { pop(it) }) + else invokeSingle(ctx) + } + + // Errors with InterpretErr.EndReached if there is no next step to be had + fun next(ctx: Context, step: StepResult) { + ctx.logger.trace { + "NEXT: $step " + + "[VAL STACK: ${ctx.maybeCurrFuncCtx?.valueStack}] " + + "[CALL STACK DEPTH: ${ctx.callStack.size}]" + } + when (step) { + // Next just moves the counter + is StepResult.Next -> ctx.currFuncCtx.insnIndex++ + // Branch updates the stack and moves the insn index + is StepResult.Branch -> ctx.currFuncCtx.run { + // A failed if just jumps to the else or end + if (step.failedIf) { + require(step.blockDepth == 0) + val block = blockStack.last() + if (block.elseIndex != null) insnIndex = block.elseIndex + 1 + else { + insnIndex = block.endIndex + 1 + blockStack.removeAt(blockStack.size - 1) + } + } else { + // Remove all blocks until the depth requested + blockStack.subList(blockStack.size - step.blockDepth, blockStack.size).clear() + // This can break out of the entire function + if (blockStack.isEmpty()) { + // Grab the stack item if present, blow away stack, put back, and move to end of func + val retVal = func.type.ret?.let { pop(it) } + valueStack.clear() + retVal?.also { push(it) } + insnIndex = func.instructions.size + } else if (blockStack.last().insn is Node.Instr.Loop && !step.forceEndOnLoop) { + // It's just a loop continuation, go back to top + insnIndex = blockStack.last().startIndex + 1 + } else { + // Remove the one at the depth requested + val block = blockStack.removeAt(blockStack.size - 1) + // Pop value if applicable + val blockVal = block.insn.type?.let { pop(it) } + // Trim the stack down to required size + valueStack.subList(block.stackSizeAtStart, valueStack.size).clear() + // Put the value back on if applicable + blockVal?.also { push(it) } + // Jump past the end + insnIndex = block.endIndex + 1 + } + } + } + // Call, if import, invokes it and puts result on stack. If not, just pushes a new func context. + is StepResult.Call -> ctx.funcAtIndex(step.funcIndex).let { + when (it) { + // If import, call and just put on stack and advance insn if came from insn + is Either.Left -> + ctx.imports.invokeFunction(it.v.module, it.v.field, step.type, step.args).also { + // Make sure result type is accurate + if (it.valueType != step.type.ret) + throw InterpretErr.InvalidCallResult(step.type.ret, it) + it?.also { ctx.currFuncCtx.push(it) } + ctx.currFuncCtx.insnIndex++ + } + // If inside the module, create new context to continue + is Either.Right -> ctx.callStack += FuncContext(it.v).also { funcCtx -> + // Set the args + step.args.forEachIndexed { index, arg -> funcCtx.locals[index] = arg } + } + } + } + // Call indirect is just an MH invocation + is StepResult.CallIndirect -> { + val mh = ctx.table?.getOrNull(step.tableIndex) ?: error("Missing table entry") + val res = mh.invokeWithArguments(step.args) as? Number? + if (res.valueType != step.type.ret) throw InterpretErr.InvalidCallResult(step.type.ret, res) + res?.also { ctx.currFuncCtx.push(it) } + ctx.currFuncCtx.insnIndex++ + } + // Unreachable throws + is StepResult.Unreachable -> throw UnsupportedOperationException("Unreachable") + // Return pops curr func from the call stack, push ret and move insn on prev one + is StepResult.Return -> ctx.callStack.removeAt(ctx.callStack.lastIndex).let { returnedFrom -> + if (ctx.callStack.isEmpty()) throw InterpretErr.EndReached(step.v) + if (returnedFrom.valueStack.isNotEmpty()) + throw CompileErr.UnusedStackOnReturn(returnedFrom.valueStack.map { it::class.ref } ) + step.v?.also { ctx.currFuncCtx.push(it) } + ctx.currFuncCtx.insnIndex++ + } + } + } + + fun invokeSingle(ctx: Context): StepResult = ctx.currFuncCtx.run { + // TODO: validation + func.instructions[insnIndex].let { insn -> + ctx.logger.trace { "INSN #$insnIndex: $insn [STACK: $valueStack]" } + when (insn) { + is Node.Instr.Unreachable -> StepResult.Unreachable + is Node.Instr.Nop -> next { } + is Node.Instr.Block, is Node.Instr.Loop -> next { + blockStack += Block(insnIndex, insn as Node.Instr.Args.Type, valueStack.size, currentBlockEnd()!!) + } + is Node.Instr.If -> { + blockStack += Block(insnIndex, insn, valueStack.size - 1, currentBlockEnd()!!, currentBlockElse()) + if (popInt() == 0) StepResult.Branch(0, failedIf = true) else StepResult.Next + } + is Node.Instr.Else -> + // Jump over the whole thing and to the end, this can only be gotten here via if + StepResult.Branch(0) + is Node.Instr.End -> + // Since we reached the end by manually running through it, jump to end even on loop + StepResult.Branch(0, forceEndOnLoop = true) + is Node.Instr.Br -> StepResult.Branch(insn.relativeDepth) + is Node.Instr.BrIf -> if (popInt() != 0) StepResult.Branch(insn.relativeDepth) else StepResult.Next + is Node.Instr.BrTable -> StepResult.Branch(insn.targetTable.getOrNull(popInt()) + ?: insn.default) + is Node.Instr.Return -> + StepResult.Return(func.type.ret?.let { pop(it) }) + is Node.Instr.Call -> ctx.funcTypeAtIndex(insn.index).let { + ctx.checkNextIsntStackOverflow() + StepResult.Call(insn.index, popCallArgs(it), it) + } + is Node.Instr.CallIndirect -> { + ctx.checkNextIsntStackOverflow() + val tableIndex = popInt() + val expectedType = ctx.typeAtIndex(insn.index).also { + val tableMh = ctx.table?.getOrNull(tableIndex) ?: + throw InterpretErr.UndefinedElement(tableIndex) + val actualType = Node.Type.Func( + params = tableMh.type().parameterList().map { it.valueType!! }, + ret = tableMh.type().returnType().valueType + ) + if (it != actualType) throw InterpretErr.IndirectCallTypeMismatch(it, actualType) + } + StepResult.CallIndirect(tableIndex, popCallArgs(expectedType), expectedType) + } + is Node.Instr.Drop -> next { pop() } + is Node.Instr.Select -> next { + popInt().also { + val v2 = pop() + val v1 = pop() + if (v1::class != v2::class) + throw CompileErr.SelectMismatch(v1.valueType!!.typeRef, v2.valueType!!.typeRef) + if (it != 0) push(v1) else push(v2) + } + } + is Node.Instr.GetLocal -> next { push(locals[insn.index]) } + is Node.Instr.SetLocal -> next { locals[insn.index] = pop()} + is Node.Instr.TeeLocal -> next { locals[insn.index] = peek() } + is Node.Instr.GetGlobal -> next { push(ctx.getGlobal(insn.index)) } + is Node.Instr.SetGlobal -> next { ctx.setGlobal(insn.index, pop()) } + is Node.Instr.I32Load -> next { push(ctx.mem.getInt(insn.popMemAddr())) } + is Node.Instr.I64Load -> next { push(ctx.mem.getLong(insn.popMemAddr())) } + is Node.Instr.F32Load -> next { push(ctx.mem.getFloat(insn.popMemAddr())) } + is Node.Instr.F64Load -> next { push(ctx.mem.getDouble(insn.popMemAddr())) } + is Node.Instr.I32Load8S -> next { push(ctx.mem.get(insn.popMemAddr()).toInt()) } + is Node.Instr.I32Load8U -> next { push(ctx.mem.get(insn.popMemAddr()).toUnsignedInt()) } + is Node.Instr.I32Load16S -> next { push(ctx.mem.getShort(insn.popMemAddr()).toInt()) } + is Node.Instr.I32Load16U -> next { push(ctx.mem.getShort(insn.popMemAddr()).toUnsignedInt()) } + is Node.Instr.I64Load8S -> next { push(ctx.mem.get(insn.popMemAddr()).toLong()) } + is Node.Instr.I64Load8U -> next { push(ctx.mem.get(insn.popMemAddr()).toUnsignedLong()) } + is Node.Instr.I64Load16S -> next { push(ctx.mem.getShort(insn.popMemAddr()).toLong()) } + is Node.Instr.I64Load16U -> next { push(ctx.mem.getShort(insn.popMemAddr()).toUnsignedLong()) } + is Node.Instr.I64Load32S -> next { push(ctx.mem.getInt(insn.popMemAddr()).toLong()) } + is Node.Instr.I64Load32U -> next { push(ctx.mem.getInt(insn.popMemAddr()).toUnsignedLong()) } + is Node.Instr.I32Store -> next { popInt().let { ctx.mem.putInt(insn.popMemAddr(), it) } } + is Node.Instr.I64Store -> next { popLong().let { ctx.mem.putLong(insn.popMemAddr(), it) } } + is Node.Instr.F32Store -> next { popFloat().let { ctx.mem.putFloat(insn.popMemAddr(), it) } } + is Node.Instr.F64Store -> next { popDouble().let { ctx.mem.putDouble(insn.popMemAddr(), it) } } + is Node.Instr.I32Store8 -> next { popInt().let { ctx.mem.put(insn.popMemAddr(), it.toByte()) } } + is Node.Instr.I32Store16 -> next { popInt().let { ctx.mem.putShort(insn.popMemAddr(), it.toShort()) } } + is Node.Instr.I64Store8 -> next { popLong().let { ctx.mem.put(insn.popMemAddr(), it.toByte()) } } + is Node.Instr.I64Store16 -> next { popLong().let { ctx.mem.putShort(insn.popMemAddr(), it.toShort()) } } + is Node.Instr.I64Store32 -> next { popLong().let { ctx.mem.putInt(insn.popMemAddr(), it.toInt()) } } + is Node.Instr.MemorySize -> next { push(ctx.mem.limit() / Mem.PAGE_SIZE) } + is Node.Instr.MemoryGrow -> next { + val newLim = ctx.mem.limit().toLong() + (popInt().toLong() * Mem.PAGE_SIZE) + if (newLim > ctx.mem.capacity()) push(-1) + else (ctx.mem.limit() / Mem.PAGE_SIZE).also { + push(it) + ctx.mem.limit(newLim.toInt()) + } + } + is Node.Instr.I32Const -> next { push(insn.value) } + is Node.Instr.I64Const -> next { push(insn.value) } + is Node.Instr.F32Const -> next { push(insn.value) } + is Node.Instr.F64Const -> next { push(insn.value) } + is Node.Instr.I32Eqz -> next { push(popInt() == 0) } + is Node.Instr.I32Eq -> nextBinOp(popInt(), popInt()) { a, b -> a == b } + is Node.Instr.I32Ne -> nextBinOp(popInt(), popInt()) { a, b -> a != b } + is Node.Instr.I32LtS -> nextBinOp(popInt(), popInt()) { a, b -> a < b } + is Node.Instr.I32LtU -> nextBinOp(popInt(), popInt()) { a, b -> Integer.compareUnsigned(a, b) < 0 } + is Node.Instr.I32GtS -> nextBinOp(popInt(), popInt()) { a, b -> a > b } + is Node.Instr.I32GtU -> nextBinOp(popInt(), popInt()) { a, b -> Integer.compareUnsigned(a, b) > 0 } + is Node.Instr.I32LeS -> nextBinOp(popInt(), popInt()) { a, b -> a <= b } + is Node.Instr.I32LeU -> nextBinOp(popInt(), popInt()) { a, b -> Integer.compareUnsigned(a, b) <= 0 } + is Node.Instr.I32GeS -> nextBinOp(popInt(), popInt()) { a, b -> a >= b } + is Node.Instr.I32GeU -> nextBinOp(popInt(), popInt()) { a, b -> Integer.compareUnsigned(a, b) >= 0 } + is Node.Instr.I64Eqz -> next { push(popLong() == 0L) } + is Node.Instr.I64Eq -> nextBinOp(popLong(), popLong()) { a, b -> a == b } + is Node.Instr.I64Ne -> nextBinOp(popLong(), popLong()) { a, b -> a != b } + is Node.Instr.I64LtS -> nextBinOp(popLong(), popLong()) { a, b -> a < b } + is Node.Instr.I64LtU -> nextBinOp(popLong(), popLong()) { a, b -> java.lang.Long.compareUnsigned(a, b) < 0 } + is Node.Instr.I64GtS -> nextBinOp(popLong(), popLong()) { a, b -> a > b } + is Node.Instr.I64GtU -> nextBinOp(popLong(), popLong()) { a, b -> java.lang.Long.compareUnsigned(a, b) > 0 } + is Node.Instr.I64LeS -> nextBinOp(popLong(), popLong()) { a, b -> a <= b } + is Node.Instr.I64LeU -> nextBinOp(popLong(), popLong()) { a, b -> java.lang.Long.compareUnsigned(a, b) <= 0 } + is Node.Instr.I64GeS -> nextBinOp(popLong(), popLong()) { a, b -> a >= b } + is Node.Instr.I64GeU -> nextBinOp(popLong(), popLong()) { a, b -> java.lang.Long.compareUnsigned(a, b) >= 0 } + is Node.Instr.F32Eq -> nextBinOp(popFloat(), popFloat()) { a, b -> a == b } + is Node.Instr.F32Ne -> nextBinOp(popFloat(), popFloat()) { a, b -> a != b } + is Node.Instr.F32Lt -> nextBinOp(popFloat(), popFloat()) { a, b -> a < b } + is Node.Instr.F32Gt -> nextBinOp(popFloat(), popFloat()) { a, b -> a > b } + is Node.Instr.F32Le -> nextBinOp(popFloat(), popFloat()) { a, b -> a <= b } + is Node.Instr.F32Ge -> nextBinOp(popFloat(), popFloat()) { a, b -> a >= b } + is Node.Instr.F64Eq -> nextBinOp(popDouble(), popDouble()) { a, b -> a == b } + is Node.Instr.F64Ne -> nextBinOp(popDouble(), popDouble()) { a, b -> a != b } + is Node.Instr.F64Lt -> nextBinOp(popDouble(), popDouble()) { a, b -> a < b } + is Node.Instr.F64Gt -> nextBinOp(popDouble(), popDouble()) { a, b -> a > b } + is Node.Instr.F64Le -> nextBinOp(popDouble(), popDouble()) { a, b -> a <= b } + is Node.Instr.F64Ge -> nextBinOp(popDouble(), popDouble()) { a, b -> a >= b } + is Node.Instr.I32Clz -> next { push(Integer.numberOfLeadingZeros(popInt())) } + is Node.Instr.I32Ctz -> next { push(Integer.numberOfTrailingZeros(popInt())) } + is Node.Instr.I32Popcnt -> next { push(Integer.bitCount(popInt())) } + is Node.Instr.I32Add -> nextBinOp(popInt(), popInt()) { a, b -> a + b } + is Node.Instr.I32Sub -> nextBinOp(popInt(), popInt()) { a, b -> a - b } + is Node.Instr.I32Mul -> nextBinOp(popInt(), popInt()) { a, b -> a * b } + is Node.Instr.I32DivS -> nextBinOp(popInt(), popInt()) { a, b -> + ctx.checkedSignedDivInteger(a, b) + a / b + } + is Node.Instr.I32DivU -> nextBinOp(popInt(), popInt()) { a, b -> Integer.divideUnsigned(a, b) } + is Node.Instr.I32RemS -> nextBinOp(popInt(), popInt()) { a, b -> a % b } + is Node.Instr.I32RemU -> nextBinOp(popInt(), popInt()) { a, b -> Integer.remainderUnsigned(a, b) } + is Node.Instr.I32And -> nextBinOp(popInt(), popInt()) { a, b -> a and b } + is Node.Instr.I32Or -> nextBinOp(popInt(), popInt()) { a, b -> a or b } + is Node.Instr.I32Xor -> nextBinOp(popInt(), popInt()) { a, b -> a xor b } + is Node.Instr.I32Shl -> nextBinOp(popInt(), popInt()) { a, b -> a shl b } + is Node.Instr.I32ShrS -> nextBinOp(popInt(), popInt()) { a, b -> a shr b } + is Node.Instr.I32ShrU -> nextBinOp(popInt(), popInt()) { a, b -> a ushr b } + is Node.Instr.I32Rotl -> nextBinOp(popInt(), popInt()) { a, b -> Integer.rotateLeft(a, b) } + is Node.Instr.I32Rotr -> nextBinOp(popInt(), popInt()) { a, b -> Integer.rotateRight(a, b) } + is Node.Instr.I64Clz -> next { push(java.lang.Long.numberOfLeadingZeros(popLong()).toLong()) } + is Node.Instr.I64Ctz -> next { push(java.lang.Long.numberOfTrailingZeros(popLong()).toLong()) } + is Node.Instr.I64Popcnt -> next { push(java.lang.Long.bitCount(popLong()).toLong()) } + is Node.Instr.I64Add -> nextBinOp(popLong(), popLong()) { a, b -> a + b } + is Node.Instr.I64Sub -> nextBinOp(popLong(), popLong()) { a, b -> a - b } + is Node.Instr.I64Mul -> nextBinOp(popLong(), popLong()) { a, b -> a * b } + is Node.Instr.I64DivS -> nextBinOp(popLong(), popLong()) { a, b -> + ctx.checkedSignedDivInteger(a, b) + a / b + } + is Node.Instr.I64DivU -> nextBinOp(popLong(), popLong()) { a, b -> java.lang.Long.divideUnsigned(a, b) } + is Node.Instr.I64RemS -> nextBinOp(popLong(), popLong()) { a, b -> a % b } + is Node.Instr.I64RemU -> nextBinOp(popLong(), popLong()) { a, b -> java.lang.Long.remainderUnsigned(a, b) } + is Node.Instr.I64And -> nextBinOp(popLong(), popLong()) { a, b -> a and b } + is Node.Instr.I64Or -> nextBinOp(popLong(), popLong()) { a, b -> a or b } + is Node.Instr.I64Xor -> nextBinOp(popLong(), popLong()) { a, b -> a xor b } + is Node.Instr.I64Shl -> nextBinOp(popLong(), popLong()) { a, b -> a shl b.toInt() } + is Node.Instr.I64ShrS -> nextBinOp(popLong(), popLong()) { a, b -> a shr b.toInt() } + is Node.Instr.I64ShrU -> nextBinOp(popLong(), popLong()) { a, b -> a ushr b.toInt() } + is Node.Instr.I64Rotl -> nextBinOp(popLong(), popLong()) { a, b -> java.lang.Long.rotateLeft(a, b.toInt()) } + is Node.Instr.I64Rotr -> nextBinOp(popLong(), popLong()) { a, b -> java.lang.Long.rotateRight(a, b.toInt()) } + is Node.Instr.F32Abs -> next { push(Math.abs(popFloat())) } + is Node.Instr.F32Neg -> next { push(-popFloat()) } + is Node.Instr.F32Ceil -> next { push(Math.ceil(popFloat().toDouble()).toFloat()) } + is Node.Instr.F32Floor -> next { push(Math.floor(popFloat().toDouble()).toFloat()) } + is Node.Instr.F32Trunc -> next { + popFloat().toDouble().let { push((if (it >= 0.0) Math.floor(it) else Math.ceil(it)).toFloat()) } + } + is Node.Instr.F32Nearest -> next { push(Math.rint(popFloat().toDouble()).toFloat()) } + is Node.Instr.F32Sqrt -> next { push(Math.sqrt(popFloat().toDouble()).toFloat()) } + is Node.Instr.F32Add -> nextBinOp(popFloat(), popFloat()) { a, b -> a + b } + is Node.Instr.F32Sub -> nextBinOp(popFloat(), popFloat()) { a, b -> a - b } + is Node.Instr.F32Mul -> nextBinOp(popFloat(), popFloat()) { a, b -> a * b } + is Node.Instr.F32Div -> nextBinOp(popFloat(), popFloat()) { a, b -> a / b } + is Node.Instr.F32Min -> nextBinOp(popFloat(), popFloat()) { a, b -> Math.min(a, b) } + is Node.Instr.F32Max -> nextBinOp(popFloat(), popFloat()) { a, b -> Math.max(a, b) } + is Node.Instr.F32CopySign -> nextBinOp(popFloat(), popFloat()) { a, b -> Math.copySign(a, b) } + is Node.Instr.F64Abs -> next { push(Math.abs(popDouble())) } + is Node.Instr.F64Neg -> next { push(-popDouble()) } + is Node.Instr.F64Ceil -> next { push(Math.ceil(popDouble())) } + is Node.Instr.F64Floor -> next { push(Math.floor(popDouble())) } + is Node.Instr.F64Trunc -> next { + popDouble().let { push((if (it >= 0.0) Math.floor(it) else Math.ceil(it))) } + } + is Node.Instr.F64Nearest -> next { push(Math.rint(popDouble())) } + is Node.Instr.F64Sqrt -> next { push(Math.sqrt(popDouble())) } + is Node.Instr.F64Add -> nextBinOp(popDouble(), popDouble()) { a, b -> a + b } + is Node.Instr.F64Sub -> nextBinOp(popDouble(), popDouble()) { a, b -> a - b } + is Node.Instr.F64Mul -> nextBinOp(popDouble(), popDouble()) { a, b -> a * b } + is Node.Instr.F64Div -> nextBinOp(popDouble(), popDouble()) { a, b -> a / b } + is Node.Instr.F64Min -> nextBinOp(popDouble(), popDouble()) { a, b -> Math.min(a, b) } + is Node.Instr.F64Max -> nextBinOp(popDouble(), popDouble()) { a, b -> Math.max(a, b) } + is Node.Instr.F64CopySign -> nextBinOp(popDouble(), popDouble()) { a, b -> Math.copySign(a, b) } + is Node.Instr.I32WrapI64 -> next { push(popLong().toInt()) } + // TODO: trunc traps on overflow! + is Node.Instr.I32TruncSF32 -> next { + push(ctx.checkedTrunc(popFloat(), true) { it.toInt() }) + } + is Node.Instr.I32TruncUF32 -> next { + push(ctx.checkedTrunc(popFloat(), false) { it.toLong().toInt() }) + } + is Node.Instr.I32TruncSF64 -> next { + push(ctx.checkedTrunc(popDouble(), true) { it.toInt() }) + } + is Node.Instr.I32TruncUF64 -> next { + push(ctx.checkedTrunc(popDouble(), false) { it.toLong().toInt() }) + } + is Node.Instr.I64ExtendSI32 -> next { push(popInt().toLong()) } + is Node.Instr.I64ExtendUI32 -> next { push(popInt().toUnsignedLong()) } + is Node.Instr.I64TruncSF32 -> next { + push(ctx.checkedTrunc(popFloat(), true) { it.toLong() }) + } + is Node.Instr.I64TruncUF32 -> next { + push(ctx.checkedTrunc(popFloat(), false) { + // If over max long, subtract and negate + if (it < 9223372036854775807f) it.toLong() + else (-9223372036854775808f + (it - 9223372036854775807f)).toLong() + }) + } + is Node.Instr.I64TruncSF64 -> next { + push(ctx.checkedTrunc(popDouble(), true) { it.toLong() }) + } + is Node.Instr.I64TruncUF64 -> next { + push(ctx.checkedTrunc(popDouble(), false) { + // If over max long, subtract and negate + if (it < 9223372036854775807.0) it.toLong() + else (-9223372036854775808.0 + (it - 9223372036854775807.0)).toLong() + }) + } + is Node.Instr.F32ConvertSI32 -> next { push(popInt().toFloat()) } + is Node.Instr.F32ConvertUI32 -> next { push(popInt().toUnsignedLong().toFloat()) } + is Node.Instr.F32ConvertSI64 -> next { push(popLong().toFloat()) } + is Node.Instr.F32ConvertUI64 -> next { + push(popLong().let { if (it >= 0) it.toFloat() else (it ushr 1).toFloat() * 2f }) + } + is Node.Instr.F32DemoteF64 -> next { push(popDouble().toFloat()) } + is Node.Instr.F64ConvertSI32 -> next { push(popInt().toDouble()) } + is Node.Instr.F64ConvertUI32 -> next { push(popInt().toUnsignedLong().toDouble()) } + is Node.Instr.F64ConvertSI64 -> next { push(popLong().toDouble()) } + is Node.Instr.F64ConvertUI64 -> next { + push(popLong().let { if (it >= 0) it.toDouble() else ((it ushr 1) or (it and 1)) * 2.0 }) + } + is Node.Instr.F64PromoteF32 -> next { push(popFloat().toDouble()) } + is Node.Instr.I32ReinterpretF32 -> next { push(java.lang.Float.floatToRawIntBits(popFloat())) } + is Node.Instr.I64ReinterpretF64 -> next { push(java.lang.Double.doubleToRawLongBits(popDouble())) } + is Node.Instr.F32ReinterpretI32 -> next { push(java.lang.Float.intBitsToFloat(popInt())) } + is Node.Instr.F64ReinterpretI64 -> next { push(java.lang.Double.longBitsToDouble(popLong())) } + } + } + } + + companion object : Interpreter() + + // Creating this does all the initialization except execute the start function + data class Context( + val mod: Node.Module, + val logger: Logger = Logger.Print(Logger.Level.OFF), + val imports: Imports = Imports.None, + val defaultMaxMemPages: Int = 1, + val memByteBufferDirect: Boolean = true, + val checkTruncOverflow: Boolean = true, + val checkSignedDivIntegerOverflow: Boolean = true, + val maximumCallStackDepth: Int = 3000 + ) { + val callStack = mutableListOf() + val currFuncCtx get() = callStack.last() + val maybeCurrFuncCtx get() = callStack.lastOrNull() + + val exportsByName = mod.exports.map { it.field to it }.toMap() + fun exportIndex(field: String, kind: Node.ExternalKind) = + exportsByName[field]?.takeIf { it.kind == kind }?.index + + val importGlobals = mod.imports.filter { it.kind is Node.Import.Kind.Global } + fun singleConstant(instrs: List): Number? = instrs.singleOrNull().let { instr -> + when (instr) { + is Node.Instr.Args.Const<*> -> instr.value + is Node.Instr.GetGlobal -> importGlobals.getOrNull(instr.index).let { + it ?: throw CompileErr.UnknownGlobal(instr.index) + imports.getGlobal(it.module, it.field, (it.kind as Node.Import.Kind.Global).type) + } + else -> null + } + } + + val maybeMem = run { + // Import it if we can, otherwise make it + val memImport = mod.imports.singleOrNull { it.kind is Node.Import.Kind.Memory } + // TODO: validate imported memory + val mem = + if (memImport != null) imports.getMemory( + memImport.module, + memImport.field, + (memImport.kind as Node.Import.Kind.Memory).type + ) else mod.memories.singleOrNull()?.let { memType -> + val max = (memType.limits.maximum ?: defaultMaxMemPages) * Mem.PAGE_SIZE + val mem = if (memByteBufferDirect) ByteBuffer.allocateDirect(max) else ByteBuffer.allocate(max) + mem.apply { + order(ByteOrder.LITTLE_ENDIAN) + limit(memType.limits.initial * Mem.PAGE_SIZE) + } + } + mem?.also { mem -> + // Load all data + mod.data.forEach { data -> + val pos = singleConstant(data.offset) as? Int ?: throw CompileErr.OffsetNotConstant() + if (pos < 0 || pos + data.data.size > mem.limit()) + throw RunErr.InvalidDataIndex(pos, data.data.size, mem.limit()) + mem.duplicate().apply { position(pos) }.put(data.data) + } + } + } + val mem get() = maybeMem ?: throw CompileErr.UnknownMemory(0) + + // TODO: some of this shares with the compiler's context, so how about some code reuse? + val importFuncs = mod.imports.filter { it.kind is Node.Import.Kind.Func } + fun typeAtIndex(index: Int) = mod.types.getOrNull(index) ?: throw CompileErr.UnknownType(index) + fun funcAtIndex(index: Int) = importFuncs.getOrNull(index).let { + when (it) { + null -> Either.Right(mod.funcs.getOrNull(index - importFuncs.size) ?: throw CompileErr.UnknownFunc(index)) + else -> Either.Left(it) + } + } + fun funcTypeAtIndex(index: Int) = funcAtIndex(index).let { + when (it) { + is Either.Left -> typeAtIndex((it.v.kind as Node.Import.Kind.Func).typeIndex) + is Either.Right -> it.v.type + } + } + fun boundFuncMethodHandleAtIndex(index: Int): MethodHandle { + val type = funcTypeAtIndex(index).let { + MethodType.methodType(it.ret?.jclass ?: Void.TYPE, it.params.map { it.jclass }) + } + val origMh = MethodHandles.lookup().bind(Interpreter.Companion, "execFunc", MethodType.methodType( + Number::class.java, Context::class.java, Int::class.java, Array::class.java)) + return MethodHandles.insertArguments(origMh, 0, this, index). + asVarargsCollector(Array::class.java).asType(type) + } + + val moduleGlobals = mod.globals.mapIndexed { index, global -> + // In MVP all globals have an init, it's either a const or an import read + val initVal = singleConstant(global.init) ?: throw CompileErr.GlobalInitNotConstant(index) + if (initVal.valueType != global.type.contentType) + throw CompileErr.GlobalConstantMismatch(index, global.type.contentType.typeRef, initVal::class.ref) + initVal + }.toMutableList() + fun globalTypeAtIndex(index: Int) = + (importGlobals.getOrNull(index)?.kind as? Node.Import.Kind.Global)?.type ?: + mod.globals[index - importGlobals.size].type + fun getGlobal(index: Int): Number = importGlobals.getOrNull(index).let { importGlobal -> + if (importGlobal != null) imports.getGlobal( + importGlobal.module, + importGlobal.field, + (importGlobal.kind as Node.Import.Kind.Global).type + ) else moduleGlobals.getOrNull(index - importGlobals.size) ?: error("No global") + } + fun setGlobal(index: Int, v: Number) { + val importGlobal = importGlobals.getOrNull(index) + if (importGlobal != null) imports.setGlobal( + importGlobal.module, + importGlobal.field, + (importGlobal.kind as Node.Import.Kind.Global).type, + v + ) else (index - importGlobals.size).also { index -> + require(index < moduleGlobals.size) + moduleGlobals[index] = v + } + } + + val table = run { + val importTable = mod.imports.singleOrNull { it.kind is Node.Import.Kind.Table } + val table = (importTable?.kind as? Node.Import.Kind.Table)?.type ?: mod.tables.singleOrNull() + if (table == null && mod.elems.isNotEmpty()) throw CompileErr.UnknownTable(0) + table?.let { table -> + // Create array either cloned from import or fresh + val arr = importTable?.let { imports.getTable(it.module, it.field, table) } ?: + arrayOfNulls(table.limits.initial) + // Now put all the elements in there + mod.elems.forEach { elem -> + require(elem.index == 0) + // Offset index always a constant or import + val offsetVal = singleConstant(elem.offset) as? Int ?: throw CompileErr.OffsetNotConstant() + // Still have to validate offset even if no func indexes + if (offsetVal < 0 || offsetVal + elem.funcIndices.size > arr.size) + throw RunErr.InvalidElemIndex(offsetVal, elem.funcIndices.size, arr.size) + elem.funcIndices.forEachIndexed { index, funcIndex -> + arr[offsetVal + index] = boundFuncMethodHandleAtIndex(funcIndex) + } + } + arr + } + } + + fun checkedTrunc(orig: Float, signed: Boolean, to: (Float) -> T) = to(orig).also { + if (checkTruncOverflow) { + if (orig.isNaN()) throw InterpretErr.TruncIntegerNaN(orig, it.valueType!!, signed) + val invalid = + (it is Int && signed && (orig < -2147483648f || orig >= 2147483648f)) || + (it is Int && !signed && (orig.toInt() < 0 || orig >= 4294967296f)) || + (it is Long && signed && (orig < -9223372036854775807f || orig >= 9223372036854775807f)) || + (it is Long && !signed && (orig.toInt() < 0 || orig >= 18446744073709551616f)) + if (invalid) throw InterpretErr.TruncIntegerOverflow(orig, it.valueType!!, signed) + } + } + + fun checkedTrunc(orig: Double, signed: Boolean, to: (Double) -> T) = to(orig).also { + if (checkTruncOverflow) { + if (orig.isNaN()) throw InterpretErr.TruncIntegerNaN(orig, it.valueType!!, signed) + val invalid = + (it is Int && signed && (orig < -2147483648.0 || orig >= 2147483648.0)) || + (it is Int && !signed && (orig.toInt() < 0 || orig >= 4294967296.0)) || + (it is Long && signed && (orig < -9223372036854775807.0 || orig >= 9223372036854775807.0)) || + (it is Long && !signed && (orig.toInt() < 0 || orig >= 18446744073709551616.0)) + if (invalid) throw InterpretErr.TruncIntegerOverflow(orig, it.valueType!!, signed) + } + } + + fun checkedSignedDivInteger(a: Int, b: Int) { + if (checkSignedDivIntegerOverflow && (a == Int.MIN_VALUE && b == -1)) + throw InterpretErr.SignedDivOverflow(a, b) + } + + fun checkedSignedDivInteger(a: Long, b: Long) { + if (checkSignedDivIntegerOverflow && (a == Long.MIN_VALUE && b == -1L)) + throw InterpretErr.SignedDivOverflow(a, b) + } + + fun checkNextIsntStackOverflow() { + // TODO: note this doesn't keep count of imports and their call stack + if (callStack.size + 1 >= maximumCallStackDepth) { + // We blow away the entire stack here so code can continue...could provide stack to + // exception if we wanted + callStack.clear() + throw InterpretErr.StackOverflow(maximumCallStackDepth) + } + } + } + + data class FuncContext( + val func: Node.Func, + val valueStack: MutableList = mutableListOf(), + val blockStack: MutableList = mutableListOf(), + var insnIndex: Int = 0 + ) { + val locals = (func.type.params + func.locals).map { + when (it) { + is Node.Type.Value.I32 -> 0 as Number + is Node.Type.Value.I64 -> 0L as Number + is Node.Type.Value.F32 -> 0f as Number + is Node.Type.Value.F64 -> 0.0 as Number + } + }.toMutableList() + + fun peek() = valueStack.last() + fun pop() = valueStack.removeAt(valueStack.size - 1) + fun popInt() = pop() as Int + fun popLong() = pop() as Long + fun popFloat() = pop() as Float + fun popDouble() = pop() as Double + fun Node.Instr.Args.AlignOffset.popMemAddr(): Int { + val v = popInt() + if (offset > Int.MAX_VALUE || offset + v > Int.MAX_VALUE) throw InterpretErr.OutOfBoundsMemory(v, offset) + return v + offset.toInt() + } + fun pop(type: Node.Type.Value): Number = when (type) { + is Node.Type.Value.I32 -> popInt() + is Node.Type.Value.I64 -> popLong() + is Node.Type.Value.F32 -> popFloat() + is Node.Type.Value.F64 -> popDouble() + } + fun popCallArgs(type: Node.Type.Func) = type.params.reversed().map(::pop).reversed() + fun push(v: Number) { valueStack += v } + fun push(v: Boolean) { valueStack += if (v) 1 else 0 } + + fun currentBlockEndOrElse(end: Boolean): Int? { + // Find the next end/else + var blockDepth = 0 + val index = func.instructions.drop(insnIndex + 1).indexOfFirst { insn -> + when (insn) { is Node.Instr.Block, is Node.Instr.Loop, is Node.Instr.If -> blockDepth++ } + val found = blockDepth == 0 && ((end && insn is Node.Instr.End) || (!end && insn is Node.Instr.Else)) + if (blockDepth > 0 && insn is Node.Instr.End) blockDepth-- + found + } + return if (index == -1) null else index + insnIndex + 1 + } + fun currentBlockEnd() = currentBlockEndOrElse(true) + fun currentBlockElse() = currentBlockEndOrElse(false) + + inline fun next(crossinline f: () -> Unit) = StepResult.Next.also { f() } + inline fun nextBinOp(second: T, first: U, crossinline f: (U, T) -> Any) = StepResult.Next.also { + val v = f(first, second) + if (v is Boolean) push(v) else push(v as Number) + } + } + + data class Block( + val startIndex: Int, + val insn: Node.Instr.Args.Type, + val stackSizeAtStart: Int, + val endIndex: Int, + val elseIndex: Int? = null + ) + + sealed class StepResult { + object Next : StepResult() + data class Branch( + val blockDepth: Int, + val failedIf: Boolean = false, + val forceEndOnLoop: Boolean = false + ) : StepResult() + data class Call( + val funcIndex: Int, + val args: List, + val type: Node.Type.Func + ) : StepResult() + data class CallIndirect( + val tableIndex: Int, + val args: List, + val type: Node.Type.Func + ) : StepResult() + object Unreachable : StepResult() + data class Return(val v: Number?) : StepResult() + } +} \ No newline at end of file diff --git a/compiler/src/main/kotlin/asmble/run/jvm/interpret/RunModule.kt b/compiler/src/main/kotlin/asmble/run/jvm/interpret/RunModule.kt new file mode 100644 index 0000000..7689adc --- /dev/null +++ b/compiler/src/main/kotlin/asmble/run/jvm/interpret/RunModule.kt @@ -0,0 +1,93 @@ +package asmble.run.jvm.interpret + +import asmble.ast.Node +import asmble.compile.jvm.jclass +import asmble.run.jvm.Module +import asmble.run.jvm.ModuleBuilder +import asmble.util.Logger +import java.lang.invoke.MethodHandles +import java.lang.invoke.MethodType +import java.nio.ByteBuffer + +class RunModule( + override val name: String?, + val ctx: Interpreter.Context +) : Module { + override fun exportedFunc(field: String) = + ctx.exportIndex(field, Node.ExternalKind.FUNCTION)?.let { ctx.boundFuncMethodHandleAtIndex(it) } + + override fun exportedGlobal(field: String) = ctx.exportIndex(field, Node.ExternalKind.GLOBAL)?.let { index -> + val type = ctx.globalTypeAtIndex(index) + val lookup = MethodHandles.lookup() + var getter = lookup.bind(ctx, "getGlobal", + MethodType.methodType(Number::class.java, Int::class.javaPrimitiveType)) + var setter = if (!type.mutable) null else lookup.bind(ctx, "setGlobal", MethodType.methodType( + Void::class.javaPrimitiveType, Int::class.javaPrimitiveType, Number::class.java)) + // Cast number to specific type + getter = MethodHandles.explicitCastArguments(getter, + MethodType.methodType(type.contentType.jclass, Int::class.javaPrimitiveType)) + if (setter != null) + setter = MethodHandles.explicitCastArguments(setter, MethodType.methodType( + Void::class.javaPrimitiveType, Int::class.javaPrimitiveType, type.contentType.jclass)) + // Insert the index argument up front + getter = MethodHandles.insertArguments(getter, 0, index) + if (setter != null) setter = MethodHandles.insertArguments(setter, 0, index) + getter to setter + } + + @SuppressWarnings("UNCHECKED_CAST") + override fun exportedMemory(field: String, memClass: Class) = + ctx.exportIndex(field, Node.ExternalKind.MEMORY)?.let { index -> + require(index == 0 && memClass == ByteBuffer::class.java) + ctx.maybeMem as? T? + } + + override fun exportedTable(field: String) = + ctx.exportIndex(field, Node.ExternalKind.TABLE)?.let { index -> + require(index == 0) + ctx.table + } + + class Builder( + val logger: Logger = Logger.Print(Logger.Level.OFF), + val defaultMaxMemPages: Int = 1, + val memByteBufferDirect: Boolean = true + ) : ModuleBuilder { + override fun build( + imports: Module.ImportResolver, + mod: Node.Module, + className: String, + name: String? + ) = RunModule( + name = name, + ctx = Interpreter.Context( + mod = mod, + logger = logger, + imports = ResolverImports(imports), + defaultMaxMemPages = defaultMaxMemPages, + memByteBufferDirect = memByteBufferDirect + ).also { ctx -> + // Run start function if present + mod.startFuncIndex?.also { Interpreter.execFunc(ctx, it) } + } + ) + } + + class ResolverImports(val res: Module.ImportResolver) : Imports { + override fun invokeFunction(module: String, field: String, type: Node.Type.Func, args: List) = + res.resolveImportFunc(module, field, type).invokeWithArguments(args) as Number? + + override fun getGlobal(module: String, field: String, type: Node.Type.Global) = + res.resolveImportGlobal(module, field, type).first.invokeWithArguments() as Number + + override fun setGlobal(module: String, field: String, type: Node.Type.Global, value: Number) { + res.resolveImportGlobal(module, field, type).second!!.invokeWithArguments(value) + } + + override fun getMemory(module: String, field: String, type: Node.Type.Memory) = + res.resolveImportMemory(module, field, type, ByteBuffer::class.java) + + override fun getTable(module: String, field: String, type: Node.Type.Table) = + res.resolveImportTable(module, field, type) + } +} \ No newline at end of file diff --git a/compiler/src/main/kotlin/asmble/util/NumExt.kt b/compiler/src/main/kotlin/asmble/util/NumExt.kt index 3dcf69f..35f0442 100644 --- a/compiler/src/main/kotlin/asmble/util/NumExt.kt +++ b/compiler/src/main/kotlin/asmble/util/NumExt.kt @@ -9,6 +9,8 @@ internal val MIN_INT32 = BigInteger.valueOf(Int.MIN_VALUE.toLong()) internal val MAX_UINT64 = BigInteger("ffffffffffffffff", 16) internal val MIN_INT64 = BigInteger.valueOf(Long.MIN_VALUE) +fun Byte.toUnsignedInt() = java.lang.Byte.toUnsignedInt(this) +fun Byte.toUnsignedLong() = java.lang.Byte.toUnsignedLong(this) fun Byte.toUnsignedShort() = (this.toInt() and 0xff).toShort() fun BigInteger.unsignedToSignedLong(): Long { @@ -40,4 +42,7 @@ fun Long.unsignedToSignedInt(): Int { return this.toInt() } -fun Long.Companion.valueOf(s: String, radix: Int = 10) = java.lang.Long.valueOf(s, radix) \ No newline at end of file +fun Long.Companion.valueOf(s: String, radix: Int = 10) = java.lang.Long.valueOf(s, radix) + +fun Short.toUnsignedInt() = java.lang.Short.toUnsignedInt(this) +fun Short.toUnsignedLong() = java.lang.Short.toUnsignedLong(this) \ No newline at end of file diff --git a/compiler/src/test/kotlin/asmble/BaseTestUnit.kt b/compiler/src/test/kotlin/asmble/BaseTestUnit.kt index 2cda40c..b8c7eef 100644 --- a/compiler/src/test/kotlin/asmble/BaseTestUnit.kt +++ b/compiler/src/test/kotlin/asmble/BaseTestUnit.kt @@ -22,5 +22,4 @@ open class BaseTestUnit(val name: String, val wast: String, val expectedOutput: } open val ast: List get() = parseResult.vals open val script: Script by lazy { SExprToAst.toScript(SExpr.Multi(ast)) } - open fun warningInsteadOfErrReason(t: Throwable): String? = null } \ No newline at end of file diff --git a/compiler/src/test/kotlin/asmble/SpecTestUnit.kt b/compiler/src/test/kotlin/asmble/SpecTestUnit.kt index 62fa78a..c8a1b48 100644 --- a/compiler/src/test/kotlin/asmble/SpecTestUnit.kt +++ b/compiler/src/test/kotlin/asmble/SpecTestUnit.kt @@ -1,8 +1,5 @@ package asmble -import asmble.ast.Node -import asmble.ast.Script -import asmble.run.jvm.ScriptAssertionError import java.nio.file.FileSystems import java.nio.file.Files import java.nio.file.Paths @@ -20,46 +17,6 @@ class SpecTestUnit(name: String, wast: String, expectedOutput: String?) : BaseTe else -> 2 } - override fun warningInsteadOfErrReason(t: Throwable) = when (name) { - "binary" -> { - val expectedFailure = ((t as? ScriptAssertionError)?.assertion as? Script.Cmd.Assertion.Malformed)?.failure - // TODO: Pending answer to https://github.com/WebAssembly/spec/pull/882#issuecomment-426349365 - if (expectedFailure == "integer too large") "Binary test changed" else null - } - // NaN bit patterns can be off - "float_literals", "float_exprs", "float_misc" -> - if (isNanMismatch(t)) "NaN JVM bit patterns can be off" else null - // We don't hold table capacity right now - // TODO: Figure out how we want to store/retrieve table capacity. Right now - // a table is an array, so there is only size not capacity. Since we want to - // stay w/ the stdlib of the JVM, the best option may be to store the capacity - // as a separate int value and query it or pass it around via import as - // necessary. I guess I could use a vector, but it's not worth it just for - // capacity since you lose speed. - "imports" -> { - val isTableMaxErr = t is ScriptAssertionError && (t.assertion as? Script.Cmd.Assertion.Unlinkable).let { - it != null && it.failure == "incompatible import type" && - it.module.imports.singleOrNull()?.kind is Node.Import.Kind.Table - } - if (isTableMaxErr) "Table max capacities are not validated" else null - } - else -> null - } - - private fun isNanMismatch(t: Throwable) = t is ScriptAssertionError && ( - t.assertion is Script.Cmd.Assertion.ReturnNan || - (t.assertion is Script.Cmd.Assertion.Return && (t.assertion as Script.Cmd.Assertion.Return).let { - it.exprs.any { it.any(this::insnIsNanConst) } || - ((it.action as? Script.Cmd.Action.Invoke)?.string?.contains("nan") ?: false) - }) - ) - - private fun insnIsNanConst(i: Node.Instr) = when (i) { - is Node.Instr.F32Const -> i.value.isNaN() - is Node.Instr.F64Const -> i.value.isNaN() - else -> false - } - companion object { val unitsPath = "/spec/test/core" diff --git a/compiler/src/test/kotlin/asmble/compile/jvm/LargeDataTest.kt b/compiler/src/test/kotlin/asmble/compile/jvm/LargeDataTest.kt index b938d39..6019bdd 100644 --- a/compiler/src/test/kotlin/asmble/compile/jvm/LargeDataTest.kt +++ b/compiler/src/test/kotlin/asmble/compile/jvm/LargeDataTest.kt @@ -2,7 +2,7 @@ package asmble.compile.jvm import asmble.TestBase import asmble.ast.Node -import asmble.run.jvm.ScriptContext +import asmble.run.jvm.ModuleBuilder import asmble.util.get import org.junit.Test import java.nio.ByteBuffer @@ -32,7 +32,7 @@ class LargeDataTest : TestBase() { logger = logger ) AstToAsm.fromModule(ctx) - val cls = ScriptContext.SimpleClassLoader(javaClass.classLoader, logger).fromBuiltContext(ctx) + val cls = ModuleBuilder.Compiled.SimpleClassLoader(javaClass.classLoader, logger).fromBuiltContext(ctx) // Instantiate it, get the memory out, and check it val field = cls.getDeclaredField("memory").apply { isAccessible = true } val buf = field[cls.newInstance()] as ByteBuffer diff --git a/compiler/src/test/kotlin/asmble/compile/jvm/NamesTest.kt b/compiler/src/test/kotlin/asmble/compile/jvm/NamesTest.kt index 51f327e..6ac9b3c 100644 --- a/compiler/src/test/kotlin/asmble/compile/jvm/NamesTest.kt +++ b/compiler/src/test/kotlin/asmble/compile/jvm/NamesTest.kt @@ -3,7 +3,7 @@ package asmble.compile.jvm import asmble.TestBase import asmble.io.SExprToAst import asmble.io.StrToSExpr -import asmble.run.jvm.ScriptContext +import asmble.run.jvm.ModuleBuilder import org.junit.Test import java.util.* @@ -30,7 +30,7 @@ class NamesTest : TestBase() { logger = logger ) AstToAsm.fromModule(ctx) - val cls = ScriptContext.SimpleClassLoader(javaClass.classLoader, logger).fromBuiltContext(ctx) + val cls = ModuleBuilder.Compiled.SimpleClassLoader(javaClass.classLoader, logger).fromBuiltContext(ctx) // Make sure the import field and the func are present named cls.getDeclaredField("import_func") cls.getDeclaredMethod("some_func", Integer.TYPE) diff --git a/compiler/src/test/kotlin/asmble/run/jvm/LargeFuncTest.kt b/compiler/src/test/kotlin/asmble/run/jvm/LargeFuncTest.kt index 8555de3..87804da 100644 --- a/compiler/src/test/kotlin/asmble/run/jvm/LargeFuncTest.kt +++ b/compiler/src/test/kotlin/asmble/run/jvm/LargeFuncTest.kt @@ -51,12 +51,12 @@ class LargeFuncTest : TestBase() { AstToAsm.fromModule(ctx) // Confirm the method size is too large try { - ScriptContext.SimpleClassLoader(javaClass.classLoader, logger, splitWhenTooLarge = false). + ModuleBuilder.Compiled.SimpleClassLoader(javaClass.classLoader, logger, splitWhenTooLarge = false). fromBuiltContext(ctx) Assert.fail() } catch (e: MethodTooLargeException) { } // Try again with split - val cls = ScriptContext.SimpleClassLoader(javaClass.classLoader, logger).fromBuiltContext(ctx) + val cls = ModuleBuilder.Compiled.SimpleClassLoader(javaClass.classLoader, logger).fromBuiltContext(ctx) // Create it and check that it still does what we expect val inst = cls.newInstance() // Run someFunc diff --git a/compiler/src/test/kotlin/asmble/run/jvm/RunTest.kt b/compiler/src/test/kotlin/asmble/run/jvm/RunTest.kt index 4c86fac..349a7e8 100644 --- a/compiler/src/test/kotlin/asmble/run/jvm/RunTest.kt +++ b/compiler/src/test/kotlin/asmble/run/jvm/RunTest.kt @@ -1,11 +1,39 @@ package asmble.run.jvm import asmble.SpecTestUnit +import asmble.annotation.WasmModule +import asmble.io.AstToBinary +import asmble.io.ByteWriter import org.junit.runner.RunWith import org.junit.runners.Parameterized +import java.io.ByteArrayOutputStream +import kotlin.test.assertEquals @RunWith(Parameterized::class) class RunTest(unit: SpecTestUnit) : TestRunner(unit) { + + override val builder get() = ModuleBuilder.Compiled( + packageName = unit.packageName, + logger = this, + adjustContext = { it.copy(eagerFailLargeMemOffset = false) }, + // Include the binary data so we can check it later + includeBinaryInCompiledClass = true, + defaultMaxMemPages = unit.defaultMaxMemPages + ) + + override fun run() = super.run().also { scriptContext -> + // Check annotations + scriptContext.modules.forEach { mod -> + mod as Module.Compiled + val expectedBinaryString = ByteArrayOutputStream().also { + ByteWriter.OutputStream(it).also { AstToBinary.fromModule(it, mod.mod) } + }.toByteArray().toString(Charsets.ISO_8859_1) + val actualBinaryString = + mod.cls.getDeclaredAnnotation(WasmModule::class.java)?.binary ?: error("No annotation") + assertEquals(expectedBinaryString, actualBinaryString) + } + } + companion object { @JvmStatic @Parameterized.Parameters(name = "{0}") fun data() = SpecTestUnit.allUnits diff --git a/compiler/src/test/kotlin/asmble/run/jvm/TestRunner.kt b/compiler/src/test/kotlin/asmble/run/jvm/TestRunner.kt index c35bcec..3aebe75 100644 --- a/compiler/src/test/kotlin/asmble/run/jvm/TestRunner.kt +++ b/compiler/src/test/kotlin/asmble/run/jvm/TestRunner.kt @@ -2,10 +2,9 @@ package asmble.run.jvm import asmble.BaseTestUnit import asmble.TestBase -import asmble.annotation.WasmModule -import asmble.io.AstToBinary +import asmble.ast.Node +import asmble.ast.Script import asmble.io.AstToSExpr -import asmble.io.ByteWriter import asmble.io.SExprToStr import org.junit.Assume import org.junit.Test @@ -27,7 +26,9 @@ abstract class TestRunner(val unit: T) : TestBase() { } else if (ex != null) throw ex } - private fun run() { + abstract val builder: ModuleBuilder<*> + + open fun run(): ScriptContext { debug { "AST SExpr: " + unit.ast } debug { "AST Str: " + SExprToStr.fromSExpr(*unit.ast.toTypedArray()) } debug { "AST: " + unit.script } @@ -35,12 +36,8 @@ abstract class TestRunner(val unit: T) : TestBase() { val out = ByteArrayOutputStream() var scriptContext = ScriptContext( - packageName = unit.packageName, logger = this, - adjustContext = { it.copy(eagerFailLargeMemOffset = false) }, - defaultMaxMemPages = unit.defaultMaxMemPages, - // Include the binary data so we can check it later - includeBinaryInCompiledClass = true + builder = builder ).withHarnessRegistered(PrintWriter(OutputStreamWriter(out, Charsets.UTF_8), true)) // This will fail assertions as necessary @@ -48,7 +45,7 @@ abstract class TestRunner(val unit: T) : TestBase() { try { scriptContext.runCommand(cmd) } catch (t: Throwable) { - val warningReason = unit.warningInsteadOfErrReason(t) ?: throw t + val warningReason = warningInsteadOfErrReason(t) ?: throw t warn { "Unexpected error on ${unit.name}, but is a warning. Reason: $warningReason. Orig err: $t" } scriptContext } @@ -60,14 +57,47 @@ abstract class TestRunner(val unit: T) : TestBase() { assertEquals(it.trimEnd(), out.toByteArray().toString(Charsets.UTF_8).trimEnd()) } - // Also check the annotations - scriptContext.modules.forEach { mod -> - val expectedBinaryString = ByteArrayOutputStream().also { - ByteWriter.OutputStream(it).also { AstToBinary.fromModule(it, mod.mod) } - }.toByteArray().toString(Charsets.ISO_8859_1) - val actualBinaryString = - mod.cls.getDeclaredAnnotation(WasmModule::class.java)?.binary ?: error("No annotation") - assertEquals(expectedBinaryString, actualBinaryString) + return scriptContext + } + + // TODO: move this into the script context for specific assertions so the rest can continue running + open fun warningInsteadOfErrReason(t: Throwable): String? = when (unit.name) { + "binary" -> { + val expectedFailure = ((t as? ScriptAssertionError)?.assertion as? Script.Cmd.Assertion.Malformed)?.failure + // TODO: Pending answer to https://github.com/WebAssembly/spec/pull/882#issuecomment-426349365 + if (expectedFailure == "integer too large") "Binary test changed" else null + } + // NaN bit patterns can be off + "float_literals", "float_exprs", "float_misc", "f32_bitwise" -> + if (isNanMismatch(t)) "NaN JVM bit patterns can be off" else null + // We don't hold table capacity right now + // TODO: Figure out how we want to store/retrieve table capacity. Right now + // a table is an array, so there is only size not capacity. Since we want to + // stay w/ the stdlib of the JVM, the best option may be to store the capacity + // as a separate int value and query it or pass it around via import as + // necessary. I guess I could use a vector, but it's not worth it just for + // capacity since you lose speed. + "imports" -> { + val isTableMaxErr = t is ScriptAssertionError && (t.assertion as? Script.Cmd.Assertion.Unlinkable).let { + it != null && it.failure == "incompatible import type" && + it.module.imports.singleOrNull()?.kind is Node.Import.Kind.Table + } + if (isTableMaxErr) "Table max capacities are not validated" else null } + else -> null + } + + private fun isNanMismatch(t: Throwable) = t is ScriptAssertionError && ( + t.assertion is Script.Cmd.Assertion.ReturnNan || + (t.assertion is Script.Cmd.Assertion.Return && (t.assertion as Script.Cmd.Assertion.Return).let { + it.exprs.any { it.any(this::insnIsNanConst) } || + ((it.action as? Script.Cmd.Action.Invoke)?.string?.contains("nan") ?: false) + }) + ) + + private fun insnIsNanConst(i: Node.Instr) = when (i) { + is Node.Instr.F32Const -> i.value.isNaN() + is Node.Instr.F64Const -> i.value.isNaN() + else -> false } } \ No newline at end of file diff --git a/compiler/src/test/kotlin/asmble/run/jvm/interpret/InterpretTest.kt b/compiler/src/test/kotlin/asmble/run/jvm/interpret/InterpretTest.kt new file mode 100644 index 0000000..078deae --- /dev/null +++ b/compiler/src/test/kotlin/asmble/run/jvm/interpret/InterpretTest.kt @@ -0,0 +1,60 @@ +package asmble.run.jvm.interpret + +import asmble.SpecTestUnit +import asmble.ast.Script +import asmble.run.jvm.ScriptAssertionError +import asmble.run.jvm.TestRunner +import org.junit.runner.RunWith +import org.junit.runners.Parameterized + +@RunWith(Parameterized::class) +class InterpretTest(unit: SpecTestUnit) : TestRunner(unit) { + + override val builder get() = RunModule.Builder( + logger = logger, + defaultMaxMemPages = unit.defaultMaxMemPages + ) + + // Some things require compilation of code, something the interpreter doesn't do until execution time + override fun warningInsteadOfErrReason(t: Throwable) = super.warningInsteadOfErrReason(t) ?: run { + // Interpreter doesn't eagerly validate + if ((t as? ScriptAssertionError)?.assertion is Script.Cmd.Assertion.Invalid) + return "Interpreter doesn't eagerly validate" + // Other units specifically... + when (unit.name) { + "if" -> { + // Couple of tests expect mismatching label to be caught at compilation without execution + val compileErr = ((t as? ScriptAssertionError)?. + assertion as? Script.Cmd.Assertion.Malformed)?.failure == "mismatching label" + if (compileErr) "Interpreter doesn't check unexecuted code" else null + } + "imports", "linking" -> (t as? ScriptAssertionError)?.assertion?.let { assertion -> + when (assertion) { + // Couple of tests expect imports to be checked without attempted resolution + is Script.Cmd.Assertion.Unlinkable -> when (assertion.failure) { + "unknown import", "incompatible import type" -> + "Interpreter doesn't check unexecuted code" + else -> null + } + // There is a test that expects none of the table elems to be set if there is any module link + // failure previously. The problem is the interpreter runs until it can't and is non-atomic + is Script.Cmd.Assertion.Trap -> + if (assertion.failure == "uninitialized") "Interpreter doesn't initialize atomically" + else null + // Same here, expects previous data to have never been placed + is Script.Cmd.Assertion.Return -> (assertion.action as? Script.Cmd.Action.Invoke)?.let { invoke -> + if (invoke.name == "Mm" && invoke.string == "load") "Interpreter doesn't initialize atomically" + else null + } + else -> null + } + } + else -> null + } + } + + companion object { + @JvmStatic @Parameterized.Parameters(name = "{0}") + fun data() = SpecTestUnit.allUnits + } +}