Skip to content

Commit 0f58543

Browse files
mresoJack-Khuu
andauthored
Remove last references to use_distributed argument (#1353)
Co-authored-by: Jack-Khuu <jack.khuu.7@gmail.com>
1 parent 2fcc37c commit 0f58543

File tree

2 files changed

+3
-87
lines changed

2 files changed

+3
-87
lines changed

torchchat/cli/builder.py

Lines changed: 1 addition & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,6 @@
1616
import torch._inductor.config
1717
import torch.nn as nn
1818

19-
from torch.distributed.device_mesh import DeviceMesh
20-
from torch.distributed.elastic.multiprocessing.errors import record
21-
from torch.distributed.elastic.utils.distributed import get_free_port
22-
23-
from torchchat.distributed import launch_distributed, ParallelDims, parallelize_llama
24-
2519
from torchchat.model import Model, ModelArgs, ModelType
2620

2721
from torchchat.model_config.model_config import resolve_model_config
@@ -464,77 +458,11 @@ def _load_model_default(builder_args: BuilderArgs) -> Model:
464458
return model
465459

466460

467-
def _maybe_init_distributed(
468-
builder_args: BuilderArgs,
469-
) -> Tuple[Optional[DeviceMesh], Optional[ParallelDims]]:
470-
"""
471-
Initialize distributed related setups if the user specified
472-
using distributed inference. If not, this is a no-op.
473-
474-
Args:
475-
builder_args (:class:`BuilderArgs`):
476-
Command args for model building.
477-
Returns:
478-
Tuple[Optional[DeviceMesh], Optional[ParallelDims]]:
479-
- The first element is an optional DeviceMesh object,
480-
which which describes the mesh topology of devices for the DTensor.
481-
- The second element is an optional ParallelDims object,
482-
which represents the parallel dimensions configuration.
483-
"""
484-
if not builder_args.use_distributed:
485-
return None, None
486-
dist_config = "llama3_8B.toml" # TODO - integrate with chat cmd line
487-
488-
world_mesh, parallel_dims = launch_distributed(dist_config)
489-
490-
assert (
491-
world_mesh is not None and parallel_dims is not None
492-
), f"failed to launch distributed using {dist_config}"
493-
494-
return world_mesh, parallel_dims
495-
496-
497-
def _maybe_parallelize_model(
498-
model: nn.Module,
499-
builder_args: BuilderArgs,
500-
world_mesh: DeviceMesh,
501-
parallel_dims: ParallelDims,
502-
) -> nn.Module:
503-
"""
504-
We parallelize the module and load the distributed checkpoint to the model
505-
if the user specifies using distributed inference. If not, this is a no-op.
506-
507-
Args:
508-
model (:class:`nn.Module`):
509-
Module to be parallelized.
510-
builder_args (:class:`BuilderArgs`):
511-
Command args for model building.
512-
world_mesh (:class:`DeviceMesh`):
513-
Object which describes the mesh topology
514-
of devices for the DTensor.
515-
parallel_dims (:class:`ParallelDims`):
516-
Object which represents the parallel dimensions configuration.
517-
Returns:
518-
A :class:`nn.Module` object which is parallelized and checkpoint loaded
519-
if the user specifies using distributed inference.
520-
"""
521-
if world_mesh is None:
522-
return model
523-
assert parallel_dims is not None
524-
print("Applying model parallel to model ...")
525-
parallelize_llama(model, world_mesh, parallel_dims)
526-
return load_checkpoints_to_model(model, builder_args, world_mesh)
527-
528-
529461
def _load_model(builder_args: BuilderArgs) -> Model:
530-
# world_mesh, parallel_dims = _maybe_init_distributed(builder_args)
531462
if builder_args.gguf_path:
532463
model = _load_model_gguf(builder_args)
533-
# elif builder_args.use_distributed:
534-
# model = _init_model_on_meta_device(builder_args)
535464
else:
536465
model = _load_model_default(builder_args)
537-
# model = _maybe_parallelize_model(model, builder_args, world_mesh, parallel_dims)
538466

539467
if builder_args.dso_path or builder_args.aoti_package_path:
540468
# AOTI-compoiled model will load its own weights.
@@ -706,4 +634,4 @@ def tokenizer_setting_to_name(tiktoken: bool, tokenizers: bool) -> str:
706634
return "TikToken"
707635
if tokenizers:
708636
return "Tokenizers"
709-
return "SentencePiece"
637+
return "SentencePiece"

torchchat/generate.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -915,13 +915,6 @@ def chat(
915915
]
916916
)
917917
if generator_args.compile:
918-
if (
919-
self.is_speculative and self.builder_args.use_distributed
920-
): # and ("cuda" in builder_args.device):
921-
torch._inductor.config.triton.cudagraph_trees = (
922-
False # Bug with cudagraph trees in this case
923-
)
924-
925918
if self.builder_args.device == "cpu":
926919
if generator_args.max_autotune:
927920
kwargs = {"mode": "max-autotune"}
@@ -1091,9 +1084,7 @@ def callback(x, *, done_generating=False):
10911084

10921085
torch._inductor.config.profiler_mark_wrapper_call = True
10931086
torch._inductor.config.cpp.enable_kernel_profile = True
1094-
if (i != generator_args.num_samples - 1 or not self.profile) or (
1095-
self.builder_args.use_distributed and self.rank != 0
1096-
):
1087+
if i != generator_args.num_samples - 1 or not self.profile:
10971088
import contextlib
10981089

10991090
prof = contextlib.nullcontext()
@@ -1136,10 +1127,7 @@ def callback(x, *, done_generating=False):
11361127
print(prof.key_averages().table(sort_by="self_cpu_time_total"))
11371128
else:
11381129
print(prof.key_averages().table(sort_by="self_cuda_time_total"))
1139-
if self.builder_args.use_distributed:
1140-
prof.export_chrome_trace(f"{self.profile}_rank_{self.rank}.json")
1141-
else:
1142-
prof.export_chrome_trace(f"{self.profile}.json")
1130+
prof.export_chrome_trace(f"{self.profile}.json")
11431131

11441132
if start_pos >= max_seq_length:
11451133
print(

0 commit comments

Comments
 (0)