1919import torchrec
2020import torchrec .pt2 .checks
2121from hypothesis import given , settings , strategies as st , Verbosity
22+ from torch ._dynamo .testing import reduce_to_scalar_loss
2223from torchrec .distributed .embedding import EmbeddingCollectionSharder
2324from torchrec .distributed .embedding_types import EmbeddingComputeKernel
2425from torchrec .distributed .fbgemm_qcomm_codec import QCommsConfig
5657from torchrec .pt2 .utils import kjt_for_pt2_tracing
5758from torchrec .sparse .jagged_tensor import JaggedTensor , KeyedJaggedTensor , KeyedTensor
5859
60+
5961try :
6062 torch .ops .load_library ("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops" )
6163 torch .ops .load_library ("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu" )
@@ -139,17 +141,22 @@ def sharding_types(self, compute_device_type: str) -> List[str]:
139141
140142
141143def _gen_model (test_model_type : _ModelType , mi : TestModelInfo ) -> torch .nn .Module :
144+ emb_dim : int = max (t .embedding_dim for t in mi .tables )
142145 if test_model_type == _ModelType .EBC :
143146
144147 class M_ebc (torch .nn .Module ):
145148 def __init__ (self , ebc : EmbeddingBagCollection ) -> None :
146149 super ().__init__ ()
147150 self ._ebc = ebc
151+ self ._linear = torch .nn .Linear (
152+ mi .num_float_features , emb_dim , device = mi .dense_device
153+ )
148154
149- def forward (self , x : KeyedJaggedTensor ) -> torch .Tensor :
155+ def forward (self , x : KeyedJaggedTensor , y : torch . Tensor ) -> torch .Tensor :
150156 kt : KeyedTensor = self ._ebc (x )
151157 v = kt .values ()
152- return torch .sigmoid (torch .mean (v , dim = 1 ))
158+ y = self ._linear (y )
159+ return torch .mul (torch .mean (v , dim = 1 ), torch .mean (y , dim = 1 ))
153160
154161 return M_ebc (
155162 EmbeddingBagCollection (
@@ -164,10 +171,15 @@ class M_fpebc(torch.nn.Module):
164171 def __init__ (self , fpebc : FeatureProcessedEmbeddingBagCollection ) -> None :
165172 super ().__init__ ()
166173 self ._fpebc = fpebc
174+ self ._linear = torch .nn .Linear (
175+ mi .num_float_features , emb_dim , device = mi .dense_device
176+ )
167177
168- def forward (self , x : KeyedJaggedTensor ) -> torch .Tensor :
178+ def forward (self , x : KeyedJaggedTensor , y : torch . Tensor ) -> torch .Tensor :
169179 kt : KeyedTensor = self ._fpebc (x )
170- return kt .values ()
180+ v = kt .values ()
181+ y = self ._linear (y )
182+ return torch .mul (torch .mean (v , dim = 1 ), torch .mean (y , dim = 1 ))
171183
172184 return M_fpebc (
173185 FeatureProcessedEmbeddingBagCollection (
@@ -187,9 +199,13 @@ def __init__(self, ec: EmbeddingCollection) -> None:
187199 super ().__init__ ()
188200 self ._ec = ec
189201
190- def forward (self , x : KeyedJaggedTensor ) -> List [JaggedTensor ]:
202+ def forward (
203+ self , x : KeyedJaggedTensor , y : torch .Tensor
204+ ) -> List [JaggedTensor ]:
191205 d : Dict [str , JaggedTensor ] = self ._ec (x )
192- return list (d .values ())
206+ v = torch .stack (d .values (), dim = 0 ).sum (dim = 0 )
207+ y = self ._linear (y )
208+ return torch .mul (torch .mean (v , dim = 1 ), torch .mean (y , dim = 1 ))
193209
194210 return M_ec (
195211 EmbeddingCollection (
@@ -307,6 +323,7 @@ def _test_compile_rank_fn(
307323 # pyre-ignore
308324 sharders = sharders ,
309325 device = device ,
326+ init_data_parallel = False ,
310327 )
311328
312329 if input_type == _InputType .VARIABLE_BATCH :
@@ -336,19 +353,27 @@ def _test_compile_rank_fn(
336353 local_model_input = local_model_inputs [0 ].to (device )
337354
338355 kjt = local_model_input .idlist_features
356+ ff = local_model_input .float_features
357+ ff .requires_grad = True
339358 kjt_ft = kjt_for_pt2_tracing (kjt , convert_to_vb = convert_to_vb )
340359
360+ compile_input_ff = ff .clone ().detach ()
361+
341362 torchrec .distributed .comm_ops .set_use_sync_collectives (True )
342363 torchrec .pt2 .checks .set_use_torchdynamo_compiling_path (True )
343364
344365 dmp .train (True )
345366
346- eager_out = dmp (kjt_ft )
367+ eager_out = dmp (kjt_ft , ff )
368+
369+ eager_loss = reduce_to_scalar_loss (eager_out )
370+ eager_loss .backward ()
347371
348372 if torch_compile_backend is None :
349373 return
350374
351375 ##### COMPILE #####
376+ run_compile_backward : bool = torch_compile_backend in ["aot_eager" , "inductor" ]
352377 with dynamo_skipfiles_allow ("torchrec" ):
353378 torch ._dynamo .config .capture_scalar_outputs = True
354379 torch ._dynamo .config .capture_dynamic_output_shape_ops = True
@@ -357,8 +382,14 @@ def _test_compile_rank_fn(
357382 backend = torch_compile_backend ,
358383 fullgraph = True ,
359384 )
360- compile_out = opt_fn (kjt_for_pt2_tracing (kjt , convert_to_vb = convert_to_vb ))
361- torch .testing .assert_close (eager_out , compile_out )
385+ compile_out = opt_fn (
386+ kjt_for_pt2_tracing (kjt , convert_to_vb = convert_to_vb ), compile_input_ff
387+ )
388+ torch .testing .assert_close (eager_out , compile_out , atol = 1e-3 , rtol = 1e-3 )
389+ if run_compile_backward :
390+ loss = reduce_to_scalar_loss (compile_out )
391+ loss .backward ()
392+
362393 ##### COMPILE END #####
363394
364395 ##### NUMERIC CHECK #####
@@ -368,9 +399,20 @@ def _test_compile_rank_fn(
368399 local_model_input = local_model_inputs [1 + i ].to (device )
369400 kjt = local_model_input .idlist_features
370401 kjt_ft = kjt_for_pt2_tracing (kjt , convert_to_vb = convert_to_vb )
371- eager_out_i = dmp (kjt_ft )
372- compile_out_i = opt_fn (kjt_ft )
373- torch .testing .assert_close (eager_out_i , compile_out_i )
402+ ff = local_model_input .float_features
403+ ff .requires_grad = True
404+ eager_out_i = dmp (kjt_ft , ff )
405+ eager_loss_i = reduce_to_scalar_loss (eager_out_i )
406+ eager_loss_i .backward ()
407+
408+ compile_input_ff = ff .detach ().clone ()
409+ compile_out_i = opt_fn (kjt_ft , ff )
410+ torch .testing .assert_close (
411+ eager_out_i , compile_out_i , atol = 1e-3 , rtol = 1e-3
412+ )
413+ if run_compile_backward :
414+ loss_i = torch ._dynamo .testing .reduce_to_scalar_loss (compile_out_i )
415+ loss_i .backward ()
374416 ##### NUMERIC CHECK END #####
375417
376418
@@ -396,14 +438,14 @@ def disable_cuda_tf32(self) -> bool:
396438 ShardingType .TABLE_WISE .value ,
397439 _InputType .SINGLE_BATCH ,
398440 _ConvertToVariableBatch .TRUE ,
399- "eager " ,
441+ "inductor " ,
400442 ),
401443 (
402444 _ModelType .EBC ,
403445 ShardingType .COLUMN_WISE .value ,
404446 _InputType .SINGLE_BATCH ,
405447 _ConvertToVariableBatch .TRUE ,
406- "eager " ,
448+ "inductor " ,
407449 ),
408450 (
409451 _ModelType .EBC ,
@@ -412,6 +454,13 @@ def disable_cuda_tf32(self) -> bool:
412454 _ConvertToVariableBatch .FALSE ,
413455 "eager" ,
414456 ),
457+ (
458+ _ModelType .EBC ,
459+ ShardingType .COLUMN_WISE .value ,
460+ _InputType .SINGLE_BATCH ,
461+ _ConvertToVariableBatch .FALSE ,
462+ "eager" ,
463+ ),
415464 ]
416465 ),
417466 )
@@ -424,7 +473,7 @@ def test_compile_multiprocess(
424473 str ,
425474 _InputType ,
426475 _ConvertToVariableBatch ,
427- str ,
476+ Optional [ str ] ,
428477 ],
429478 ) -> None :
430479 model_type , sharding_type , input_type , tovb , compile_backend = (
0 commit comments