Skip to content

Commit

Permalink
[BE] Enable ruff's UP rules and autoformat dynamo / functorch and refs (
Browse files Browse the repository at this point in the history
pytorch#105432)

Pull Request resolved: pytorch#105432
Approved by: https://github.com/ezyang
  • Loading branch information
justinchuby authored and pytorchmergebot committed Jul 19, 2023
1 parent 88f1197 commit 8a68827
Show file tree
Hide file tree
Showing 47 changed files with 188 additions and 242 deletions.
2 changes: 1 addition & 1 deletion functorch/benchmarks/operator_authoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def out_setup(n):
def test_backwards(make_args, nnc=nnc_add, aten=torch.add):
def backwards_setup(n):
args = make_args(n)
(grad_var,) = [a for a in args if a.requires_grad]
(grad_var,) = (a for a in args if a.requires_grad)
aten(*args).sum().backward()
correct = grad_var.grad.clone()
grad_var.grad.zero_()
Expand Down
22 changes: 10 additions & 12 deletions functorch/einops/rearrange.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,18 +108,16 @@ class dims."""

custom_rearrange_callable_name = "do_rearrange"
custom_rearrange_callable_code = (
(
f"def {custom_rearrange_callable_name}(tensor):\n"
f" {comma_separate(first_class_dims)} = dims({n_dims})\n"
+ (
"".join(f" {dim}.size = {length}\n" for (dim, length) in specified_lengths)
if specified_lengths else ""
)
+ f" tensor = tensor[{comma_separate(left_dims)}].order({comma_separate(right_dims)})\n"
+ (
f" return tensor.sum({comma_separate([anon_dims])}, keepdim=False)\n"
if anon_dims else " return tensor\n"
)
f"def {custom_rearrange_callable_name}(tensor):\n"
f" {comma_separate(first_class_dims)} = dims({n_dims})\n"
+ (
"".join(f" {dim}.size = {length}\n" for (dim, length) in specified_lengths)
if specified_lengths else ""
)
+ f" tensor = tensor[{comma_separate(left_dims)}].order({comma_separate(right_dims)})\n"
+ (
f" return tensor.sum({comma_separate([anon_dims])}, keepdim=False)\n"
if anon_dims else " return tensor\n"
)
)

Expand Down
2 changes: 1 addition & 1 deletion functorch/examples/compilation/linear_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def bench(f, iters=100, warmup=10):
begin = time.time()
for _ in range(iters):
f()
print((time.time() - begin))
print(time.time() - begin)


class Foo(nn.Module):
Expand Down
8 changes: 4 additions & 4 deletions functorch/examples/maml_omniglot/support/omniglot_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def find_classes(root_dir):
r = root.split('/')
lr = len(r)
retour.append((f, r[lr - 2] + "/" + r[lr - 1], root))
print("== Found %d items " % len(retour))
print(f"== Found {len(retour)} items ")
return retour


Expand All @@ -130,7 +130,7 @@ def index_classes(items):
for i in items:
if i[1] not in idx:
idx[i[1]] = len(idx)
print("== Found %d classes" % len(idx))
print(f"== Found {len(idx)} classes")
return idx


Expand Down Expand Up @@ -276,10 +276,10 @@ def load_data_cache(self, data_pack):
x_qrys = np.array(x_qrys).astype(np.float32).reshape(self.batchsz, querysz, 1, self.resize, self.resize)
y_qrys = np.array(y_qrys).astype(int).reshape(self.batchsz, querysz)

x_spts, y_spts, x_qrys, y_qrys = [
x_spts, y_spts, x_qrys, y_qrys = (
torch.from_numpy(z).to(self.device) for z in
[x_spts, y_spts, x_qrys, y_qrys]
]
)

data_cache.append([x_spts, y_spts, x_qrys, y_qrys])

Expand Down
10 changes: 5 additions & 5 deletions functorch/op_analysis/gen_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def gen_data(special_op_lists, analysis_name):
composite_ops = get_ops_for_key('CompositeImplicitAutograd')
noncomposite_ops = all_ops - composite_ops

ops = yaml.load(open('../../aten/src/ATen/native/native_functions.yaml', 'r').read(), Loader=yaml.CLoader)
ops = yaml.load(open('../../aten/src/ATen/native/native_functions.yaml').read(), Loader=yaml.CLoader)

annotated_ops = {a.strip(): b.strip() for a, b in list(csv.reader(open('annotated_ops')))}
from collections import defaultdict
Expand Down Expand Up @@ -132,19 +132,19 @@ def remove_prefix(input_string, prefix):


if True:
with open('run_ops.txt', 'r') as f:
with open('run_ops.txt') as f:
opinfo_ops = [remove_suffix(i.strip(), '.default') for i in f.readlines()]
with open('count_ops.txt', 'r') as f:
with open('count_ops.txt') as f:
opinfo_counts = [i.strip() for i in f.readlines()]
opinfo_counts = defaultdict(int, dict(zip(opinfo_ops, opinfo_counts)))

def count_fn(x):
return opinfo_counts[x['full_name']]

with open('run_decompositions.txt', 'r') as f:
with open('run_decompositions.txt') as f:
decomposed_ops = [remove_suffix(i.strip(), '.default') for i in f.readlines()]

with open('public_api', 'r') as f:
with open('public_api') as f:
ref_api = [i.strip() for i in f.readlines()]

def has_ref_impl(x):
Expand Down
4 changes: 2 additions & 2 deletions test/dynamo/test_autograd_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def backward(ctx, grad_output):

class ModuleWithGradFunc(torch.nn.Module):
def __init__(self, func):
super(ModuleWithGradFunc, self).__init__()
super().__init__()
self.f = func.apply

def forward(self, x):
Expand Down Expand Up @@ -336,7 +336,7 @@ def backward(ctx, grad_output):

class MyMod(torch.nn.Module):
def __init__(self):
super(MyMod, self).__init__()
super().__init__()
self.gamma = torch.nn.Parameter(torch.rand([4, 128, 32, 32]))

def forward(self, x):
Expand Down
2 changes: 1 addition & 1 deletion test/dynamo/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

class ToyModel(torch.nn.Module):
def __init__(self):
super(ToyModel, self).__init__()
super().__init__()
self.linear = torch.nn.Linear(10, 10)
self.relu = torch.nn.ReLU()

Expand Down
2 changes: 1 addition & 1 deletion test/dynamo/test_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def throw(x):
def test_ddp_graphs(self, records):
class ToyModel(torch.nn.Module):
def __init__(self):
super(ToyModel, self).__init__()
super().__init__()
self.layers = torch.nn.Sequential(
torch.nn.Linear(1024, 1024),
torch.nn.Linear(1024, 1024),
Expand Down
34 changes: 17 additions & 17 deletions test/dynamo/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -822,7 +822,7 @@ def fn(a, b):
v2 = torch.randn((10, 10))
correct = fn(v1, v2)
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize((cnts))(fn)
opt_fn = torch._dynamo.optimize(cnts)(fn)
self.assertEqual(opt_fn(v1, v2), correct)
self.assertEqual(cnts.frame_count, 1)
self.assertEqual(cnts.op_count, 3)
Expand All @@ -836,7 +836,7 @@ def fn(a, b):
v2 = torch.randn((10, 10))
correct = fn(v1, v2)
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize((cnts))(fn)
opt_fn = torch._dynamo.optimize(cnts)(fn)
self.assertEqual(opt_fn(v1, v2), correct)
self.assertEqual(cnts.frame_count, 1)
self.assertEqual(cnts.op_count, 2)
Expand Down Expand Up @@ -2201,7 +2201,7 @@ def fn():
def fn():
foo.bar(1, 2, 3)
{str(chr(10)).join(' ' * 4 + 'x' + str(i) + ' = 1' for i in range(1 << 9))}
l = [{str(' ').join('x' + str(i) + ',' for i in range(1 << 9))}]
l = [{' '.join('x' + str(i) + ',' for i in range(1 << 9))}]
"""
locals = {}
exec(fn_str, {}, locals)
Expand Down Expand Up @@ -3086,7 +3086,7 @@ def foo(self, memo=None, prefix="", remove_duplicate=False):
memo=memo, prefix=prefix, remove_duplicate=remove_duplicate
):
for pn, p in self.named_parameters():
fpn = "%s.%s" % (mn, pn) if mn else pn
fpn = f"{mn}.{pn}" if mn else pn
self.names.append(fpn)

# Test plain recurse
Expand Down Expand Up @@ -5031,11 +5031,11 @@ def test_compute_exception_table_nested(self):
(15, 16, 7),
(17, 17, 6),
]
self.assertEquals(len(tab), len(expected))
self.assertEqual(len(tab), len(expected))
for entry, exp in zip(tab, expected):
self.assertEquals(entry.start, exp[0] * 2)
self.assertEquals(entry.end, exp[1] * 2)
self.assertEquals(entry.target, exp[2] * 2)
self.assertEqual(entry.start, exp[0] * 2)
self.assertEqual(entry.end, exp[1] * 2)
self.assertEqual(entry.target, exp[2] * 2)

@skipIfNotPy311
def test_remove_dead_code_with_exn_table_entries(self):
Expand All @@ -5059,17 +5059,17 @@ def test_remove_dead_code_with_exn_table_entries(self):
)
bytecode_transformation.propagate_inst_exn_table_entries(insts)
insts = bytecode_analysis.remove_dead_code(insts)
self.assertEquals(len(insts), 5)
self.assertEqual(len(insts), 5)
self.assertNotIn(exn_start, insts)
self.assertNotIn(exn_end, insts)
self.assertIn(target2, insts)
self.assertIn(target3, insts)
bytecode_transformation.update_offsets(insts)
tab = bytecode_transformation.compute_exception_table(insts)
self.assertEquals(len(tab), 1)
self.assertEquals(tab[0].start, 2)
self.assertEquals(tab[0].end, 4)
self.assertEquals(tab[0].target, 6)
self.assertEqual(len(tab), 1)
self.assertEqual(tab[0].start, 2)
self.assertEqual(tab[0].end, 4)
self.assertEqual(tab[0].target, 6)

def test_unhandled_exception_in_dynamo(self):
# traceback.format_exc() approximates an unhandled exception
Expand Down Expand Up @@ -5756,7 +5756,7 @@ def guard(L):
def test_dynamo_compiling_fake_tensor_to_vararg_int(self):
class MyModule(torch.nn.Module):
def __init__(self):
super(MyModule, self).__init__()
super().__init__()

def forward(self, x):
# use numpy int so it's wrapped as fake tensor in dynamo
Expand All @@ -5775,7 +5775,7 @@ def forward(self, x):
def test_scalar_tensor_is_equivalent_to_symint_argument(self):
class GumbelTopKSampler(torch.nn.Module):
def __init__(self, T, k):
super(GumbelTopKSampler, self).__init__()
super().__init__()
self.T = torch.nn.Parameter(
torch.tensor(T, dtype=torch.float32), requires_grad=False
)
Expand All @@ -5802,7 +5802,7 @@ def forward(self, logits):
def test_scalar_tensor_is_equivalent_to_symint_list_argument(self):
class Jitter(torch.nn.Module):
def __init__(self, jitter_val):
super(Jitter, self).__init__()
super().__init__()
self.jitter_val = jitter_val

def roll_tensor(self, input):
Expand Down Expand Up @@ -5987,7 +5987,7 @@ def _prepare_for_translation_validation(self):

# Z3 symbols.
[validator.add_var(s, int) for s in (s0, s1, s2)]
z0, z1, z2 = [validator.z3var(s) for s in (s0, s1, s2)]
z0, z1, z2 = (validator.z3var(s) for s in (s0, s1, s2))

return (s0, s1, s2), (z0, z1, z2), validator

Expand Down
12 changes: 6 additions & 6 deletions test/dynamo/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -762,21 +762,21 @@ def forward(self, x):

class ConvCallSuperForwardDirectly(torch.nn.Conv1d):
def __init__(self, in_channels, out_channels, kernel_size, **kwargs):
super(ConvCallSuperForwardDirectly, self).__init__(
super().__init__(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
**kwargs,
)

def forward(self, inputs, mask=None):
outputs = super(ConvCallSuperForwardDirectly, self).forward(inputs)
outputs = super().forward(inputs)
return outputs


class ConvTransposeCallSuperForwardDirectly(torch.nn.ConvTranspose2d):
def __init__(self, in_channels, out_channels, kernel_size, **kwargs):
super(ConvTransposeCallSuperForwardDirectly, self).__init__(
super().__init__(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
Expand All @@ -785,7 +785,7 @@ def __init__(self, in_channels, out_channels, kernel_size, **kwargs):

def forward(self, x):
if x.numel() > 0:
return super(ConvTransposeCallSuperForwardDirectly, self).forward(x)
return super().forward(x)
output_shape = [
((i - 1) * d - 2 * p + (di * (k - 1) + 1) + op)
for i, p, di, k, d, op in zip(
Expand Down Expand Up @@ -923,7 +923,7 @@ def forward(self, x):
class SequentialWithDuplicatedModule(torch.nn.Module):
# Sequential module(self.layer) contains three duplicated ReLU module.
def __init__(self):
super(SequentialWithDuplicatedModule, self).__init__()
super().__init__()
self.relu = torch.nn.ReLU()
self.layer = torch.nn.Sequential(
torch.nn.Linear(10, 20),
Expand All @@ -940,7 +940,7 @@ def forward(self, x):

class SequentialWithDuplicatedModule2(torch.nn.Module):
def __init__(self):
super(SequentialWithDuplicatedModule2, self).__init__()
super().__init__()
self.relu = torch.nn.ReLU()
self.layer = torch.nn.Sequential(
collections.OrderedDict(
Expand Down
4 changes: 2 additions & 2 deletions test/dynamo/test_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def inner_fn(x):
def outer_fn(x, y):
return inner_fn(x) * y

x, y = [torch.rand((2, 2)) for _ in range(2)]
x, y = (torch.rand((2, 2)) for _ in range(2))

with torch.profiler.profile(with_stack=False) as prof:
outer_fn(x, y)
Expand All @@ -40,7 +40,7 @@ def test_dynamo_timed_profiling_backend_compile(self):
def fn(x, y):
return x.sin() * y.cos()

x, y = [torch.rand((2, 2)) for _ in range(2)]
x, y = (torch.rand((2, 2)) for _ in range(2))

with torch.profiler.profile(with_stack=False) as prof:
torch._dynamo.optimize("aot_eager")(fn)(x, y)
Expand Down
2 changes: 1 addition & 1 deletion test/dynamo/test_repros.py
Original file line number Diff line number Diff line change
Expand Up @@ -2632,7 +2632,7 @@ def test_error_return_without_exception_set(self):
# https://github.com/pytorch/pytorch/issues/93781
@torch.compile
def f():
_generator_type = type((_ for _ in ()))
_generator_type = type(_ for _ in ())

self.assertNoUnraisable(f)

Expand Down
2 changes: 1 addition & 1 deletion test/error_messages/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def check_error(desc, fn, *required_substrings):
for sub in required_substrings:
assert sub in error_message
return
raise AssertionError("given function ({}) didn't raise an error".format(desc))
raise AssertionError(f"given function ({desc}) didn't raise an error")

check_error(
'Wrong argument types',
Expand Down
6 changes: 3 additions & 3 deletions test/export/test_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class ExampleTests(TestCase):
@parametrize(
"name,case",
filter_examples_by_support_level(SupportLevel.SUPPORTED).items(),
name_fn=lambda name, case: "case_{}".format(name),
name_fn=lambda name, case: f"case_{name}",
)
def test_exportdb_supported(self, name: str, case: ExportCase) -> None:
model = case.model
Expand Down Expand Up @@ -51,7 +51,7 @@ def test_exportdb_supported(self, name: str, case: ExportCase) -> None:
@parametrize(
"name,case",
filter_examples_by_support_level(SupportLevel.NOT_SUPPORTED_YET).items(),
name_fn=lambda name, case: "case_{}".format(name),
name_fn=lambda name, case: f"case_{name}",
)
def test_exportdb_not_supported(self, name: str, case: ExportCase) -> None:
model = case.model
Expand All @@ -73,7 +73,7 @@ def test_exportdb_not_supported(self, name: str, case: ExportCase) -> None:
).items()
for rewrite_case in get_rewrite_cases(case)
],
name_fn=lambda name, case: "case_{}_{}".format(name, case.name),
name_fn=lambda name, case: f"case_{name}_{case.name}",
)
def test_exportdb_not_supported_rewrite(
self, name: str, rewrite_case: ExportCase
Expand Down
2 changes: 1 addition & 1 deletion test/export/test_serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ def f(x, y):
@parametrize(
"name,case",
get_filtered_export_db_tests(),
name_fn=lambda name, case: "case_{}".format(name),
name_fn=lambda name, case: f"case_{name}",
)
def test_exportdb_supported(self, name: str, case: ExportCase) -> None:
model = case.model
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,7 @@ def check_fc(existing_schemas):
args = parser.parse_args()
existing_schema_dict = {}
slist = []
with open(args.existing_schemas, "r") as f:
with open(args.existing_schemas) as f:
while True:
line = f.readline()
if not line:
Expand Down
Loading

0 comments on commit 8a68827

Please sign in to comment.