1919// PassOptions structure; see more details there.
2020//
2121
22+ #include < ranges>
23+
2224#include " ir/effects.h"
2325#include " ir/module-utils.h"
2426#include " pass.h"
27+ #include " support/graph_traversal.h"
2528#include " support/strongly_connected_components.h"
2629#include " wasm.h"
2730
@@ -39,6 +42,9 @@ struct FuncInfo {
3942
4043 // Directly-called functions from this function.
4144 std::unordered_set<Name> calledFunctions;
45+
46+ // Types that are targets of indirect calls.
47+ std::unordered_set<HeapType> indirectCalledTypes;
4248};
4349
4450std::map<Function*, FuncInfo> analyzeFuncs (Module& module ,
@@ -83,11 +89,22 @@ std::map<Function*, FuncInfo> analyzeFuncs(Module& module,
8389 if (auto * call = curr->dynCast <Call>()) {
8490 // Note the direct call.
8591 funcInfo.calledFunctions .insert (call->target );
92+ } else if (effects.calls && options.closedWorld ) {
93+ HeapType type;
94+ if (auto * callRef = curr->dynCast <CallRef>()) {
95+ // call_ref on unreachable does not have a call effect,
96+ // so this must be a HeapType.
97+ type = callRef->target ->type .getHeapType ();
98+ } else if (auto * callIndirect = curr->dynCast <CallIndirect>()) {
99+ type = callIndirect->heapType ;
100+ } else {
101+ funcInfo.effects = UnknownEffects;
102+ return ;
103+ }
104+
105+ funcInfo.indirectCalledTypes .insert (type);
86106 } else if (effects.calls ) {
87- // This is an indirect call of some sort, so we must assume the
88- // worst. To do so, clear the effects, which indicates nothing
89- // is known (so anything is possible).
90- // TODO: We could group effects by function type etc.
107+ assert (!options.closedWorld );
91108 funcInfo.effects = UnknownEffects;
92109 } else {
93110 // No call here, but update throwing if we see it. (Only do so,
@@ -107,22 +124,84 @@ std::map<Function*, FuncInfo> analyzeFuncs(Module& module,
107124 return std::move (analysis.map );
108125}
109126
110- using CallGraph = std::unordered_map<Function*, std::unordered_set<Function*>>;
127+ using CallGraphNode = std::variant<Function*, HeapType>;
128+
129+ // Call graph for indirect and direct calls.
130+ //
131+ // key (caller) -> value (callee)
132+ // Function -> Function : direct call
133+ // Function -> HeapType : indirect call to the given HeapType
134+ // HeapType -> Function : The function `callee` has the type `caller`. The
135+ // HeapType may essentially 'call' any of its
136+ // potential implementations.
137+ // HeapType -> HeapType : `callee` is a subtype of `caller`. A call_ref
138+ // could target any subtype of the ref, so we need to
139+ // aggregate effects of subtypes of the target type.
140+ //
141+ // If we're running in an open world, we only include Function -> Function
142+ // edges, and don't compute effects for indirect calls, conservatively assuming
143+ // the worst.
144+ using CallGraph =
145+ std::unordered_map<CallGraphNode, std::unordered_set<CallGraphNode>>;
111146
112147CallGraph buildCallGraph (const Module& module ,
113- const std::map<Function*, FuncInfo>& funcInfos) {
148+ const std::map<Function*, FuncInfo>& funcInfos,
149+ bool closedWorld) {
114150 CallGraph callGraph;
115- for (const auto & [func, info] : funcInfos) {
116- if (info.calledFunctions .empty ()) {
117- continue ;
151+ if (!closedWorld) {
152+ for (const auto & [caller, callerInfo] : funcInfos) {
153+ auto & callees = callGraph[caller];
154+
155+ // Function -> Function
156+ for (Name calleeFunction : callerInfo.calledFunctions ) {
157+ callees.insert (module .getFunction (calleeFunction));
158+ }
159+ }
160+
161+ return callGraph;
162+ }
163+
164+ std::unordered_set<HeapType> allFunctionTypes;
165+ for (const auto & [caller, callerInfo] : funcInfos) {
166+ auto & callees = callGraph[caller];
167+
168+ // Function -> Function
169+ for (Name calleeFunction : callerInfo.calledFunctions ) {
170+ callees.insert (module .getFunction (calleeFunction));
118171 }
119172
120- auto & callees = callGraph[func];
121- for (Name callee : info.calledFunctions ) {
122- callees.insert (module .getFunction (callee));
173+ // Function -> Type
174+ allFunctionTypes.insert (caller->type .getHeapType ());
175+ for (HeapType calleeType : callerInfo.indirectCalledTypes ) {
176+ callees.insert (calleeType);
177+
178+ // Add the key to ensure the lookup doesn't fail for indirect calls to
179+ // uninhabited types.
180+ callGraph[calleeType];
123181 }
182+
183+ // Type -> Function
184+ callGraph[caller->type .getHeapType ()].insert (caller);
124185 }
125186
187+ // Type -> Type
188+ // Do a DFS up the type heirarchy for all function implementations.
189+ // We are essentially walking up each supertype chain and adding edges from
190+ // super -> subtype, but doing it via DFS to avoid repeated work.
191+ Graph superTypeGraph (allFunctionTypes.begin (),
192+ allFunctionTypes.end (),
193+ [&callGraph](auto && push, HeapType t) {
194+ // Not needed except that during lookup we expect the
195+ // key to exist.
196+ callGraph[t];
197+
198+ if (auto super = t.getDeclaredSuperType ()) {
199+ callGraph[*super].insert (t);
200+ push (*super);
201+ }
202+ });
203+ (void )superTypeGraph.traverseDepthFirst ();
204+
126205 return callGraph;
127206}
128207
@@ -152,63 +231,60 @@ void propagateEffects(const Module& module,
152231 const PassOptions& passOptions,
153232 std::map<Function*, FuncInfo>& funcInfos,
154233 const CallGraph& callGraph) {
234+ // We only care about Functions that are roots, not types.
235+ // A type would be a root if a function exists with that type, but no-one
236+ // indirect calls the type.
237+ auto funcNodes = std::views::keys (callGraph) |
238+ std::views::filter ([](auto node) {
239+ return std::holds_alternative<Function*>(node);
240+ }) |
241+ std::views::common;
242+ using funcNodesType = decltype (funcNodes);
243+
155244 struct CallGraphSCCs
156- : SCCs<std::vector<Function*>::const_iterator, CallGraphSCCs> {
245+ : SCCs<std::ranges::iterator_t <funcNodesType>, CallGraphSCCs> {
246+
157247 const std::map<Function*, FuncInfo>& funcInfos;
158- const std::unordered_map<Function*, std::unordered_set<Function*>>&
159- callGraph;
248+ const CallGraph& callGraph;
160249 const Module& module ;
161250
162- CallGraphSCCs (
163- const std::vector<Function*>& funcs,
164- const std::map<Function*, FuncInfo>& funcInfos,
165- const std::unordered_map<Function*, std::unordered_set<Function*>>&
166- callGraph,
167- const Module& module )
168- : SCCs<std::vector<Function*>::const_iterator, CallGraphSCCs>(
169- funcs.begin(), funcs.end()),
251+ CallGraphSCCs (funcNodesType&& nodes,
252+ const std::map<Function*, FuncInfo>& funcInfos,
253+ const CallGraph& callGraph,
254+ const Module& module )
255+ : SCCs<std::ranges::iterator_t <funcNodesType>, CallGraphSCCs>(
256+ std::ranges::begin (nodes), std::ranges::end(nodes)),
170257 funcInfos (funcInfos), callGraph(callGraph), module(module ) {}
171258
172- void pushChildren (Function* f) {
173- auto callees = callGraph.find (f);
174- if (callees == callGraph.end ()) {
175- return ;
176- }
177-
178- for (auto * callee : callees->second ) {
259+ void pushChildren (CallGraphNode node) {
260+ for (CallGraphNode callee : callGraph.at (node)) {
179261 push (callee);
180262 }
181263 }
182264 };
183-
184- std::vector<Function*> allFuncs;
185- for (auto & [func, info] : funcInfos) {
186- allFuncs.push_back (func);
187- }
188- CallGraphSCCs sccs (allFuncs, funcInfos, callGraph, module );
265+ CallGraphSCCs sccs (std::move(funcNodes), funcInfos, callGraph, module);
189266
190267 std::vector<std::optional<EffectAnalyzer>> componentEffects;
191268 // Points to an index in componentEffects
192- std::unordered_map<Function* , Index> funcComponents ;
269+ std::unordered_map<CallGraphNode , Index> nodeComponents ;
193270
194271 for (auto ccIterator : sccs) {
195272 std::optional<EffectAnalyzer>& ccEffects =
196273 componentEffects.emplace_back (std::in_place, passOptions, module );
274+ std::vector<CallGraphNode> cc (ccIterator.begin (), ccIterator.end ());
197275
198- std::vector<Function*> ccFuncs (ccIterator.begin (), ccIterator.end ());
199-
200- for (Function* f : ccFuncs) {
201- funcComponents.emplace (f, componentEffects.size () - 1 );
276+ std::vector<Function*> ccFuncs;
277+ for (CallGraphNode node : cc) {
278+ nodeComponents.emplace (node, componentEffects.size () - 1 );
279+ if (auto ** func = std::get_if<Function*>(&node)) {
280+ ccFuncs.push_back (*func);
281+ }
202282 }
203283
204284 std::unordered_set<int > calleeSccs;
205- for (Function* caller : ccFuncs) {
206- auto callees = callGraph.find (caller);
207- if (callees == callGraph.end ()) {
208- continue ;
209- }
210- for (auto * callee : callees->second ) {
211- calleeSccs.insert (funcComponents.at (callee));
285+ for (CallGraphNode caller : cc) {
286+ for (CallGraphNode callee : callGraph.at (caller)) {
287+ calleeSccs.insert (nodeComponents.at (callee));
212288 }
213289 }
214290
@@ -219,11 +295,13 @@ void propagateEffects(const Module& module,
219295 }
220296
221297 // Add trap effects for potential cycles.
222- if (ccFuncs .size () > 1 ) {
298+ if (cc .size () > 1 ) {
223299 if (ccEffects != UnknownEffects) {
224300 ccEffects->trap = true ;
225301 }
226- } else {
302+ } else if (ccFuncs.size () == 1 ) {
303+ // It's possible for a CC to only contain 1 type, but that is not a
304+ // cycle in the call graph.
227305 auto * func = ccFuncs[0 ];
228306 if (funcInfos.at (func).calledFunctions .contains (func->name )) {
229307 if (ccEffects != UnknownEffects) {
@@ -267,7 +345,8 @@ struct GenerateGlobalEffects : public Pass {
267345 std::map<Function*, FuncInfo> funcInfos =
268346 analyzeFuncs (*module , getPassOptions ());
269347
270- auto callGraph = buildCallGraph (*module , funcInfos);
348+ auto callGraph =
349+ buildCallGraph (*module , funcInfos, getPassOptions ().closedWorld );
271350
272351 propagateEffects (*module , getPassOptions (), funcInfos, callGraph);
273352
0 commit comments