Skip to content

Commit 35ba23c

Browse files
Compute effects for indirect calls in GlobalEffects (#8609)
When running in --closed-world, compute effects for indirect calls by unioning the effects of all potential functions of that type. In --closed-world, we assume that all references originate in our module, so the only possible functions that we don't know about are imports. Previously [we gave up on effects analysis](https://github.com/WebAssembly/binaryen/blob/29b2d42e8a748fbe1095696d58a52b7bf83e2253/src/passes/GlobalEffects.cpp#L83-L87) for indirect calls. Yields a very small byte count reduction in calcworker (3799354 - 3799297 = 57 bytes). Also shows no significant difference in Binaryen runtime: (0.1346069 -> 0.13375045 = <1% improvement, probably within noise). We expect more benefits after we're able to share indirect call effects with other passes, since currently they're only seen one layer up for callers of functions that indirectly call functions (see the newly-added tests for examples). Followups: * Share effect information per type with other passes besides just via Function::effects (#8625) * Exclude functions that don't have an address (i.e. functions that aren't the target of ref.func) from effect analysis () * Compute effects more precisely for exact + nullable/non-nullable references Part of #8615.
1 parent f61c445 commit 35ba23c

8 files changed

Lines changed: 994 additions & 59 deletions

src/passes/GlobalEffects.cpp

Lines changed: 130 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,12 @@
1919
// PassOptions structure; see more details there.
2020
//
2121

22+
#include <ranges>
23+
2224
#include "ir/effects.h"
2325
#include "ir/module-utils.h"
2426
#include "pass.h"
27+
#include "support/graph_traversal.h"
2528
#include "support/strongly_connected_components.h"
2629
#include "wasm.h"
2730

@@ -39,6 +42,9 @@ struct FuncInfo {
3942

4043
// Directly-called functions from this function.
4144
std::unordered_set<Name> calledFunctions;
45+
46+
// Types that are targets of indirect calls.
47+
std::unordered_set<HeapType> indirectCalledTypes;
4248
};
4349

4450
std::map<Function*, FuncInfo> analyzeFuncs(Module& module,
@@ -83,11 +89,22 @@ std::map<Function*, FuncInfo> analyzeFuncs(Module& module,
8389
if (auto* call = curr->dynCast<Call>()) {
8490
// Note the direct call.
8591
funcInfo.calledFunctions.insert(call->target);
92+
} else if (effects.calls && options.closedWorld) {
93+
HeapType type;
94+
if (auto* callRef = curr->dynCast<CallRef>()) {
95+
// call_ref on unreachable does not have a call effect,
96+
// so this must be a HeapType.
97+
type = callRef->target->type.getHeapType();
98+
} else if (auto* callIndirect = curr->dynCast<CallIndirect>()) {
99+
type = callIndirect->heapType;
100+
} else {
101+
funcInfo.effects = UnknownEffects;
102+
return;
103+
}
104+
105+
funcInfo.indirectCalledTypes.insert(type);
86106
} else if (effects.calls) {
87-
// This is an indirect call of some sort, so we must assume the
88-
// worst. To do so, clear the effects, which indicates nothing
89-
// is known (so anything is possible).
90-
// TODO: We could group effects by function type etc.
107+
assert(!options.closedWorld);
91108
funcInfo.effects = UnknownEffects;
92109
} else {
93110
// No call here, but update throwing if we see it. (Only do so,
@@ -107,22 +124,84 @@ std::map<Function*, FuncInfo> analyzeFuncs(Module& module,
107124
return std::move(analysis.map);
108125
}
109126

110-
using CallGraph = std::unordered_map<Function*, std::unordered_set<Function*>>;
127+
using CallGraphNode = std::variant<Function*, HeapType>;
128+
129+
// Call graph for indirect and direct calls.
130+
//
131+
// key (caller) -> value (callee)
132+
// Function -> Function : direct call
133+
// Function -> HeapType : indirect call to the given HeapType
134+
// HeapType -> Function : The function `callee` has the type `caller`. The
135+
// HeapType may essentially 'call' any of its
136+
// potential implementations.
137+
// HeapType -> HeapType : `callee` is a subtype of `caller`. A call_ref
138+
// could target any subtype of the ref, so we need to
139+
// aggregate effects of subtypes of the target type.
140+
//
141+
// If we're running in an open world, we only include Function -> Function
142+
// edges, and don't compute effects for indirect calls, conservatively assuming
143+
// the worst.
144+
using CallGraph =
145+
std::unordered_map<CallGraphNode, std::unordered_set<CallGraphNode>>;
111146

112147
CallGraph buildCallGraph(const Module& module,
113-
const std::map<Function*, FuncInfo>& funcInfos) {
148+
const std::map<Function*, FuncInfo>& funcInfos,
149+
bool closedWorld) {
114150
CallGraph callGraph;
115-
for (const auto& [func, info] : funcInfos) {
116-
if (info.calledFunctions.empty()) {
117-
continue;
151+
if (!closedWorld) {
152+
for (const auto& [caller, callerInfo] : funcInfos) {
153+
auto& callees = callGraph[caller];
154+
155+
// Function -> Function
156+
for (Name calleeFunction : callerInfo.calledFunctions) {
157+
callees.insert(module.getFunction(calleeFunction));
158+
}
159+
}
160+
161+
return callGraph;
162+
}
163+
164+
std::unordered_set<HeapType> allFunctionTypes;
165+
for (const auto& [caller, callerInfo] : funcInfos) {
166+
auto& callees = callGraph[caller];
167+
168+
// Function -> Function
169+
for (Name calleeFunction : callerInfo.calledFunctions) {
170+
callees.insert(module.getFunction(calleeFunction));
118171
}
119172

120-
auto& callees = callGraph[func];
121-
for (Name callee : info.calledFunctions) {
122-
callees.insert(module.getFunction(callee));
173+
// Function -> Type
174+
allFunctionTypes.insert(caller->type.getHeapType());
175+
for (HeapType calleeType : callerInfo.indirectCalledTypes) {
176+
callees.insert(calleeType);
177+
178+
// Add the key to ensure the lookup doesn't fail for indirect calls to
179+
// uninhabited types.
180+
callGraph[calleeType];
123181
}
182+
183+
// Type -> Function
184+
callGraph[caller->type.getHeapType()].insert(caller);
124185
}
125186

187+
// Type -> Type
188+
// Do a DFS up the type heirarchy for all function implementations.
189+
// We are essentially walking up each supertype chain and adding edges from
190+
// super -> subtype, but doing it via DFS to avoid repeated work.
191+
Graph superTypeGraph(allFunctionTypes.begin(),
192+
allFunctionTypes.end(),
193+
[&callGraph](auto&& push, HeapType t) {
194+
// Not needed except that during lookup we expect the
195+
// key to exist.
196+
callGraph[t];
197+
198+
if (auto super = t.getDeclaredSuperType()) {
199+
callGraph[*super].insert(t);
200+
push(*super);
201+
}
202+
});
203+
(void)superTypeGraph.traverseDepthFirst();
204+
126205
return callGraph;
127206
}
128207

@@ -152,63 +231,60 @@ void propagateEffects(const Module& module,
152231
const PassOptions& passOptions,
153232
std::map<Function*, FuncInfo>& funcInfos,
154233
const CallGraph& callGraph) {
234+
// We only care about Functions that are roots, not types.
235+
// A type would be a root if a function exists with that type, but no-one
236+
// indirect calls the type.
237+
auto funcNodes = std::views::keys(callGraph) |
238+
std::views::filter([](auto node) {
239+
return std::holds_alternative<Function*>(node);
240+
}) |
241+
std::views::common;
242+
using funcNodesType = decltype(funcNodes);
243+
155244
struct CallGraphSCCs
156-
: SCCs<std::vector<Function*>::const_iterator, CallGraphSCCs> {
245+
: SCCs<std::ranges::iterator_t<funcNodesType>, CallGraphSCCs> {
246+
157247
const std::map<Function*, FuncInfo>& funcInfos;
158-
const std::unordered_map<Function*, std::unordered_set<Function*>>&
159-
callGraph;
248+
const CallGraph& callGraph;
160249
const Module& module;
161250

162-
CallGraphSCCs(
163-
const std::vector<Function*>& funcs,
164-
const std::map<Function*, FuncInfo>& funcInfos,
165-
const std::unordered_map<Function*, std::unordered_set<Function*>>&
166-
callGraph,
167-
const Module& module)
168-
: SCCs<std::vector<Function*>::const_iterator, CallGraphSCCs>(
169-
funcs.begin(), funcs.end()),
251+
CallGraphSCCs(funcNodesType&& nodes,
252+
const std::map<Function*, FuncInfo>& funcInfos,
253+
const CallGraph& callGraph,
254+
const Module& module)
255+
: SCCs<std::ranges::iterator_t<funcNodesType>, CallGraphSCCs>(
256+
std::ranges::begin(nodes), std::ranges::end(nodes)),
170257
funcInfos(funcInfos), callGraph(callGraph), module(module) {}
171258

172-
void pushChildren(Function* f) {
173-
auto callees = callGraph.find(f);
174-
if (callees == callGraph.end()) {
175-
return;
176-
}
177-
178-
for (auto* callee : callees->second) {
259+
void pushChildren(CallGraphNode node) {
260+
for (CallGraphNode callee : callGraph.at(node)) {
179261
push(callee);
180262
}
181263
}
182264
};
183-
184-
std::vector<Function*> allFuncs;
185-
for (auto& [func, info] : funcInfos) {
186-
allFuncs.push_back(func);
187-
}
188-
CallGraphSCCs sccs(allFuncs, funcInfos, callGraph, module);
265+
CallGraphSCCs sccs(std::move(funcNodes), funcInfos, callGraph, module);
189266

190267
std::vector<std::optional<EffectAnalyzer>> componentEffects;
191268
// Points to an index in componentEffects
192-
std::unordered_map<Function*, Index> funcComponents;
269+
std::unordered_map<CallGraphNode, Index> nodeComponents;
193270

194271
for (auto ccIterator : sccs) {
195272
std::optional<EffectAnalyzer>& ccEffects =
196273
componentEffects.emplace_back(std::in_place, passOptions, module);
274+
std::vector<CallGraphNode> cc(ccIterator.begin(), ccIterator.end());
197275

198-
std::vector<Function*> ccFuncs(ccIterator.begin(), ccIterator.end());
199-
200-
for (Function* f : ccFuncs) {
201-
funcComponents.emplace(f, componentEffects.size() - 1);
276+
std::vector<Function*> ccFuncs;
277+
for (CallGraphNode node : cc) {
278+
nodeComponents.emplace(node, componentEffects.size() - 1);
279+
if (auto** func = std::get_if<Function*>(&node)) {
280+
ccFuncs.push_back(*func);
281+
}
202282
}
203283

204284
std::unordered_set<int> calleeSccs;
205-
for (Function* caller : ccFuncs) {
206-
auto callees = callGraph.find(caller);
207-
if (callees == callGraph.end()) {
208-
continue;
209-
}
210-
for (auto* callee : callees->second) {
211-
calleeSccs.insert(funcComponents.at(callee));
285+
for (CallGraphNode caller : cc) {
286+
for (CallGraphNode callee : callGraph.at(caller)) {
287+
calleeSccs.insert(nodeComponents.at(callee));
212288
}
213289
}
214290

@@ -219,11 +295,13 @@ void propagateEffects(const Module& module,
219295
}
220296

221297
// Add trap effects for potential cycles.
222-
if (ccFuncs.size() > 1) {
298+
if (cc.size() > 1) {
223299
if (ccEffects != UnknownEffects) {
224300
ccEffects->trap = true;
225301
}
226-
} else {
302+
} else if (ccFuncs.size() == 1) {
303+
// It's possible for a CC to only contain 1 type, but that is not a
304+
// cycle in the call graph.
227305
auto* func = ccFuncs[0];
228306
if (funcInfos.at(func).calledFunctions.contains(func->name)) {
229307
if (ccEffects != UnknownEffects) {
@@ -267,7 +345,8 @@ struct GenerateGlobalEffects : public Pass {
267345
std::map<Function*, FuncInfo> funcInfos =
268346
analyzeFuncs(*module, getPassOptions());
269347

270-
auto callGraph = buildCallGraph(*module, funcInfos);
348+
auto callGraph =
349+
buildCallGraph(*module, funcInfos, getPassOptions().closedWorld);
271350

272351
propagateEffects(*module, getPassOptions(), funcInfos, callGraph);
273352

src/support/graph_traversal.h

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
/*
2+
* Copyright 2026 WebAssembly Community Group participants
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#include <concepts>
18+
#include <functional>
19+
#include <iterator>
20+
#include <unordered_set>
21+
22+
namespace wasm {
23+
24+
// SuccessorFunction should be an invocable that takes a 'push' function (which
25+
// is an invocable that takes a `const T&`), and a `const T&`. i.e.
26+
// SuccessorFunction should call `push` for each neighbor of the T that it's
27+
// called with.
28+
// TODO: We don't have a good way to write this with concepts today.
29+
// Something like this should do it, but we hit an ICE on dwarf symbols in debug
30+
// builds: requires requires(const SuccessorFunction& successors, const T& t) {
31+
// successors([](const T&) { }, t); }
32+
template<typename T, typename SuccessorFunction> class Graph {
33+
public:
34+
template<std::input_iterator It, std::sentinel_for<It> Sen>
35+
requires std::convertible_to<std::iter_reference_t<It>, T>
36+
Graph(It rootsBegin, Sen rootsEnd, SuccessorFunction successors)
37+
: roots(rootsBegin, rootsEnd), successors(std::move(successors)) {}
38+
39+
// Traverse the graph depth-first, calling `successors` exactly once for each
40+
// node (unless the node appears multiple times in `roots`). Return the set of
41+
// nodes visited.
42+
std::unordered_set<T> traverseDepthFirst() const {
43+
std::vector<T> stack(roots.begin(), roots.end());
44+
std::unordered_set<T> visited(roots.begin(), roots.end());
45+
46+
auto maybePush = [&](const T& t) {
47+
auto [_, inserted] = visited.insert(t);
48+
if (inserted) {
49+
stack.push_back(t);
50+
}
51+
};
52+
53+
while (!stack.empty()) {
54+
auto curr = std::move(stack.back());
55+
stack.pop_back();
56+
57+
successors(maybePush, curr);
58+
}
59+
60+
return visited;
61+
}
62+
63+
private:
64+
std::vector<T> roots;
65+
SuccessorFunction successors;
66+
};
67+
68+
template<std::input_iterator It,
69+
std::sentinel_for<It> Sen,
70+
typename SuccessorFunction>
71+
Graph(It, Sen, SuccessorFunction)
72+
-> Graph<std::iter_value_t<It>, std::decay_t<SuccessorFunction>>;
73+
74+
} // namespace wasm

test/gtest/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ set(unittest_SOURCES
1212
dataflow.cpp
1313
dfa_minimization.cpp
1414
disjoint_sets.cpp
15+
graph.cpp
1516
leaves.cpp
1617
glbs.cpp
1718
interpreter.cpp

0 commit comments

Comments
 (0)