Skip to content

Commit fc86eba

Browse files
committed
fix merge conflix
1 parent eae71ad commit fc86eba

File tree

1 file changed

+15
-9
lines changed

1 file changed

+15
-9
lines changed

test/test_ops.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1776,12 +1776,16 @@ def check_cow_input(
17761776
if isinstance(r, torch.Tensor) and r.requires_grad
17771777
]
17781778

1779-
all_results_strided = all(is_strided_tensor(result) for result in results)
1779+
all_results_strided = all(
1780+
is_strided_tensor(result) for result in results
1781+
)
17801782

17811783
# Only test backward if the results are strided tensors
17821784
if all_results_strided:
17831785
output_grads_raw = [
1784-
torch.ones(r.shape, device=r.device, dtype=r.dtype) for r in results]
1786+
torch.ones(r.shape, device=r.device, dtype=r.dtype)
1787+
for r in results
1788+
]
17851789
output_grads_copy = []
17861790
output_grads = []
17871791

@@ -1795,28 +1799,30 @@ def check_cow_input(
17951799
leaf_tensors,
17961800
output_grads,
17971801
allow_unused=True,
1798-
retain_graph=True)
1802+
retain_graph=True,
1803+
)
17991804

18001805
# Check that COW inputs remain COW after the backward op is executed
18011806
for idx, arg in enumerate(args):
18021807
check_cow_input(
18031808
arg,
18041809
args_copy[idx],
18051810
idx,
1806-
backward_or_forward='backward',
1811+
backward_or_forward="backward",
18071812
supports_cow_input_no_materialize=op.supports_cow_input_no_materialize_backward,
1808-
allow_list=op.allow_cow_input_materialize_backward)
1813+
allow_list=op.allow_cow_input_materialize_backward,
1814+
)
18091815

18101816
# Check that COW inputs remain COW after the backward op is executed
18111817
for idx, output_grad in enumerate(output_grads):
18121818
check_cow_input(
18131819
output_grad,
18141820
output_grads_copy[idx],
1815-
f'output grad {idx}',
1816-
backward_or_forward='backward',
1821+
f"output grad {idx}",
1822+
backward_or_forward="backward",
18171823
supports_cow_input_no_materialize=op.supports_cow_input_no_materialize_backward,
1818-
allow_list=op.allow_cow_input_materialize_backward)
1819-
1824+
allow_list=op.allow_cow_input_materialize_backward,
1825+
)
18201826

18211827
@ops(op_db, allowed_dtypes=(torch.float,))
18221828
def test_view_replay(self, device, dtype, op):

0 commit comments

Comments
 (0)