Skip to content

Commit e2ab435

Browse files
[Activation-checkpointing] add tensor dedup and param offloading (#4247)
Co-authored-by: Sergio Paniego Blanco <sergiopaniegoblanco@gmail.com>
1 parent 46a53cd commit e2ab435

File tree

2 files changed

+307
-61
lines changed

2 files changed

+307
-61
lines changed

tests/test_activation_offloading.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,3 +154,56 @@ def test_real_hf_model(self):
154154
assert torch.allclose(out1, out2, rtol=1e-5)
155155
for g1, g2 in zip(grads1, grads2):
156156
assert torch.allclose(g1, g2, rtol=1e-5)
157+
158+
@require_torch_accelerator
159+
def test_tensor_deduplication(self):
160+
"""Test that deduplication works correctly for tensors sharing storage"""
161+
162+
class ModelWithViews(nn.Module):
163+
def __init__(self):
164+
super().__init__()
165+
self.linear = nn.Linear(100, 100)
166+
167+
def forward(self, x):
168+
out = self.linear(x)
169+
view1 = out.view(-1)
170+
view2 = out.transpose(0, 1)
171+
return view1.sum() + view2.sum()
172+
173+
model = ModelWithViews().to(torch_device)
174+
offload_ctx = OffloadActivations(min_offload_size=1)
175+
offload_ctx.update_model_params(model)
176+
177+
x = torch.randn(10, 100, device=torch_device, requires_grad=True)
178+
with offload_ctx:
179+
loss = model(x)
180+
181+
total_tensor_ids = offload_ctx.tensor_id
182+
assert total_tensor_ids > 0, "Should have created tensor IDs"
183+
184+
# modified=True means offloaded to CPU, modified=False means kept on GPU (deduplicated)
185+
deduplicated_count = sum(1 for _, modified, _, _, _ in offload_ctx.tracker.values() if not modified)
186+
offloaded_count = sum(1 for _, modified, _, _, _ in offload_ctx.tracker.values() if modified)
187+
188+
assert offloaded_count > 0, "Should have offloaded at least one tensor"
189+
assert deduplicated_count > 0, "Should have deduplicated at least one tensor (view)"
190+
191+
unique_storages_offloaded = len(offload_ctx.storage_to_tensor_id)
192+
assert unique_storages_offloaded < total_tensor_ids, (
193+
f"Deduplication should result in fewer storages ({unique_storages_offloaded}) "
194+
f"than total tensors ({total_tensor_ids})"
195+
)
196+
197+
loss.backward()
198+
199+
@require_torch_accelerator
200+
def test_parameter_filtering(self):
201+
"""Test that model parameters are filtered during offloading"""
202+
model = nn.Sequential(nn.Linear(10, 20), nn.Linear(20, 10)).to(torch_device)
203+
offload_ctx = OffloadActivations()
204+
offload_ctx.update_model_params(model)
205+
206+
assert len(offload_ctx.param_storages) > 0, "Should have tracked parameter storages"
207+
208+
param_ptrs = {p.data.untyped_storage().data_ptr() for p in model.parameters()}
209+
assert offload_ctx.param_storages == param_ptrs, "Tracked storages should match parameter storages"

0 commit comments

Comments
 (0)