@@ -154,3 +154,56 @@ def test_real_hf_model(self):
154154 assert torch .allclose (out1 , out2 , rtol = 1e-5 )
155155 for g1 , g2 in zip (grads1 , grads2 ):
156156 assert torch .allclose (g1 , g2 , rtol = 1e-5 )
157+
158+ @require_torch_accelerator
159+ def test_tensor_deduplication (self ):
160+ """Test that deduplication works correctly for tensors sharing storage"""
161+
162+ class ModelWithViews (nn .Module ):
163+ def __init__ (self ):
164+ super ().__init__ ()
165+ self .linear = nn .Linear (100 , 100 )
166+
167+ def forward (self , x ):
168+ out = self .linear (x )
169+ view1 = out .view (- 1 )
170+ view2 = out .transpose (0 , 1 )
171+ return view1 .sum () + view2 .sum ()
172+
173+ model = ModelWithViews ().to (torch_device )
174+ offload_ctx = OffloadActivations (min_offload_size = 1 )
175+ offload_ctx .update_model_params (model )
176+
177+ x = torch .randn (10 , 100 , device = torch_device , requires_grad = True )
178+ with offload_ctx :
179+ loss = model (x )
180+
181+ total_tensor_ids = offload_ctx .tensor_id
182+ assert total_tensor_ids > 0 , "Should have created tensor IDs"
183+
184+ # modified=True means offloaded to CPU, modified=False means kept on GPU (deduplicated)
185+ deduplicated_count = sum (1 for _ , modified , _ , _ , _ in offload_ctx .tracker .values () if not modified )
186+ offloaded_count = sum (1 for _ , modified , _ , _ , _ in offload_ctx .tracker .values () if modified )
187+
188+ assert offloaded_count > 0 , "Should have offloaded at least one tensor"
189+ assert deduplicated_count > 0 , "Should have deduplicated at least one tensor (view)"
190+
191+ unique_storages_offloaded = len (offload_ctx .storage_to_tensor_id )
192+ assert unique_storages_offloaded < total_tensor_ids , (
193+ f"Deduplication should result in fewer storages ({ unique_storages_offloaded } ) "
194+ f"than total tensors ({ total_tensor_ids } )"
195+ )
196+
197+ loss .backward ()
198+
199+ @require_torch_accelerator
200+ def test_parameter_filtering (self ):
201+ """Test that model parameters are filtered during offloading"""
202+ model = nn .Sequential (nn .Linear (10 , 20 ), nn .Linear (20 , 10 )).to (torch_device )
203+ offload_ctx = OffloadActivations ()
204+ offload_ctx .update_model_params (model )
205+
206+ assert len (offload_ctx .param_storages ) > 0 , "Should have tracked parameter storages"
207+
208+ param_ptrs = {p .data .untyped_storage ().data_ptr () for p in model .parameters ()}
209+ assert offload_ctx .param_storages == param_ptrs , "Tracked storages should match parameter storages"
0 commit comments