Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support wasm globals #650

Merged
merged 1 commit into from
Jul 21, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions src/ast_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,8 @@ struct EffectAnalyzer : public PostWalker<EffectAnalyzer, Visitor<EffectAnalyzer
void visitSetLocal(SetLocal *curr) {
localsWritten.insert(curr->index);
}
void visitGetGlobal(GetGlobal *curr) { readsMemory = true; } // TODO: global-specific
void visitSetGlobal(SetGlobal *curr) { writesMemory = true; } // stuff?
void visitLoad(Load *curr) { readsMemory = true; }
void visitStore(Store *curr) { writesMemory = true; }
void visitReturn(Return *curr) { branches = true; }
Expand Down Expand Up @@ -277,6 +279,12 @@ struct ExpressionManipulator {
Expression* visitSetLocal(SetLocal *curr) {
return builder.makeSetLocal(curr->index, copy(curr->value));
}
Expression* visitGetGlobal(GetGlobal *curr) {
return builder.makeGetGlobal(curr->index, curr->type);
}
Expression* visitSetGlobal(SetGlobal *curr) {
return builder.makeSetGlobal(curr->index, copy(curr->value));
}
Expression* visitLoad(Load *curr) {
return builder.makeLoad(curr->bytes, curr->signed_, curr->offset, curr->align, copy(curr->ptr), curr->type);
}
Expand Down Expand Up @@ -476,6 +484,15 @@ struct ExpressionAnalyzer {
PUSH(SetLocal, value);
break;
}
case Expression::Id::GetGlobalId: {
CHECK(GetGlobal, index);
break;
}
case Expression::Id::SetGlobalId: {
CHECK(SetGlobal, index);
PUSH(SetGlobal, value);
break;
}
case Expression::Id::LoadId: {
CHECK(Load, bytes);
CHECK(Load, signed_);
Expand Down Expand Up @@ -678,6 +695,15 @@ struct ExpressionAnalyzer {
PUSH(SetLocal, value);
break;
}
case Expression::Id::GetGlobalId: {
HASH(GetGlobal, index);
break;
}
case Expression::Id::SetGlobalId: {
HASH(SetGlobal, index);
PUSH(SetGlobal, value);
break;
}
case Expression::Id::LoadId: {
HASH(Load, bytes);
HASH(Load, signed_);
Expand Down
6 changes: 6 additions & 0 deletions src/passes/Precompute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,12 @@ class StandaloneExpressionRunner : public ExpressionRunner<StandaloneExpressionR
Flow visitSetLocal(SetLocal *curr) {
return Flow(NONSTANDALONE);
}
Flow visitGetGlobal(GetGlobal *curr) {
return Flow(NONSTANDALONE);
}
Flow visitSetGlobal(SetGlobal *curr) {
return Flow(NONSTANDALONE);
}
Flow visitLoad(Load *curr) {
return Flow(NONSTANDALONE);
}
Expand Down
27 changes: 27 additions & 0 deletions src/passes/Print.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ struct PrintSExpression : public Visitor<PrintSExpression> {
bool fullAST = false; // whether to not elide nodes in output when possible
// (like implicit blocks)

Module* currModule = nullptr;
Function* currFunction = nullptr;

PrintSExpression(std::ostream& o) : o(o) {
Expand Down Expand Up @@ -78,6 +79,10 @@ struct PrintSExpression : public Visitor<PrintSExpression> {
return name;
}

Name printableGlobal(Index index) {
return currModule->getGlobal(index)->name;
}

std::ostream& printName(Name name) {
// we need to quote names if they have tricky chars
if (strpbrk(name.str, "()")) {
Expand Down Expand Up @@ -239,6 +244,15 @@ struct PrintSExpression : public Visitor<PrintSExpression> {
printFullLine(curr->value);
decIndent();
}
void visitGetGlobal(GetGlobal *curr) {
printOpening(o, "get_global ") << printableGlobal(curr->index) << ')';
}
void visitSetGlobal(SetGlobal *curr) {
printOpening(o, "set_global ") << printableGlobal(curr->index);
incIndent();
printFullLine(curr->value);
decIndent();
}
void visitLoad(Load *curr) {
o << '(';
prepareColor(o) << printWasmType(curr->type) << ".load";
Expand Down Expand Up @@ -519,6 +533,12 @@ struct PrintSExpression : public Visitor<PrintSExpression> {
printText(o, curr->name.str) << ' ';
printName(curr->value) << ')';
}
void visitGlobal(Global *curr) {
printOpening(o, "global ");
printName(curr->name) << ' ' << printWasmType(curr->type);
printFullLine(curr->init);
o << ')';
}
void visitFunction(Function *curr) {
currFunction = curr;
printOpening(o, "func ", true);
Expand Down Expand Up @@ -563,6 +583,7 @@ struct PrintSExpression : public Visitor<PrintSExpression> {
o << ')';
}
void visitModule(Module *curr) {
currModule = curr;
printOpening(o, "module", true);
incIndent();
doIndent(o, indent);
Expand Down Expand Up @@ -621,6 +642,11 @@ struct PrintSExpression : public Visitor<PrintSExpression> {
visitExport(child.get());
o << maybeNewLine;
}
for (auto& child : curr->globals) {
doIndent(o, indent);
visitGlobal(child.get());
o << maybeNewLine;
}
if (curr->table.names.size() > 0) {
doIndent(o, indent);
visitTable(&curr->table);
Expand All @@ -633,6 +659,7 @@ struct PrintSExpression : public Visitor<PrintSExpression> {
}
decIndent();
o << maybeNewLine;
currModule = nullptr;
}
};

Expand Down
13 changes: 13 additions & 0 deletions src/wasm-builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,19 @@ class Builder {
ret->type = value->type;
return ret;
}
GetGlobal* makeGetGlobal(Index index, WasmType type) {
auto* ret = allocator.alloc<GetGlobal>();
ret->index = index;
ret->type = type;
return ret;
}
SetGlobal* makeSetGlobal(Index index, Expression* value) {
auto* ret = allocator.alloc<SetGlobal>();
ret->index = index;
ret->value = value;
ret->type = value->type;
return ret;
}
Load* makeLoad(unsigned bytes, bool signed_, uint32_t offset, unsigned align, Expression *ptr, WasmType type) {
auto* ret = allocator.alloc<Load>();
ret->bytes = bytes; ret->signed_ = signed_; ret->offset = offset; ret->align = align; ret->ptr = ptr;
Expand Down
42 changes: 42 additions & 0 deletions src/wasm-interpreter.h
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,22 @@ class ExpressionRunner : public Visitor<SubType, Flow> {
}
};

// Execute an expression in global init
class GlobalInitRunner : public ExpressionRunner<GlobalInitRunner> {
public:
Flow visitLoop(Loop* curr) { WASM_UNREACHABLE(); }
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should probably result in some kind of validation error, rather than WASM_UNREACHABLE (since otherwise well-formed input could reach this code?)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code is reached when we execute the init statement for the global, which is when we instantiate the module, which is after we fully parse and validate it. That means we can assume the expression is valid, in which case none of these unreachables would be hit. Assuming we don't have a bug in our validation ;)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it; but I don't see code for that in wasm-validator.h?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Look at https://github.com/WebAssembly/binaryen/pull/650/files/0dc286f78cd3d5d1f5fd1cb21c16485cd9f2fd1f#diff-33a3856e71772f6165f55c85cb076496R264 , it's actually a stronger check than needed - it only accepts Const (because that's all that is tested currently, and I'm not sure what else will be allowed).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, got it.

Flow visitCall(Call* curr) { WASM_UNREACHABLE(); }
Flow visitCallImport(CallImport* curr) { WASM_UNREACHABLE(); }
Flow visitCallIndirect(CallIndirect* curr) { WASM_UNREACHABLE(); }
Flow visitGetLocal(GetLocal *curr) { WASM_UNREACHABLE(); }
Flow visitSetLocal(SetLocal *curr) { WASM_UNREACHABLE(); }
Flow visitGetGlobal(GetGlobal *curr) { WASM_UNREACHABLE(); }
Flow visitSetGlobal(SetGlobal *curr) { WASM_UNREACHABLE(); }
Flow visitLoad(Load *curr) { WASM_UNREACHABLE(); }
Flow visitStore(Store *curr) { WASM_UNREACHABLE(); }
Flow visitHost(Host *curr) { WASM_UNREACHABLE(); }
};

//
// An instance of a WebAssembly module, which can execute it via AST interpretation.
//
Expand Down Expand Up @@ -519,8 +535,14 @@ class ModuleInstance {

Module& wasm;

// Values of globals
std::vector<Literal> globals;

ModuleInstance(Module& wasm, ExternalInterface* externalInterface) : wasm(wasm), externalInterface(externalInterface) {
memorySize = wasm.memory.initial;
for (Index i = 0; i < wasm.globals.size(); i++) {
globals.push_back(GlobalInitRunner().visit(wasm.globals[i]->init).value);
}
externalInterface->init(wasm);
if (wasm.start.is()) {
LiteralList arguments;
Expand Down Expand Up @@ -676,6 +698,26 @@ class ModuleInstance {
scope.locals[index] = flow.value;
return flow;
}

Flow visitGetGlobal(GetGlobal *curr) {
NOTE_ENTER("GetGlobal");
auto index = curr->index;
NOTE_EVAL1(index);
NOTE_EVAL1(instance.globals[index]);
return instance.globals[index];
}
Flow visitSetGlobal(SetGlobal *curr) {
NOTE_ENTER("SetGlobal");
auto index = curr->index;
Flow flow = visit(curr->value);
if (flow.breaking()) return flow;
NOTE_EVAL1(index);
NOTE_EVAL1(flow.value);
assert(flow.value.type == curr->type);
instance.globals[index] = flow.value;
return flow;
}

Flow visitLoad(Load *curr) {
NOTE_ENTER("Load");
Flow flow = visit(curr->ptr);
Expand Down
59 changes: 56 additions & 3 deletions src/wasm-s-parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -264,11 +264,12 @@ class SExpressionWasmBuilder {
std::vector<Name> functionNames;
int functionCounter;
int importCounter;
int globalCounter;
std::map<Name, WasmType> functionTypes; // we need to know function return types before we parse their contents

public:
// Assumes control of and modifies the input.
SExpressionWasmBuilder(Module& wasm, Element& module) : wasm(wasm), allocator(wasm.allocator), importCounter(0) {
SExpressionWasmBuilder(Module& wasm, Element& module) : wasm(wasm), allocator(wasm.allocator), importCounter(0), globalCounter(0) {
assert(module[0]->str() == MODULE);
if (module.size() > 1 && module[1]->isStr()) {
// these s-expressions contain a binary module, actually
Expand Down Expand Up @@ -340,6 +341,7 @@ class SExpressionWasmBuilder {
if (id == MEMORY) return parseMemory(curr);
if (id == EXPORT) return parseExport(curr);
if (id == IMPORT) return; // already done
if (id == GLOBAL) return parseGlobal(curr);
if (id == TABLE) return parseTable(curr);
if (id == TYPE) return; // already done
std::cerr << "bad module element " << id.str << '\n';
Expand Down Expand Up @@ -703,7 +705,10 @@ class SExpressionWasmBuilder {
abort_on(str);
}
case 'g': {
if (str[1] == 'e') return makeGetLocal(s);
if (str[1] == 'e') {
if (str[4] == 'l') return makeGetLocal(s);
if (str[4] == 'g') return makeGetGlobal(s);
}
if (str[1] == 'r') return makeHost(s, HostOp::GrowMemory);
abort_on(str);
}
Expand All @@ -728,7 +733,10 @@ class SExpressionWasmBuilder {
abort_on(str);
}
case 's': {
if (str[1] == 'e' && str[2] == 't') return makeSetLocal(s);
if (str[1] == 'e' && str[2] == 't') {
if (str[4] == 'l') return makeSetLocal(s);
if (str[4] == 'g') return makeSetGlobal(s);
}
if (str[1] == 'e' && str[2] == 'l') return makeSelect(s);
abort_on(str);
}
Expand Down Expand Up @@ -844,6 +852,7 @@ class SExpressionWasmBuilder {
}

Index getLocalIndex(Element& s) {
if (!currFunction) throw ParseException("local access in non-function scope", s.line, s.col);
if (s.dollared()) {
auto ret = s.str();
if (currFunction->localIndices.count(ret) == 0) throw ParseException("bad local name", s.line, s.col);
Expand All @@ -870,6 +879,35 @@ class SExpressionWasmBuilder {
return ret;
}

Index getGlobalIndex(Element& s) {
if (s.dollared()) {
auto name = s.str();
for (Index i = 0; i < wasm.globals.size(); i++) {
if (wasm.globals[i]->name == name) return i;
}
throw ParseException("bad global name", s.line, s.col);
}
// this is a numeric index
Index ret = atoi(s.c_str());
if (!wasm.checkGlobal(ret)) throw ParseException("bad global index", s.line, s.col);
return ret;
}

Expression* makeGetGlobal(Element& s) {
auto ret = allocator.alloc<GetGlobal>();
ret->index = getGlobalIndex(*s[1]);
ret->type = wasm.getGlobal(ret->index)->type;
return ret;
}

Expression* makeSetGlobal(Element& s) {
auto ret = allocator.alloc<SetGlobal>();
ret->index = getGlobalIndex(*s[1]);
ret->value = parseExpression(s[2]);
ret->type = wasm.getGlobal(ret->index)->type;
return ret;
}

Expression* makeBlock(Element& s) {
// special-case Block, because Block nesting (in their first element) can be incredibly deep
auto curr = allocator.alloc<Block>();
Expand Down Expand Up @@ -1315,6 +1353,21 @@ class SExpressionWasmBuilder {
wasm.addImport(im.release());
}

void parseGlobal(Element& s) {
std::unique_ptr<Global> global = make_unique<Global>();
size_t i = 1;
if (s.size() == 4) {
global->name = s[i++]->str();
} else {
global->name = Name::fromInt(globalCounter);
}
globalCounter++;
global->type = stringToWasmType(s[i++]->str());
global->init = parseExpression(s[i++]);
assert(i == s.size());
wasm.addGlobal(global.release());
}

void parseTable(Element& s) {
for (size_t i = 1; i < s.size(); i++) {
wasm.table.names.push_back(getFunctionName(*s[i]));
Expand Down
Loading