Skip to content

Commit 0f323fb

Browse files
isururanawakafacebook-github-bot
authored andcommitted
ReshardingAPI Host Memory Offloading and BenchmarkReshardingHandler (#3291)
Summary: Pull Request resolved: #3291 - Implements tensor offloading to host memory inside resharding API - Add BenchmarkReshardingHandler - generate random plans - calls DDP reshard API by selecting random plans - Add reset method to train_pipelines.py Reviewed By: aporialiao Differential Revision: D80366926 fbshipit-source-id: a137da2f36cbacf21f0c28ae83dfc6eabba29901
1 parent 3b437e6 commit 0f323fb

File tree

6 files changed

+431
-133
lines changed

6 files changed

+431
-133
lines changed
Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
# pyre-strict
9+
10+
import logging
11+
import random
12+
from typing import List, Optional
13+
14+
import torch
15+
import torch.distributed as dist
16+
import torch.nn as nn
17+
from torchrec.distributed.embeddingbag import EmbeddingBagCollection
18+
19+
from torchrec.distributed.sharding.dynamic_sharding import output_sharding_plan_delta
20+
21+
from torchrec.distributed.sharding_plan import (
22+
column_wise,
23+
construct_module_sharding_plan,
24+
table_wise,
25+
)
26+
27+
from torchrec.distributed.test_utils.test_sharding import generate_rank_placements
28+
from torchrec.distributed.types import EmbeddingModuleShardingPlan
29+
30+
logger: logging.Logger = logging.getLogger(__name__)
31+
32+
33+
class ReshardingHandler:
34+
"""
35+
Handles the resharding of a training module by generating and applying different sharding plans.
36+
"""
37+
38+
def __init__(self, train_module: nn.Module, num_plans: int) -> None:
39+
"""
40+
Initializes the ReshardingHandler with a training module and the number of sharding plans to generate.
41+
42+
Args:
43+
train_module (nn.Module): The training module to be resharded.
44+
num_plans (int): The number of sharding plans to generate.
45+
"""
46+
self._train_module = train_module
47+
if not hasattr(train_module, "_module"):
48+
raise RuntimeError("Incorrect train module")
49+
50+
if not hasattr(train_module._module, "plan"):
51+
raise RuntimeError("sharding plan cannot be found")
52+
53+
# Pyre-ignore
54+
plan = train_module._module.plan.plan
55+
key = "main_module.sparse_arch.embedding_bag_collection"
56+
module = (
57+
# Pyre-ignore
58+
train_module._module.module.main_module.sparse_arch.embedding_bag_collection
59+
)
60+
self._resharding_plans: List[EmbeddingModuleShardingPlan] = []
61+
world_size = dist.get_world_size()
62+
63+
# TODO: change this logic when, proper planner is integrated
64+
if key in plan:
65+
ebc = plan[key]
66+
num_tables = len(ebc)
67+
ranks_per_tables = [1 for _ in range(num_tables)]
68+
ranks_per_tables_for_CW = []
69+
for index, table_config in enumerate(module._embedding_bag_configs):
70+
# CW sharding
71+
valid_candidates = [
72+
i
73+
for i in range(1, world_size + 1)
74+
if table_config.embedding_dim % i == 0
75+
]
76+
rng = random.Random(index)
77+
ranks_per_tables_for_CW.append(rng.choice(valid_candidates))
78+
79+
for i in range(num_plans):
80+
new_ranks = generate_rank_placements(
81+
world_size, num_tables, ranks_per_tables, i
82+
)
83+
new_ranks_cw = generate_rank_placements(
84+
world_size, num_tables, ranks_per_tables_for_CW, i
85+
)
86+
new_per_param_sharding = {}
87+
for i, (talbe_id, param) in enumerate(ebc.items()):
88+
if param.sharding_type == "column_wise":
89+
cw_gen = column_wise(
90+
ranks=new_ranks_cw[i],
91+
compute_kernel=param.compute_kernel,
92+
)
93+
new_per_param_sharding[talbe_id] = cw_gen
94+
else:
95+
tw_gen = table_wise(
96+
rank=new_ranks[i][0],
97+
compute_kernel=param.compute_kernel,
98+
)
99+
new_per_param_sharding[talbe_id] = tw_gen
100+
101+
lightweight_ebc = EmbeddingBagCollection(
102+
tables=module._embedding_bag_configs,
103+
device=torch.device(
104+
"meta"
105+
), # Use meta device to avoid actual memory allocation
106+
)
107+
108+
meta_device = torch.device("meta")
109+
new_plan = construct_module_sharding_plan(
110+
lightweight_ebc,
111+
per_param_sharding=new_per_param_sharding, # Pyre-ignore
112+
local_size=world_size,
113+
world_size=world_size,
114+
# Pyre-ignore
115+
device_type=meta_device,
116+
)
117+
self._resharding_plans.append(new_plan)
118+
else:
119+
raise RuntimeError(f"Plan does not have key: {key}")
120+
121+
def step(self, batch_no: int) -> float:
122+
"""
123+
Executes a step in the training process by selecting and applying a sharding plan.
124+
125+
Args:
126+
batch_no (int): The current batch number.
127+
128+
Returns:
129+
float: The data volume of the sharding plan delta.
130+
"""
131+
# Pyre-ignore
132+
plan = self._train_module._module.plan.plan
133+
key = "main_module.sparse_arch.embedding_bag_collection"
134+
135+
# Use the current step as a seed to ensure all ranks get the same random number
136+
# but it changes on each call
137+
138+
rng = random.Random(batch_no)
139+
index = rng.randint(0, len(self._resharding_plans) - 1)
140+
logger.info(f"Selected resharding plan index {index} for step {batch_no}")
141+
# Get the selected plan
142+
selected_plan = self._resharding_plans[index]
143+
144+
# Fix the device mismatch by updating the placement device in the sharding spec
145+
# This is necessary because the plan was created with meta device but needs to be applied on CUDA
146+
for _, param_sharding in selected_plan.items():
147+
sharding_spec = param_sharding.sharding_spec
148+
if sharding_spec is not None:
149+
# pyre-ignore
150+
for shard in sharding_spec.shards:
151+
placement = shard.placement
152+
rank: Optional[int] = placement.rank()
153+
assert rank is not None
154+
current_device = (
155+
torch.cuda.current_device()
156+
if rank == torch.distributed.get_rank()
157+
else rank % torch.cuda.device_count()
158+
)
159+
shard.placement = torch.distributed._remote_device(
160+
f"rank:{rank}/cuda:{current_device}"
161+
)
162+
163+
data_volume, delta_plan = output_sharding_plan_delta(
164+
plan[key], selected_plan, True
165+
)
166+
# Pyre-ignore
167+
self._train_module.module.reshard(
168+
sharded_module_fqn="main_module.sparse_arch.embedding_bag_collection",
169+
changed_shard_to_params=delta_plan,
170+
)
171+
return data_volume

torchrec/distributed/embeddingbag.py

Lines changed: 35 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,11 @@
5757
from torchrec.distributed.sharding.dp_sharding import DpPooledEmbeddingSharding
5858
from torchrec.distributed.sharding.dynamic_sharding import (
5959
get_largest_dims_from_sharding_plan_updates,
60+
move_sharded_tensors_to_cpu,
6061
shards_all_to_all,
6162
update_module_sharding_plan,
6263
update_optimizer_state_post_resharding,
63-
update_state_dict_post_resharding,
64+
update_state_post_resharding,
6465
)
6566
from torchrec.distributed.sharding.grid_sharding import GridPooledEmbeddingSharding
6667
from torchrec.distributed.sharding.rw_sharding import RwPooledEmbeddingSharding
@@ -1377,6 +1378,25 @@ def _init_mean_pooling_callback(
13771378
device=self._device,
13781379
)
13791380

1381+
def _purge_lookups(self) -> None:
1382+
# Purge old lookups
1383+
for lookup in self._lookups:
1384+
# Call purge method if available (for TBE modules)
1385+
if hasattr(lookup, "purge") and callable(lookup.purge):
1386+
# Pyre-ignore
1387+
lookup.purge()
1388+
1389+
# For DDP modules, get the underlying module
1390+
while isinstance(lookup, DistributedDataParallel):
1391+
lookup = lookup.module
1392+
if hasattr(lookup, "purge") and callable(lookup.purge):
1393+
lookup.purge()
1394+
1395+
# Clear the lookups list
1396+
self._lookups.clear()
1397+
# Force garbage collection to free memory
1398+
torch.cuda.empty_cache()
1399+
13801400
def _create_lookups(
13811401
self,
13821402
) -> None:
@@ -1723,12 +1743,13 @@ def update_shards(
17231743
env (ShardingEnv): The sharding environment for the module.
17241744
device (Optional[torch.device]): The device to place the updated module on.
17251745
"""
1726-
17271746
if env.output_dtensor:
17281747
raise RuntimeError("We do not yet support DTensor for resharding yet")
17291748
return
17301749

17311750
current_state = self.state_dict()
1751+
current_state = move_sharded_tensors_to_cpu(current_state)
1752+
# TODO: improve, checking one would be enough
17321753
has_local_optimizer = len(self._optim._optims) > 0 and all(
17331754
len(i) > 0 for i in self._optim.state_dict()["state"].values()
17341755
)
@@ -1740,22 +1761,18 @@ def update_shards(
17401761

17411762
has_optimizer = self._is_optimizer_enabled(has_local_optimizer, env, device)
17421763

1743-
# TODO: Saving lookups tensors to CPU to eventually avoid recreating them completely again
1744-
# TODO: Ensure lookup tensors are actually being deleted
1745-
for _, lookup in enumerate(self._lookups):
1746-
# pyre-ignore
1747-
lookup.purge()
1748-
1749-
# Deleting all lookups
1750-
self._lookups.clear()
1764+
# TODO: make sure this is clearing all lookups
1765+
self._purge_lookups()
17511766

17521767
# Get max dim size to enable padding for all_to_all
17531768
max_dim_0, max_dim_1 = get_largest_dims_from_sharding_plan_updates(
17541769
changed_sharding_params
17551770
)
17561771
old_optimizer_state = self._optim.state_dict() if has_local_optimizer else None
1772+
if old_optimizer_state is not None:
1773+
move_sharded_tensors_to_cpu(old_optimizer_state)
17571774

1758-
local_shard_names_by_src_rank, local_output_tensor = shards_all_to_all(
1775+
local_shard_names_by_src_rank, local_output_tensor_cpu = shards_all_to_all(
17591776
module=self,
17601777
state_dict=current_state,
17611778
device=device, # pyre-ignore
@@ -1832,22 +1849,21 @@ def update_shards(
18321849
if has_optimizer:
18331850
optimizer_state = update_optimizer_state_post_resharding(
18341851
old_opt_state=old_optimizer_state, # pyre-ignore
1835-
new_opt_state=copy.deepcopy(self._optim.state_dict()),
1852+
new_opt_state=self._optim.state_dict(),
18361853
ordered_shard_names_and_lengths=local_shard_names_by_src_rank,
1837-
output_tensor=local_output_tensor,
1854+
output_tensor=local_output_tensor_cpu,
18381855
max_dim_0=max_dim_0,
18391856
extend_shard_name=self.extend_shard_name,
18401857
)
18411858
self._optim.load_state_dict(optimizer_state)
18421859

1843-
current_state = update_state_dict_post_resharding(
1844-
state_dict=current_state,
1860+
new_state = self.state_dict()
1861+
current_state = update_state_post_resharding(
1862+
old_state=current_state,
1863+
new_state=new_state,
18451864
ordered_shard_names_and_lengths=local_shard_names_by_src_rank,
1846-
output_tensor=local_output_tensor,
1847-
new_sharding_params=changed_sharding_params,
1848-
curr_rank=dist.get_rank(),
1865+
output_tensor=local_output_tensor_cpu,
18491866
extend_shard_name=self.extend_shard_name,
1850-
max_dim_0=max_dim_0,
18511867
has_optimizer=has_optimizer,
18521868
)
18531869

0 commit comments

Comments
 (0)