Skip to content

Commit ba40125

Browse files
committed
Add Activation Checkpointing Pass
Work in progress: copied the latest code over from `whc/hack_aot` and tweaked the way it gets hooked up a bit, haven't tested yet. Likely need to discuss whether we want the AC pass to be popped back off inductor's passes earlier or keep it at __exit__ from AutoParallel.
1 parent badffa7 commit ba40125

File tree

2 files changed

+306
-1
lines changed

2 files changed

+306
-1
lines changed
Lines changed: 288 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,288 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2+
#
3+
# This source code is licensed under the BSD license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
import logging
6+
from collections import defaultdict
7+
from dataclasses import dataclass
8+
9+
import torch
10+
from torch._functorch.partitioners import _has_tag_is_backward, _size_of
11+
from torch.utils._ordered_set import OrderedSet
12+
from torch.utils.checkpoint import CheckpointPolicy
13+
14+
logger: logging.Logger = logging.getLogger(__name__)
15+
logger.setLevel(logging.INFO)
16+
17+
18+
# reimplement torch._functorch.partitioners.must_recompute
19+
# to only check for MUST_RECOMPUTE tag, and not PREFER_RECOMPUTE
20+
# For now there isn't any distinction in the partitioner between both
21+
# and I think this is a bug
22+
def must_recompute(node: torch.fx.Node) -> bool:
23+
return node.meta.get("recompute", None) is CheckpointPolicy.MUST_RECOMPUTE
24+
25+
26+
def is_graph_input(node: torch.fx.Node) -> bool:
27+
return node.op == "placeholder"
28+
29+
30+
def is_wait_tensor(node: torch.fx.Node) -> bool:
31+
return (
32+
node.op == "call_function"
33+
and node.target == torch.ops._c10d_functional.wait_tensor.default
34+
)
35+
36+
37+
def is_all_gather_into_tensor(node: torch.fx.Node) -> bool:
38+
return (
39+
node.op == "call_function"
40+
and node.target == torch.ops._c10d_functional.all_gather_into_tensor.default
41+
)
42+
43+
44+
def is_wait_tensor_from_fsdp(node: torch.fx.Node) -> bool:
45+
if is_wait_tensor(node) and is_all_gather_into_tensor(node.args[0]):
46+
# TODO: this needs to be improved, its firing in autoparallel "2D" case where input to AG is wait,
47+
# maybe just 2D FSDP
48+
# ag_node = node.args[0]
49+
# assert is_graph_input(ag_node.args[0]) or (
50+
# ag_node.args[0].op == "call_function"
51+
# and ag_node.args[0].target == torch.ops.prims.convert_element_type.default
52+
# and is_graph_input(ag_node.args[0].args[0])
53+
# ), (
54+
# "Assume all_gather_into_tensor input is either graph input "
55+
# + f"or dtype conversion of graph input, but got {ag_node.args[0]}"
56+
# )
57+
return True
58+
return False
59+
60+
61+
# mypy: ignore-errors
62+
63+
64+
def force_recompute_fsdp_all_gather(graph: torch.fx.Graph) -> None:
65+
"""
66+
Force recompute all_gather nodes from simple fsdp in the graph.
67+
68+
This pass should be added in torch._inductor.config.joint_custom_post_pass
69+
"""
70+
71+
def force_recompute_node(node):
72+
node.meta["recompute"] = CheckpointPolicy.MUST_RECOMPUTE
73+
if "ac_graph_id" not in node.meta:
74+
# ac_graph_id is used in the partitioner to decide
75+
# if two nodes which have AC applied come from a different
76+
# AC regions. This is needed because nodes in the boundary
77+
# of two AC regions are marked as MUST_SAVE. In our case
78+
# we just add a large value of ac_graph_id so that
79+
# all nodes we tag for recomputation do indeed get recomputed
80+
# and are not influenced by other nodes in the graph with
81+
# nearby ac_graph_id values
82+
node.meta["ac_graph_id"] = 100000
83+
84+
# Make all-gather nodes (and related nodes) recomputable, to circumvent
85+
# https://github.com/pytorch/pytorch/issues/136433
86+
for node in graph.nodes:
87+
if is_wait_tensor_from_fsdp(node):
88+
ag_node = node.args[0]
89+
force_recompute_node(ag_node) # all_gather
90+
force_recompute_node(node) # wait_tensor
91+
# Force-recompute slice that comes after wait
92+
for user in node.users:
93+
if (
94+
user.op == "call_function"
95+
and user.target == torch.ops.aten.slice.Tensor
96+
):
97+
force_recompute_node(user)
98+
# Force-recompute potential dtype casts from all_gather
99+
if (
100+
ag_node.all_input_nodes[0].op == "call_function"
101+
and ag_node.args[0].target
102+
== torch.ops.prims.convert_element_type.default
103+
):
104+
force_recompute_node(ag_node.all_input_nodes[0])
105+
106+
107+
def mark_nodes_as_must_save_to_stage_recomputation(
108+
joint_graph: torch.fx.Graph,
109+
stage_size_in_GiB: float = -1,
110+
) -> None:
111+
"""
112+
Marks specific nodes as "must save" to optimize memory usage during recomputation.
113+
114+
With aggressive recomputation strategies, we often encounter situations where long chains
115+
of forward nodes must be recomputed before executing backward pass nodes, causing high
116+
peak memory usage. This function breaks these recomputation chains into smaller stages
117+
based by periodically saving itermediate nodes, keeping peak memory usage below.
118+
119+
Args:
120+
joint_graph: The joint graph containing both forward and backward nodes
121+
stage_size_in_GiB: Target memory size per stage in GiB (-1 to disable staging)
122+
"""
123+
124+
if stage_size_in_GiB < 0:
125+
return
126+
127+
INT_INF = int(1e9)
128+
129+
def get_required_fwd_nodes(
130+
joint_graph: torch.fx.Graph,
131+
) -> OrderedSet[torch.fx.Node]:
132+
"""
133+
Return the set of nodes that are required in the forward graph.
134+
135+
NOTE: this is doing similar things as classify_nodes() in _functorch/partitioenrs.py
136+
where nodes are classified into three types -- fwd, bwd, and unclaimed
137+
both bwd and unclaimed nodes have partitioner_tag equal to "is_backward"
138+
"""
139+
required_fwd_nodes: OrderedSet[torch.fx.Node] = OrderedSet()
140+
for node in joint_graph.nodes:
141+
if node.op == "placeholder" and "tangents" in node.target:
142+
continue
143+
if node.op == "output":
144+
continue
145+
if _has_tag_is_backward(node):
146+
continue
147+
required_fwd_nodes.add(node)
148+
return required_fwd_nodes
149+
150+
def get_node_distance_to_bwd(
151+
joint_graph: torch.fx.Graph,
152+
get_required_fwd_nodes: OrderedSet[torch.fx.Node],
153+
) -> dict[torch.fx.Node, int]:
154+
"""
155+
Compute and return the distance of all nodes to the closest backward node.
156+
If a node is not an ancestor of a backward node, then its distance is INT_INF.
157+
158+
NOTE: this is adapted from
159+
https://github.com/pytorch/pytorch/blob/3196a3aca0f16792820158cfd451cb977f99ac7e/torch/_functorch/partitioners.py#L2089-L2097
160+
"""
161+
dist_from_bw = {}
162+
for node in reversed(joint_graph.nodes):
163+
if node.op == "output":
164+
dist_from_bw[node] = INT_INF
165+
elif node not in get_required_fwd_nodes:
166+
dist_from_bw[node] = 0
167+
else:
168+
dist_from_bw[node] = INT_INF
169+
for user in node.users:
170+
dist_from_bw[node] = min(dist_from_bw[node], dist_from_bw[user] + 1)
171+
return dist_from_bw
172+
173+
def get_all_recomputable_forward_nodes(
174+
joint_graph: torch.fx.Graph,
175+
) -> OrderedSet[torch.fx.Node]:
176+
"""
177+
Return the set of all forward nodes that are recomputable
178+
"""
179+
required_fwd_nodes = get_required_fwd_nodes(joint_graph)
180+
dist_from_bw = get_node_distance_to_bwd(joint_graph, required_fwd_nodes)
181+
fwd_recomputable_nodes: OrderedSet[torch.fx.Node] = OrderedSet()
182+
for node in joint_graph.nodes:
183+
if (
184+
node in required_fwd_nodes
185+
and dist_from_bw[node] < INT_INF
186+
and node.op != "placeholder"
187+
):
188+
fwd_recomputable_nodes.add(node)
189+
return fwd_recomputable_nodes
190+
191+
def mark_nodes_as_must_save(must_save_nodes: list[torch.fx.Node]) -> None:
192+
"""
193+
Given a list of nodes, mark them as must save.
194+
"""
195+
print(f"mark_nodes_as_must_save: {must_save_nodes}")
196+
for node in must_save_nodes:
197+
node.meta["recompute"] = CheckpointPolicy.MUST_SAVE
198+
199+
fwd_recomputable_nodes = get_all_recomputable_forward_nodes(joint_graph)
200+
201+
# Initialize all nodes as 'prefer recompute' and then adjust only the must-save ones below
202+
for node in fwd_recomputable_nodes:
203+
if node.meta.get("recompute", None) is not None:
204+
# do not mess with allgather nodes that have already been marked recompute!
205+
continue
206+
node.meta["recompute"] = CheckpointPolicy.PREFER_RECOMPUTE
207+
# add an arbitrarily large graph id. I'm assuming 100000 here, which should be fine
208+
# and is the same we add for the all-gather nodes
209+
node.meta["ac_graph_id"] = 100000
210+
211+
# get the mapping between node name and node
212+
name_to_node_mapping = {}
213+
for node in fwd_recomputable_nodes:
214+
name_to_node_mapping[node.name] = node
215+
216+
# populate node_to_predecessors, accounting for must_recompute nodes. In particular,
217+
# if a node is marked as must recompute, then for its users, their predecessors should
218+
# be updated to be instead the predecessors of the must recompute node.
219+
node_to_predecessors = defaultdict(OrderedSet)
220+
for node in fwd_recomputable_nodes:
221+
node_to_predecessors[node] = OrderedSet(
222+
[pred for pred in node.all_input_nodes if pred in fwd_recomputable_nodes]
223+
)
224+
for node in fwd_recomputable_nodes:
225+
if not must_recompute(node):
226+
continue
227+
for user in node.users:
228+
if user in fwd_recomputable_nodes:
229+
node_to_predecessors[user].remove(node)
230+
node_to_predecessors[user].update(node_to_predecessors[node])
231+
232+
# populate node_to_last_usage
233+
# if A is last used by B, then A \in node_to_last_usage[B]
234+
node_to_last_usage = defaultdict(OrderedSet)
235+
last_used_by = {}
236+
for node in fwd_recomputable_nodes:
237+
last_used_by[node] = node
238+
for pred in node_to_predecessors[node]:
239+
last_used_by[pred] = node
240+
for producer, consumer in last_used_by.items():
241+
node_to_last_usage[consumer].add(producer)
242+
243+
# loop through nodes in order of the forward graph and keep track of the following:
244+
# for each node, right before its execution, the output of what nodes are in memory.
245+
@dataclass
246+
class NodeCutScore:
247+
tot_mem: float
248+
alive_node_names: OrderedSet[str]
249+
250+
alive_nodes = OrderedSet()
251+
node2score = {}
252+
for node in fwd_recomputable_nodes:
253+
if not must_recompute(node):
254+
alive_nodes.add(node)
255+
for a in node_to_last_usage[node]:
256+
alive_nodes.remove(a)
257+
tot_mem = sum(_size_of(node) for node in alive_nodes)
258+
node2score[node] = NodeCutScore(
259+
tot_mem, OrderedSet([n.name for n in alive_nodes])
260+
)
261+
262+
# divide the graph into stages with roughly equal memory usage.
263+
stages = defaultdict(OrderedSet)
264+
cum_mem_so_far = 0
265+
curr_stage_idx = 0
266+
target_mem = stage_size_in_GiB * 2**30
267+
for node in fwd_recomputable_nodes:
268+
stages[curr_stage_idx].add(node)
269+
if not must_recompute(node):
270+
cum_mem_so_far += _size_of(node)
271+
if cum_mem_so_far >= target_mem:
272+
curr_stage_idx += 1
273+
cum_mem_so_far = 0
274+
275+
# loop through each stage and pick the best node to cut on, and save
276+
# the nodes that will be marked as must save.
277+
nodes_to_save = OrderedSet()
278+
for stage in stages.values():
279+
best_node = min(stage, key=lambda x: node2score[x].tot_mem)
280+
nodes_to_save.update(node2score[best_node].alive_node_names)
281+
mark_nodes_as_must_save([name_to_node_mapping[n] for n in nodes_to_save])
282+
283+
284+
def ac_joint_pass(graph: torch.fx.Graph, ac_stage_size_in_GiB: float = 2.0):
285+
force_recompute_fsdp_all_gather(graph)
286+
mark_nodes_as_must_save_to_stage_recomputation(
287+
graph, stage_size_in_GiB=ac_stage_size_in_GiB
288+
)

autoparallel/api.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import copy
77
import itertools
88
from contextlib import ExitStack
9+
from functools import partial
910
from types import MethodType
1011
from typing import Optional
1112

@@ -24,6 +25,7 @@
2425
from torch.export._unlift import _assign_attr
2526
from torch.export.unflatten import _AttrKind
2627

28+
from .activation_checkpointing import ac_joint_pass
2729
from .apply_sharding import apply_sharding_to_model
2830
from .cast_parametrization import apply_dtype_cast, canonicalize_mp, set_dtype_cast
2931
from .init_weights import hook_params_setters
@@ -163,6 +165,8 @@ def __init__(
163165
input_fn,
164166
mesh: DeviceMesh,
165167
mp_policy: Optional[MixedPrecisionPolicy] = None,
168+
enable_ac: bool = True,
169+
ac_stage_size_in_GiB: float = 2.0,
166170
):
167171
self.stack = ExitStack()
168172
self.fake_mode = (
@@ -187,14 +191,23 @@ def __init__(
187191
self.model = move_to_fake(model, self.fake_mode, device)
188192
self.input_fn = input_fn
189193
self.mesh = mesh
194+
self.enable_ac = enable_ac
195+
self.ac_stage_size_in_GiB = ac_stage_size_in_GiB
190196

191197
# NB: rest of the construction happens in __enter__
192-
193198
self.active = False
194199

195200
def __enter__(self):
196201
assert self.active is False
197202

203+
if self.enable_ac:
204+
self.orig_inductor_joint_custom_post_pass = (
205+
torch._inductor.config.joint_custom_post_pass
206+
)
207+
torch._inductor.config.joint_custom_post_pass = partial(
208+
ac_joint_pass, ac_stage_size_in_GiB=self.ac_stage_size_in_GiB
209+
)
210+
198211
self.build_model_graph()
199212

200213
rescale_grad_comm_cost_for_mp = 1.0
@@ -225,6 +238,10 @@ def __enter__(self):
225238

226239
def __exit__(self, exc_type, exc_val, exc_tb):
227240
self.active = None
241+
if self.enable_ac:
242+
torch._inductor.config.joint_custom_post_pass = (
243+
self.orig_inductor_joint_custom_post_pass
244+
)
228245
return self.stack.__exit__(exc_type, exc_val, exc_tb)
229246

230247
def _assert_entered(self):

0 commit comments

Comments
 (0)