@@ -613,8 +613,7 @@ def input_creation_fn():
613
613
@skipIfRocm
614
614
@unittest .skipIf (not HAS_GPU , "Inductor+gpu needs triton and recent GPU arch" )
615
615
def test_nested_fully_shard_backend_aot_eager (self ):
616
- # TODO: fix fwd_fullgraph=False case
617
- for fwd_fullgraph in [True ]:
616
+ for fwd_fullgraph in [True , False ]:
618
617
self ._test_traceable_fsdp (
619
618
* self ._create_nested_fully_shard_factory_fns (
620
619
fwd_fullgraph = fwd_fullgraph
@@ -626,8 +625,7 @@ def test_nested_fully_shard_backend_aot_eager(self):
626
625
@skipIfRocm
627
626
@unittest .skipIf (not HAS_GPU , "Inductor+gpu needs triton and recent GPU arch" )
628
627
def test_nested_fully_shard_backend_aot_eager_decomp_partition (self ):
629
- # TODO: fix fwd_fullgraph=False case
630
- for fwd_fullgraph in [True ]:
628
+ for fwd_fullgraph in [True , False ]:
631
629
self ._test_traceable_fsdp (
632
630
* self ._create_nested_fully_shard_factory_fns (
633
631
fwd_fullgraph = fwd_fullgraph
@@ -732,7 +730,6 @@ def test_nested_fully_shard_backend_inductor_fullgraph_True(self):
732
730
)
733
731
file_check .run (bwd_code )
734
732
735
- @unittest .skip ("TODO: fix fwd_fullgraph=False case" )
736
733
@skipIfRocm
737
734
@unittest .skipIf (not HAS_GPU , "Inductor+gpu needs triton and recent GPU arch" )
738
735
def test_nested_fully_shard_backend_inductor_fullgraph_False (self ):
@@ -813,9 +810,8 @@ def _sdpa_with_graph_break(*args, **kwargs):
813
810
@skipIfRocm
814
811
@unittest .skipIf (not HAS_GPU , "Inductor+gpu needs triton and recent GPU arch" )
815
812
def test_transformer_backend_aot_eager (self ):
816
- # TODO: fix fwd_fullgraph=False case
817
813
for fwd_fullgraph , all_requires_grad in itertools .product (
818
- [True ], [True , False ]
814
+ [True , False ], [True , False ]
819
815
):
820
816
with self ._maybe_add_graph_break_to_sdpa (
821
817
fwd_fullgraph
@@ -833,9 +829,8 @@ def test_transformer_backend_aot_eager(self):
833
829
# TODO: native_dropout has worse accuracy after decomp, need to figure out why
834
830
@torch ._inductor .config .patch (fallback_random = True )
835
831
def test_transformer_backend_aot_eager_decomp_partition (self ):
836
- # TODO: fix fwd_fullgraph=False case
837
832
for fwd_fullgraph , all_requires_grad in itertools .product (
838
- [True ], [True , False ]
833
+ [True , False ], [True , False ]
839
834
):
840
835
with self ._maybe_add_graph_break_to_sdpa (fwd_fullgraph ):
841
836
self ._test_traceable_fsdp (
@@ -951,7 +946,6 @@ def test_transformer_backend_inductor_fullgraph_True(self):
951
946
)
952
947
file_check .run (bwd_code )
953
948
954
- @unittest .skip ("TODO: fix fwd_fullgraph=False case" )
955
949
@skipIfRocm
956
950
@unittest .skipIf (not HAS_GPU , "Inductor+gpu needs triton and recent GPU arch" )
957
951
# TODO: native_dropout causes CUDA IMA error, need to figure out why
0 commit comments