@@ -1776,12 +1776,16 @@ def check_cow_input(
1776
1776
if isinstance (r , torch .Tensor ) and r .requires_grad
1777
1777
]
1778
1778
1779
- all_results_strided = all (is_strided_tensor (result ) for result in results )
1779
+ all_results_strided = all (
1780
+ is_strided_tensor (result ) for result in results
1781
+ )
1780
1782
1781
1783
# Only test backward if the results are strided tensors
1782
1784
if all_results_strided :
1783
1785
output_grads_raw = [
1784
- torch .ones (r .shape , device = r .device , dtype = r .dtype ) for r in results ]
1786
+ torch .ones (r .shape , device = r .device , dtype = r .dtype )
1787
+ for r in results
1788
+ ]
1785
1789
output_grads_copy = []
1786
1790
output_grads = []
1787
1791
@@ -1795,28 +1799,30 @@ def check_cow_input(
1795
1799
leaf_tensors ,
1796
1800
output_grads ,
1797
1801
allow_unused = True ,
1798
- retain_graph = True )
1802
+ retain_graph = True ,
1803
+ )
1799
1804
1800
1805
# Check that COW inputs remain COW after the backward op is executed
1801
1806
for idx , arg in enumerate (args ):
1802
1807
check_cow_input (
1803
1808
arg ,
1804
1809
args_copy [idx ],
1805
1810
idx ,
1806
- backward_or_forward = ' backward' ,
1811
+ backward_or_forward = " backward" ,
1807
1812
supports_cow_input_no_materialize = op .supports_cow_input_no_materialize_backward ,
1808
- allow_list = op .allow_cow_input_materialize_backward )
1813
+ allow_list = op .allow_cow_input_materialize_backward ,
1814
+ )
1809
1815
1810
1816
# Check that COW inputs remain COW after the backward op is executed
1811
1817
for idx , output_grad in enumerate (output_grads ):
1812
1818
check_cow_input (
1813
1819
output_grad ,
1814
1820
output_grads_copy [idx ],
1815
- f' output grad { idx } ' ,
1816
- backward_or_forward = ' backward' ,
1821
+ f" output grad { idx } " ,
1822
+ backward_or_forward = " backward" ,
1817
1823
supports_cow_input_no_materialize = op .supports_cow_input_no_materialize_backward ,
1818
- allow_list = op .allow_cow_input_materialize_backward )
1819
-
1824
+ allow_list = op .allow_cow_input_materialize_backward ,
1825
+ )
1820
1826
1821
1827
@ops (op_db , allowed_dtypes = (torch .float ,))
1822
1828
def test_view_replay (self , device , dtype , op ):
0 commit comments