Skip to content

Enable ruff format in pre-commit #2142

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,8 @@ repos:
hooks:
- id: ruff-check
args: ["--fix"]

- repo: https://github.com/psf/black
rev: 25.1.0
hooks:
- id: black
name: Black code
- id: ruff-format
types_or: [python]
exclude: "examples"

- repo: https://github.com/executablebooks/mdformat
Expand Down
1 change: 0 additions & 1 deletion thunder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,7 +553,6 @@ def apply_transforms_and_build_cache_entry(cd, cs, cache_info, prologue_trc, com

if requires_grad:
if delay_trace_split:

from thunder.transforms.autodiff import grad_transform_on_trace

computation_trc = grad_transform_on_trace(computation_trc)
Expand Down
22 changes: 11 additions & 11 deletions thunder/benchmarks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ def _prettyprint_stats(
if rank_mem_info:
short_printout += "\n " + "*" * 20 + " Memory Usage " + "*" * 20
for rank, (memory_allocated, memory_reserved) in rank_mem_info.items():
short_printout += f"\n rank-{rank} - peak allocated memory {memory_allocated/1024/1024:.2f}MB, peak reserved: {memory_reserved/1024/1024:.2f}MB"
short_printout += f"\n rank-{rank} - peak allocated memory {memory_allocated / 1024 / 1024:.2f}MB, peak reserved: {memory_reserved / 1024 / 1024:.2f}MB"
short_printout += "\n"

print(short_printout)
Expand All @@ -415,7 +415,7 @@ def _prettyprint_stats(
if rank_mem_info:
short_printout += "\n " + "*" * 20 + " Memory Usage " + "*" * 20
for rank, (memory_allocated, memory_reserved) in rank_mem_info.items():
short_printout += f"\n rank-{rank} - peak allocated memory {memory_allocated/1024/1024:.2f}MB, peak reserved: {memory_reserved/1024/1024:.2f}MB"
short_printout += f"\n rank-{rank} - peak allocated memory {memory_allocated / 1024 / 1024:.2f}MB, peak reserved: {memory_reserved / 1024 / 1024:.2f}MB"
short_printout += "\n"
if median_benchmark_stat.has_extended_stats:
# NOTE At this point in the program extended statistics are available
Expand Down Expand Up @@ -617,9 +617,9 @@ def run_multiprocess_benchmark(
print(f"Running distributed benchmark {benchmark.name} with {world_size=}")
_print_benchmark_arguments(benchmark)

assert (
torch.distributed.is_available()
), "Trying to run a distributed benchmark, but torch.distributed is not available"
assert torch.distributed.is_available(), (
"Trying to run a distributed benchmark, but torch.distributed is not available"
)

# Ensures the benchmark is running on a single CUDA device (which is overridden later)
assert (
Expand All @@ -629,14 +629,14 @@ def run_multiprocess_benchmark(

# Ensures the benchmark returns a module (because ddp is only supported on modules)
benchmark_fn = benchmark.fn()
assert isinstance(
benchmark_fn, torch.nn.Module
), "Distributed benchmarking currently only supports module benchmarks"
assert isinstance(benchmark_fn, torch.nn.Module), (
"Distributed benchmarking currently only supports module benchmarks"
)

# Validates world size
assert (
world_size <= torch.cuda.device_count()
), f"Requested world size of {world_size} is greater than the number of available cuda devices {torch.cuda.device_count()}"
assert world_size <= torch.cuda.device_count(), (
f"Requested world size of {world_size} is greater than the number of available cuda devices {torch.cuda.device_count()}"
)

FILE_SCHEMA: str = "file://"
if sys.platform == "win32":
Expand Down
35 changes: 16 additions & 19 deletions thunder/benchmarks/benchmark_litgpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,6 @@ def update_if_not_divisible(attr_name, divisor):


def swap_linear_layers_for_te(model: torch.nn.Module, device: Any, swap_layernorm: bool = True) -> None:

def parameters_cnt(model: torch.nn.Module) -> int:
return sum(p.numel() for p in model.parameters())

Expand Down Expand Up @@ -298,7 +297,6 @@ def __init__(
self.use_sdpa = False

if use_torchao_fp8_linear:

if not torchao_available:
raise ValueError("`torchao` is not available")
if self.distributed_mode not in ("none", "fsdp2"):
Expand All @@ -313,28 +311,28 @@ def __init__(

# Clarify benchmark assumptions
if self.sharding_size is not None:
assert (
"thunder" not in self.compile
), "Hybrid Sharding (FSDP/DP) using --sharding_size is not yet supported for Thunder. Coming soon."
assert "thunder" not in self.compile, (
"Hybrid Sharding (FSDP/DP) using --sharding_size is not yet supported for Thunder. Coming soon."
)

assert self.shard_mode in [
"hybrid_zero2",
"hybrid_zero3",
], "Sharding Size is only used with Hybrid FSDP/DP style parallelism."

assert (
world_size % self.sharding_size == 0
), f"World size {world_size} is not divisible by the sharding size {self.sharding_size}"
assert world_size % self.sharding_size == 0, (
f"World size {world_size} is not divisible by the sharding size {self.sharding_size}"
)

if self.bucketing_mode is not None and self.distributed_mode not in FSDP_MODES:
warnings.warn(
f"--bucketing_mode {self.bucketing_mode} will be ignored as "
f" it is only used for FSDP style parallelism but running {self.distributed_mode}"
)

assert not (
"thunder" in self.compile and self.bucketing_mode == "size"
), "'size' bucketing mode is not supported for Thunder. Please use 'none' or 'block'."
assert not ("thunder" in self.compile and self.bucketing_mode == "size"), (
"'size' bucketing mode is not supported for Thunder. Please use 'none' or 'block'."
)

if self.fsdp_bucket_params is not None:
if self.distributed_mode not in FSDP_MODES:
Expand All @@ -361,15 +359,15 @@ def __init__(
self.global_batch_size = (
self.micro_batch_size * world_size if world_size is not None else self.micro_batch_size
)
assert (
self.global_batch_size % self.micro_batch_size == 0
), f"Global Batch Size {self.global_batch_size} should be a multiple of Micro Batch Size {self.micro_batch_size}."
assert self.global_batch_size % self.micro_batch_size == 0, (
f"Global Batch Size {self.global_batch_size} should be a multiple of Micro Batch Size {self.micro_batch_size}."
)
self.gradient_accumulation_steps = int(self.global_batch_size / self.micro_batch_size)
if world_size:
self.gradient_accumulation_steps = int(self.gradient_accumulation_steps / world_size)
assert (
self.global_batch_size % self.micro_batch_size * world_size == 0
), f"Global Batch Size {self.global_batch_size} should be a multiple Micro Batch Size {self.micro_batch_size} * World Size {world_size}."
assert self.global_batch_size % self.micro_batch_size * world_size == 0, (
f"Global Batch Size {self.global_batch_size} should be a multiple Micro Batch Size {self.micro_batch_size} * World Size {world_size}."
)

self.skip_data_sync = skip_data_sync

Expand Down Expand Up @@ -628,7 +626,6 @@ def setup_compile(self, model):
executors.insert(0, torch_compile_ex)

if "transformerengine_v2" in self.compile:

from thunder.executors.transformer_engine_v2ex import (
transformer_engine_v2_ex,
TransformerEngineTransformV2,
Expand Down Expand Up @@ -913,7 +910,7 @@ def benchmark_main(return_metrics_as_json=False, json_path="", **kwargs) -> None
print(f"Sharding Mode: {benchmark.shard_mode}\nBucketing: {benchmark.bucketing_mode}")
if benchmark.sharding_size is not None:
print(
f"Sharding Size: {benchmark.sharding_size}\nReplicate DP Groups: {int(world_size/benchmark.sharding_size)}"
f"Sharding Size: {benchmark.sharding_size}\nReplicate DP Groups: {int(world_size / benchmark.sharding_size)}"
)
if benchmark.bucketing_mode == "size":
print(f"Bucketing Number Params: {benchmark.fsdp_bucket_params}")
Expand Down
2 changes: 1 addition & 1 deletion thunder/benchmarks/test_benchmark_litgpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def run_benchmark(self, kwargs):
]
subprocess_cmd.extend(command_list)

print(f'Running {" ".join(subprocess_cmd)!r}')
print(f"Running {' '.join(subprocess_cmd)!r}")
proc_output = subprocess.run(subprocess_cmd, capture_output=True, text=True)

self.perf_metrics_dict = {}
Expand Down
6 changes: 3 additions & 3 deletions thunder/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,9 +239,9 @@ def __init__(

# Validates executors list
for ex in self.executors_list:
assert isinstance(
ex, Executor
), f"Expected all elements of the executors list to be executors, but found {ex}"
assert isinstance(ex, Executor), (
f"Expected all elements of the executors list to be executors, but found {ex}"
)

# Resolves language context (defaulting to the torch language)
self.langctx = langctx if langctx is not None else resolve_language(Languages.TORCH)
Expand Down
18 changes: 9 additions & 9 deletions thunder/core/functionalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,13 +221,13 @@ def replace_args_with_alias_map(
}:
bsyms.append(bsym.from_bsym_swap_proxies(swap_map_for_aliases, skip_output=True))
if len(replaced_args_map) == 1:
bsyms[-1].header = (
f"[alias tensor args] `{list(replaced_args_map.keys())[0]}` is replaced by `{list(replaced_args_map.values())[0]}`"
)
bsyms[
-1
].header = f"[alias tensor args] `{list(replaced_args_map.keys())[0]}` is replaced by `{list(replaced_args_map.values())[0]}`"
else:
bsyms[-1].header = (
f"[alias tensor args] {list(replaced_args_map.keys())} are replaced by {list(replaced_args_map.values())}, respectively"
)
bsyms[
-1
].header = f"[alias tensor args] {list(replaced_args_map.keys())} are replaced by {list(replaced_args_map.values())}, respectively"
else:
bsyms.append(bsym)
no_implicit_alias_trace = from_trace(computation_trace)
Expand Down Expand Up @@ -298,9 +298,9 @@ def canonicalize_bsym_args(
new_bsym = new_bsym.from_bsym_swap_proxies(swap_map, skip_inputs=True, skip_subsymbols=True)
bsyms.append(new_bsym)
if cur_orig_to_view_swap_map:
bsyms[-1].header = (
f"Replace {[unvariableify(k) for k in cur_orig_to_view_swap_map]} with {[list(cur_orig_to_view_swap_map.values())]}"
)
bsyms[
-1
].header = f"Replace {[unvariableify(k) for k in cur_orig_to_view_swap_map]} with {[list(cur_orig_to_view_swap_map.values())]}"

intermediate_trace = from_trace(computation_trace)
intermediate_trace.bound_symbols = bsyms
Expand Down
Loading
Loading