Skip to content

[release/2.6] [SWDEV-531526] [SWDEV-527340] Allocation of buffers ordered before compute #2276

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 27, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion torch/_inductor/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2247,21 +2247,43 @@ def topological_sort_schedule(
name_to_node: Dict[str, BaseSchedulerNode] = dict()
result: List[BaseSchedulerNode] = []

def has_mutations(node: BaseSchedulerNode) -> bool:
return any(buf.get_mutations() for buf in node.get_outputs())

def visit(n: BaseSchedulerNode) -> None:
if n not in seen:
seen.add(n)

# Visit regular dependencies
for dep in sorted(n.unmet_dependencies, key=lambda d: d.name):
# We only care about doing toposort within `nodes`
if dep.name not in name_to_node:
continue
visit(name_to_node[dep.name])

# Visit mutation dependencies
for buf in n.get_outputs():
for mutation in buf.get_mutations():
if mutation in name_to_node and name_to_node[mutation] != n:
visit(name_to_node[mutation])

result.append(n)

# Build name to node mapping
for node in nodes:
for name in node.get_buffer_names():
name_to_node[name] = node

# Visit non-mutation nodes first
for node in nodes:
if not has_mutations(node):
visit(node)

# Then visit mutation nodes
for node in nodes:
visit(node)
if has_mutations(node):
visit(node)

return result

def _get_unmet_dep_nodes(self, snode: BaseSchedulerNode) -> List[BaseSchedulerNode]:
Expand Down