Skip to content

Commit

Permalink
[BE]: Optimize min/max/sum comprehensions C419 (pytorch#123960)
Browse files Browse the repository at this point in the history
Automatic fixes that replaces certain list comprehensions with generator ones where appropriate so that they are immediately consumed. This is preview functionality in ruff for rule C419 and it was automatically applied.

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
Pull Request resolved: pytorch#123960
Approved by: https://github.com/malfet
  • Loading branch information
Skylion007 authored and pytorchmergebot committed Apr 12, 2024
1 parent 961eb39 commit 1d6c597
Show file tree
Hide file tree
Showing 42 changed files with 89 additions and 110 deletions.
6 changes: 2 additions & 4 deletions benchmarks/gpt_fast/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,10 +194,8 @@ def run_experiment(

torch.manual_seed(1234)
model_size = sum(
[
p.numel() * p.dtype.itemsize
for p in itertools.chain(model.parameters(), model.buffers())
]
p.numel() * p.dtype.itemsize
for p in itertools.chain(model.parameters(), model.buffers())
)

aggregate_metrics = {"tokens_per_sec": []}
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/tensorexpr/rnn_eltwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def memory_workload(self):
def memsize(t):
return t.numel() * t.element_size()

input_size = sum([memsize(t) for t in self.inputs])
input_size = sum(memsize(t) for t in self.inputs)
output_size = 2 * memsize(self.cx)
io_size = input_size + output_size
return {"sol": io_size, "algorithmic": io_size}
Expand Down
2 changes: 1 addition & 1 deletion scripts/compile_tests/failures_histogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def failures_histogram(eager_dir, dynamo_dir, verbose=False, format_issues=False
else "(num_failed_tests, error_msg, sample_test)"
)
print(header)
sum_counts = sum([r[0] for r in result])
sum_counts = sum(r[0] for r in result)
for row in result:
if format_issues:
print(as_issue(*row))
Expand Down
4 changes: 2 additions & 2 deletions test/distributed/optim/test_zero_redundancy_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,7 @@ def test_sharding(self):
params.append(torch.rand(size, 1))
o = ZeroRedundancyOptimizer(params, optimizer_class=SGD, lr=LR)
self.assertEqual(
sum([x.numel() for x in o.optim.param_groups[0]["params"]]),
sum(x.numel() for x in o.optim.param_groups[0]["params"]),
sum(sizes),
)

Expand Down Expand Up @@ -567,7 +567,7 @@ def all_trainable():
# all partitions have the same elements
self.assertEqual(len(o.param_groups), 2)
self.assertEqual(
sum([x.numel() for g in o.optim.param_groups for x in g["params"]]),
sum(x.numel() for g in o.optim.param_groups for x in g["params"]),
sum(sizes),
)
self.assertEqual(len(o.optim.param_groups), 2)
Expand Down
2 changes: 1 addition & 1 deletion test/distributed/pipeline/sync/test_transparency.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

def test_simple_linears(setup_rpc):
def sum_grad(parameters):
return sum([p.grad.sum() for p in parameters if p.grad is not None])
return sum(p.grad.sum() for p in parameters if p.grad is not None)

def zero_grad(parameters):
for p in parameters:
Expand Down
4 changes: 2 additions & 2 deletions test/functorch/discover_coverage.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,11 +254,11 @@ def get_num_usages(opname):

# get all operators that are not in the denylist
all_ops = get_top_ops(999999, 999999)
total_op_usages = sum([get_num_usages(op) for op in all_ops])
total_op_usages = sum(get_num_usages(op) for op in all_ops)

# get subset of all operators
subset_ops = get_top_ops(torch_threshold, nn_fn_threshold)
subset_op_usages = sum([get_num_usages(op) for op in subset_ops])
subset_op_usages = sum(get_num_usages(op) for op in subset_ops)
return subset_op_usages / total_op_usages


Expand Down
2 changes: 1 addition & 1 deletion test/functorch/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,7 @@ def abs_if_complex(t):
# Reduce into single value for grad
if isinstance(result, torch.Tensor):
return abs_if_complex(result.sum())
result = sum([abs_if_complex(res.sum()) for res in result])
result = sum(abs_if_complex(res.sum()) for res in result)
return result

result = grad(wrapped_fn, diff_argnums)(*args, **kwargs)
Expand Down
2 changes: 1 addition & 1 deletion test/inductor/test_fx_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def parent_pass(module: torch.fx.GraphModule, input: Any) -> torch.fx.GraphModul

def count_call(module: torch.fx.GraphModule, op: str, target_op: Any) -> int:
return sum(
[1 if (n.op == op and n.target == target_op) else 0 for n in module.graph.nodes]
1 if (n.op == op and n.target == target_op) else 0 for n in module.graph.nodes
)


Expand Down
4 changes: 2 additions & 2 deletions test/nn/test_pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -892,11 +892,11 @@ def test_rnn_pruning(self):

# Pruning one of them causes one of the weights to become a tensor
prune.l1_unstructured(l, "weight_ih_l0", 0.5)
assert sum([isinstance(p, torch.nn.Parameter) for p in l._flat_weights]) == 3
assert sum(isinstance(p, torch.nn.Parameter) for p in l._flat_weights) == 3

# Removing the pruning reparametrization restores the Parameter
prune.remove(l, "weight_ih_l0")
assert sum([isinstance(p, torch.nn.Parameter) for p in l._flat_weights]) == 4
assert sum(isinstance(p, torch.nn.Parameter) for p in l._flat_weights) == 4

# Make sure that, upon removal of the reparametrization, the
# `._parameters` and `.named_parameters` contain the right params.
Expand Down
2 changes: 1 addition & 1 deletion test/onnx/test_fx_op_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -1443,7 +1443,7 @@ def skip_torchlib_forward_compatibility(
),
skip(
"linalg.multi_dot",
matcher=lambda sample: sum([torch.numel(input) for input in sample.input]) == 0,
matcher=lambda sample: sum(torch.numel(input) for input in sample.input) == 0,
reason="fixme: Undefined",
),
skip(
Expand Down
5 changes: 2 additions & 3 deletions test/profiler/test_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3016,10 +3016,9 @@ def test_utils_compute_self_time(self):
for event_key, event_metrics in metrics.items():
self.assertEqual(
event_metrics.self_time_ns,
event_key.event.duration_time_ns - sum([
event_key.event.duration_time_ns - sum(
child.duration_time_ns
for child in event_key.event.children
]))
for child in event_key.event.children))

def test_utils_intervals_overlap(self):
event = _utils.EventKey(MockProfilerEvent("Event 1", 1, 5, 5))
Expand Down
6 changes: 3 additions & 3 deletions test/quantization/fx/test_model_report_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1346,7 +1346,7 @@ def test_input_weight_equalization_determine_points(self):
# assert that each of the desired modules have the observers inserted
for fqn, module in prepared_for_callibrate_model.named_modules():
# check if module is a supported module
is_in_include_list = sum([isinstance(module, x) for x in mods_to_check]) > 0
is_in_include_list = sum(isinstance(module, x) for x in mods_to_check) > 0

if is_in_include_list:
# make sure it has the observer attribute
Expand Down Expand Up @@ -1563,7 +1563,7 @@ def test_outlier_detection_determine_points(self):
obs_name_to_find = InputWeightEqualizationDetector.DEFAULT_PRE_OBSERVER_NAME

number_of_obs_found = sum(
[1 if obs_name_to_find in str(node.target) else 0 for node in prepared_for_callibrate_model.graph.nodes]
1 if obs_name_to_find in str(node.target) else 0 for node in prepared_for_callibrate_model.graph.nodes
)
self.assertEqual(number_of_obs_found, correct_number_of_obs_inserted)

Expand Down Expand Up @@ -1753,7 +1753,7 @@ def test_multiple_run_consistent_spike_outlier_report_gen(self):
assert sum(counts_info) >= 2

# half of the recorded max values should be what we set
matched_max = sum([val == 3.28e8 for val in module_dict[OutlierDetector.MAX_VALS_KEY]])
matched_max = sum(val == 3.28e8 for val in module_dict[OutlierDetector.MAX_VALS_KEY])
self.assertEqual(matched_max, param_size / 2)


Expand Down
2 changes: 1 addition & 1 deletion test/test_bundled_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ def {}(self, value: Optional[List[Tensor]]):
# two args which have InflatableArg with fmt_fn
# 1 * 2 * 2 = 4
self.assertEqual(
sum([method.startswith("_inflate_helper") for method in methods]), 4
sum(method.startswith("_inflate_helper") for method in methods), 4
)


Expand Down
2 changes: 1 addition & 1 deletion test/test_flop_counter.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def FlopCounterMode(*args, **kwargs):
return torch.utils.flop_counter.FlopCounterMode(*args, **kwargs, display=False)

def get_total_flops(mode):
return str(sum([v for _, v in mode.flop_counts["Global"].items()]))
return str(sum(v for _, v in mode.flop_counts["Global"].items()))

def T(*shape, requires_grad=False):
return torch.randn(*shape, requires_grad=requires_grad)
Expand Down
4 changes: 1 addition & 3 deletions test/test_foreach.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,9 +334,7 @@ def clone(arg):
[rhs_arg, tensors], is_cuda=False, expect_fastpath=False
)
).mean().backward()
sum(
[ref.func(ref_rhs_arg, t) for t in ref_tensors]
).mean().backward()
sum(ref.func(ref_rhs_arg, t) for t in ref_tensors).mean().backward()
self.assertEqual(
[t.grad for t in tensors], [t.grad for t in ref_tensors]
)
Expand Down
6 changes: 3 additions & 3 deletions test/test_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -3593,17 +3593,17 @@ def verify_pytree(f, inp):
self.assertEqual(nf.graph.process_outputs(bare_fx(*nf.graph.process_inputs(val))), orig_out)

assert num_flat_args == 0 or "tree_flatten_spec" in nf.code
assert sum([i.op == 'placeholder' for i in nf.graph.nodes]) == num_flat_args
assert sum(i.op == 'placeholder' for i in nf.graph.nodes) == num_flat_args

nf = symbolic_trace(nf)
self.assertEqual(nf(val), orig_out)
assert "tree_flatten_spec" not in nf.code
assert sum([i.op == 'placeholder' for i in nf.graph.nodes]) == 1
assert sum(i.op == 'placeholder' for i in nf.graph.nodes) == 1

nf = symbolic_trace(nf, concrete_args={'x': inp})
self.assertEqual(nf(val), orig_out)
assert num_flat_args == 0 or "tree_flatten_spec" in nf.code
assert sum([i.op == 'placeholder' for i in nf.graph.nodes]) == num_flat_args
assert sum(i.op == 'placeholder' for i in nf.graph.nodes) == num_flat_args

pickled = pickle.dumps(nf)
nf = pickle.loads(pickled)
Expand Down
4 changes: 2 additions & 2 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1699,14 +1699,14 @@ def check_weight_norm(l, name, num_params):
# Applying weight norm on one of them causes it to become a tensor
l = torch.nn.utils.weight_norm(l, name=name)
self.assertEqual(
sum([isinstance(p, torch.nn.Parameter) for p in l._flat_weights]),
sum(isinstance(p, torch.nn.Parameter) for p in l._flat_weights),
num_params - 1,
)

# Removing the weight norm reparametrization restores the Parameter
l = torch.nn.utils.remove_weight_norm(l, name=name)
self.assertEqual(
sum([isinstance(p, torch.nn.Parameter) for p in l._flat_weights]),
sum(isinstance(p, torch.nn.Parameter) for p in l._flat_weights),
num_params,
)

Expand Down
8 changes: 4 additions & 4 deletions test/test_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,7 @@ def test_rosenbrock_sparse(self, device, dtype, optim_info, with_lrsched):

solution = torch.tensor([1, 1])
with torch.no_grad():
initial_dist = sum([param.dist(solution) for param in params])
initial_dist = sum(param.dist(solution) for param in params)

def get_grad(param, sparse_grad, w):
grad = drosenbrock(param)
Expand Down Expand Up @@ -410,13 +410,13 @@ def eval(params, sparse_grad, w):

if not kwargs.get("maximize", False):
self.assertLessEqual(
sum([param.dist(solution) for param in params]),
sum(param.dist(solution) for param in params),
initial_dist
)
else:
self.assertGreaterEqual(
sum([rosenbrock(param) for param in params]),
sum([rosenbrock(param_t) for param_t in params_t]),
sum(rosenbrock(param) for param in params),
sum(rosenbrock(param_t) for param_t in params_t),
)


Expand Down
8 changes: 4 additions & 4 deletions test/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2223,10 +2223,10 @@ def test_mem_eff_attention_non_contig_mask_bug(self, device):
attn_mask_strides = (14, 14, 14, 1)

# Calculate the number of elements needed for each tensor
query_num_elements = max([size * stride for size, stride in zip(query_size, query_strides)])
key_num_elements = max([size * stride for size, stride in zip(key_size, key_strides)])
value_num_elements = max([size * stride for size, stride in zip(value_size, value_strides)])
attention_mask_num_elements = max([size * stride for size, stride in zip(attention_mask_size, attn_mask_strides)])
query_num_elements = max(size * stride for size, stride in zip(query_size, query_strides))
key_num_elements = max(size * stride for size, stride in zip(key_size, key_strides))
value_num_elements = max(size * stride for size, stride in zip(value_size, value_strides))
attention_mask_num_elements = max(size * stride for size, stride in zip(attention_mask_size, attn_mask_strides))

# Create the tensors with the specified sizes and strides
query = torch.randn(query_num_elements, device=device).as_strided(query_size, query_strides)
Expand Down
4 changes: 2 additions & 2 deletions torch/_dynamo/bytecode_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,8 +244,8 @@ def stacksize_analysis(instructions) -> Union[int, float]:
stack_size = stack_sizes[inst]
print(stack_size.low, stack_size.high, inst)

low = min([x.low for x in stack_sizes.values()])
high = max([x.high for x in stack_sizes.values()])
low = min(x.low for x in stack_sizes.values())
high = max(x.high for x in stack_sizes.values())

assert fixed_point.value, "failed to reach fixed point"
assert low >= 0
Expand Down
4 changes: 2 additions & 2 deletions torch/_dynamo/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def reduce_to_scalar_loss(out):
# Mean does not work on integer tensors
return out.sum() / out.numel()
elif isinstance(out, (list, tuple)):
return sum([reduce_to_scalar_loss(x) for x in out]) / len(out)
return sum(reduce_to_scalar_loss(x) for x in out) / len(out)
elif type(out).__name__ in (
"MaskedLMOutput",
"Seq2SeqLMOutput",
Expand All @@ -115,7 +115,7 @@ def reduce_to_scalar_loss(out):
elif type(out).__name__ == "SquashedNormal":
return out.mean.sum()
elif isinstance(out, dict):
return sum([reduce_to_scalar_loss(value) for value in out.values()]) / len(
return sum(reduce_to_scalar_loss(value) for value in out.values()) / len(
out.keys()
)
raise NotImplementedError("Don't know how to reduce", type(out))
Expand Down
2 changes: 1 addition & 1 deletion torch/_dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1594,7 +1594,7 @@ def graph_break_report():

def recompilation_report():
if len(gf):
max_recompiles = max([num_recompiles(code) for code in gf])
max_recompiles = max(num_recompiles(code) for code in gf)
recomp_table = tabulate(
summarized_gf,
headers=["Function", "Recompiles", "Recompile Reasons"],
Expand Down
2 changes: 1 addition & 1 deletion torch/_functorch/partitioners.py
Original file line number Diff line number Diff line change
Expand Up @@ -1296,7 +1296,7 @@ def find_first_unfusible(start_nodes: List[fx.Node], max_range: int) -> int:
storages = {get_node_storage(node) for node in saved_values}
print(
"Theoretical Activations Stored: ",
sum([_size_of(i) for i in saved_values]) / 1e9,
sum(_size_of(i) for i in saved_values) / 1e9,
)
sorted_sizes = sorted([(_size_of(i), str(i)) for i in saved_values])
fw_module_nodes = {
Expand Down
6 changes: 2 additions & 4 deletions torch/_inductor/comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,10 +215,8 @@ def schedule_nodes(snodes):
assert_no_comm_nodes(needed_by_next_comm_and_ready_compute_nodes)

total_compute_runtime_cost = rolled_over_compute_cost + sum(
[
estimate_op_runtime(node)
for node in needed_by_next_comm_and_ready_compute_nodes
]
estimate_op_runtime(node)
for node in needed_by_next_comm_and_ready_compute_nodes
)
prev_comm_runtime_cost = estimate_op_runtime(comm_nodes[idx - 1])
schedule_nodes(tuple_sorted(needed_by_next_comm_and_ready_compute_nodes))
Expand Down
24 changes: 8 additions & 16 deletions torch/_inductor/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -3642,16 +3642,12 @@ def max_pool2d_with_indices_backward(
new_size = list(x.get_size())

h_window_size = max(
[
max(h // stride[0] - max(0, (h - kernel_size[0]) // stride[0]), 1)
for h in range(kernel_size[0] * 2)
]
max(h // stride[0] - max(0, (h - kernel_size[0]) // stride[0]), 1)
for h in range(kernel_size[0] * 2)
)
w_window_size = max(
[
max(w // stride[1] - max(0, (w - kernel_size[1]) // stride[1]), 1)
for w in range(kernel_size[1] * 2)
]
max(w // stride[1] - max(0, (w - kernel_size[1]) // stride[1]), 1)
for w in range(kernel_size[1] * 2)
)

window_size = h_window_size * w_window_size
Expand Down Expand Up @@ -4353,16 +4349,12 @@ def avg_pool2d_backward(
dtype = x.get_dtype()

h_window_size = max(
[
max(h // stride[0] - max(0, (h - kernel_size[0]) // stride[0]), 1)
for h in range(kernel_size[0] * 2)
]
max(h // stride[0] - max(0, (h - kernel_size[0]) // stride[0]), 1)
for h in range(kernel_size[0] * 2)
)
w_window_size = max(
[
max(w // stride[1] - max(0, (w - kernel_size[1]) // stride[1]), 1)
for w in range(kernel_size[1] * 2)
]
max(w // stride[1] - max(0, (w - kernel_size[1]) // stride[1]), 1)
for w in range(kernel_size[1] * 2)
)

window_size = h_window_size * w_window_size
Expand Down
6 changes: 3 additions & 3 deletions torch/_inductor/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,7 +538,7 @@ def is_materialized(buf, snodes):
node_bytes = 0

for buf_name in reads | writes:
buf_accessed_elems = sum([node_numel for dep in buf_accesses[buf_name]])
buf_accessed_elems = sum(node_numel for dep in buf_accesses[buf_name])
buf: Union[ir.Buffer, ir.TensorBox]
if buf_name in V.graph.name_to_buffer:
buf = V.graph.name_to_buffer[buf_name]
Expand Down Expand Up @@ -868,8 +868,8 @@ def __init__(self, scheduler: "Scheduler", snodes: List[SchedulerNode]):
for dep in set.union(*[x.unmet_dependencies for x in snodes])
if dep.name not in self.get_names()
} - self.read_writes.writes
self.min_order = min([x.min_order for x in self.snodes])
self.max_order = max([x.max_order for x in self.snodes])
self.min_order = min(x.min_order for x in self.snodes)
self.max_order = max(x.max_order for x in self.snodes)

@cache_on_self
def get_name(self) -> str:
Expand Down
6 changes: 2 additions & 4 deletions torch/amp/grad_scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,10 +426,8 @@ def step(
found_inf = cast(
torch.Tensor,
sum(
[
t.to(scaler.device, non_blocking=True)
for t in optimizer_state["found_inf_per_device"].values()
]
t.to(scaler.device, non_blocking=True)
for t in optimizer_state["found_inf_per_device"].values()
),
)
optimizer.grad_scale = ( # type: ignore[attr-defined]
Expand Down
Loading

0 comments on commit 1d6c597

Please sign in to comment.