1717import torch
1818from hypothesis import given , settings , strategies as st , Verbosity
1919from torch import nn , optim
20+ from torch ._dynamo .testing import reduce_to_scalar_loss
2021from torchrec .distributed import DistributedModelParallel
2122from torchrec .distributed .embedding_types import EmbeddingComputeKernel
2223from torchrec .distributed .embeddingbag import EmbeddingBagCollectionSharder
4546 PrefetchTrainPipelineSparseDist ,
4647 StagedTrainPipeline ,
4748 TrainPipelineBase ,
49+ TrainPipelinePT2 ,
4850 TrainPipelineSemiSync ,
4951 TrainPipelineSparseDist ,
5052)
6365 ShardingPlan ,
6466 ShardingType ,
6567)
66- from torchrec .modules .embedding_configs import DataType
68+ from torchrec .modules .embedding_configs import DataType , EmbeddingBagConfig
69+ from torchrec .modules .embedding_modules import EmbeddingBagCollection
6770
6871from torchrec .optim .keyed import KeyedOptimizerWrapper
6972from torchrec .optim .optimizers import in_backward_optimizer_filter
73+ from torchrec .pt2 .utils import kjt_for_pt2_tracing
74+ from torchrec .sparse .jagged_tensor import JaggedTensor , KeyedJaggedTensor , KeyedTensor
7075from torchrec .streamable import Pipelineable
7176
7277
@@ -93,6 +98,7 @@ def __init__(self) -> None:
9398 super ().__init__ ()
9499 self .model = nn .Linear (10 , 1 )
95100 self .loss_fn = nn .BCEWithLogitsLoss ()
101+ self ._dummy_setting : str = "dummy"
96102
97103 def forward (
98104 self , model_input : ModelInputSimple
@@ -156,6 +162,153 @@ def test_equal_to_non_pipelined(self) -> None:
156162 self .assertTrue (torch .isclose (pred_gpu .cpu (), pred ))
157163
158164
165+ class TrainPipelinePT2Test (unittest .TestCase ):
166+ def setUp (self ) -> None :
167+ self .device = torch .device ("cuda:0" )
168+ torch .backends .cudnn .allow_tf32 = False
169+ torch .backends .cuda .matmul .allow_tf32 = False
170+
171+ def gen_eb_conf_list (self , is_weighted : bool = False ) -> List [EmbeddingBagConfig ]:
172+ weighted_prefix = "weighted_" if is_weighted else ""
173+
174+ return [
175+ EmbeddingBagConfig (
176+ num_embeddings = 256 ,
177+ embedding_dim = 12 ,
178+ name = weighted_prefix + "table_0" ,
179+ feature_names = [weighted_prefix + "f0" ],
180+ ),
181+ EmbeddingBagConfig (
182+ num_embeddings = 256 ,
183+ embedding_dim = 12 ,
184+ name = weighted_prefix + "table_1" ,
185+ feature_names = [weighted_prefix + "f1" ],
186+ ),
187+ ]
188+
189+ def gen_model (
190+ self , device : torch .device , ebc_list : List [EmbeddingBagConfig ]
191+ ) -> nn .Module :
192+ class M_ebc (torch .nn .Module ):
193+ def __init__ (self , vle : EmbeddingBagCollection ) -> None :
194+ super ().__init__ ()
195+ self .model = vle
196+
197+ def forward (self , x : KeyedJaggedTensor ) -> List [JaggedTensor ]:
198+ kt : KeyedTensor = self .model (x )
199+ return list (kt .to_dict ().values ())
200+
201+ return M_ebc (
202+ EmbeddingBagCollection (
203+ device = device ,
204+ tables = ebc_list ,
205+ )
206+ )
207+
208+ # pyre-fixme[56]: Pyre was not able to infer the type of argument
209+ @unittest .skipIf (
210+ not torch .cuda .is_available (),
211+ "Not enough GPUs, this test requires at least one GPU" ,
212+ )
213+ def test_equal_to_non_pipelined (self ) -> None :
214+ model_cpu = TestModule ()
215+ model_gpu = TestModule ().to (self .device )
216+ model_gpu .load_state_dict (model_cpu .state_dict ())
217+ optimizer_cpu = optim .SGD (model_cpu .model .parameters (), lr = 0.01 )
218+ optimizer_gpu = optim .SGD (model_gpu .model .parameters (), lr = 0.01 )
219+ data = [
220+ ModelInputSimple (
221+ float_features = torch .rand ((10 ,)),
222+ label = torch .randint (2 , (1 ,), dtype = torch .float32 ),
223+ )
224+ for b in range (5 )
225+ ]
226+ dataloader = iter (data )
227+ pipeline = TrainPipelinePT2 (model_gpu , optimizer_gpu , self .device )
228+
229+ for batch in data [:- 1 ]:
230+ optimizer_cpu .zero_grad ()
231+ loss , pred = model_cpu (batch )
232+ loss .backward ()
233+ optimizer_cpu .step ()
234+
235+ pred_gpu = pipeline .progress (dataloader )
236+
237+ self .assertEqual (pred_gpu .device , self .device )
238+ self .assertTrue (torch .isclose (pred_gpu .cpu (), pred ))
239+
240+ # pyre-fixme[56]: Pyre was not able to infer the type of argument
241+ @unittest .skipIf (
242+ not torch .cuda .is_available (),
243+ "Not enough GPUs, this test requires at least one GPU" ,
244+ )
245+ def test_pre_compile_fn (self ) -> None :
246+ model_cpu = TestModule ()
247+ model_gpu = TestModule ().to (self .device )
248+ model_gpu .load_state_dict (model_cpu .state_dict ())
249+ optimizer_gpu = optim .SGD (model_gpu .model .parameters (), lr = 0.01 )
250+ data = [
251+ ModelInputSimple (
252+ float_features = torch .rand ((10 ,)),
253+ label = torch .randint (2 , (1 ,), dtype = torch .float32 ),
254+ )
255+ for b in range (5 )
256+ ]
257+
258+ def pre_compile_fn (model : nn .Module ) -> None :
259+ model ._dummy_setting = "dummy modified"
260+
261+ dataloader = iter (data )
262+ pipeline = TrainPipelinePT2 (
263+ model_gpu , optimizer_gpu , self .device , pre_compile_fn = pre_compile_fn
264+ )
265+ self .assertEqual (model_gpu ._dummy_setting , "dummy" )
266+ for _ in range (len (data )):
267+ pipeline .progress (dataloader )
268+ self .assertEqual (model_gpu ._dummy_setting , "dummy modified" )
269+
270+ # pyre-fixme[56]: Pyre was not able to infer the type of argument
271+ @unittest .skipIf (
272+ not torch .cuda .is_available (),
273+ "Not enough GPUs, this test requires at least one GPU" ,
274+ )
275+ def test_equal_to_non_pipelined_with_input_transformer (self ) -> None :
276+ cpu = torch .device ("cpu:0" )
277+ eb_conf_list = self .gen_eb_conf_list ()
278+ eb_conf_list_weighted = self .gen_eb_conf_list (is_weighted = True )
279+
280+ model_cpu = self .gen_model (cpu , eb_conf_list )
281+ model_gpu = self .gen_model (self .device , eb_conf_list ).to (self .device )
282+
283+ _ , local_model_inputs = ModelInput .generate (
284+ batch_size = 10 ,
285+ world_size = 4 ,
286+ num_float_features = 8 ,
287+ tables = eb_conf_list ,
288+ weighted_tables = eb_conf_list_weighted ,
289+ variable_batch_size = False ,
290+ )
291+
292+ model_gpu .load_state_dict (model_cpu .state_dict ())
293+ optimizer_cpu = optim .SGD (model_cpu .model .parameters (), lr = 0.01 )
294+ optimizer_gpu = optim .SGD (model_gpu .model .parameters (), lr = 0.01 )
295+
296+ data = [i .idlist_features for i in local_model_inputs ]
297+ dataloader = iter (data )
298+ pipeline = TrainPipelinePT2 (
299+ model_gpu , optimizer_gpu , self .device , input_transformer = kjt_for_pt2_tracing
300+ )
301+
302+ for batch in data [:- 1 ]:
303+ optimizer_cpu .zero_grad ()
304+ loss , pred = model_cpu (batch )
305+ loss = reduce_to_scalar_loss (loss )
306+ pred_gpu = pipeline .progress (dataloader )
307+
308+ self .assertEqual (pred_gpu .device , self .device )
309+ torch .testing .assert_close (pred_gpu .cpu (), pred )
310+
311+
159312class TrainPipelineSparseDistTest (TrainPipelineSparseDistTestBase ):
160313 # pyre-fixme[56]: Pyre was not able to infer the type of argument
161314 @unittest .skipIf (
0 commit comments