16
16
import torch ._inductor .config
17
17
import torch .nn as nn
18
18
19
- from torch .distributed .device_mesh import DeviceMesh
20
- from torch .distributed .elastic .multiprocessing .errors import record
21
- from torch .distributed .elastic .utils .distributed import get_free_port
22
-
23
- from torchchat .distributed import launch_distributed , ParallelDims , parallelize_llama
24
-
25
19
from torchchat .model import Model , ModelArgs , ModelType
26
20
27
21
from torchchat .model_config .model_config import resolve_model_config
@@ -464,77 +458,11 @@ def _load_model_default(builder_args: BuilderArgs) -> Model:
464
458
return model
465
459
466
460
467
- def _maybe_init_distributed (
468
- builder_args : BuilderArgs ,
469
- ) -> Tuple [Optional [DeviceMesh ], Optional [ParallelDims ]]:
470
- """
471
- Initialize distributed related setups if the user specified
472
- using distributed inference. If not, this is a no-op.
473
-
474
- Args:
475
- builder_args (:class:`BuilderArgs`):
476
- Command args for model building.
477
- Returns:
478
- Tuple[Optional[DeviceMesh], Optional[ParallelDims]]:
479
- - The first element is an optional DeviceMesh object,
480
- which which describes the mesh topology of devices for the DTensor.
481
- - The second element is an optional ParallelDims object,
482
- which represents the parallel dimensions configuration.
483
- """
484
- if not builder_args .use_distributed :
485
- return None , None
486
- dist_config = "llama3_8B.toml" # TODO - integrate with chat cmd line
487
-
488
- world_mesh , parallel_dims = launch_distributed (dist_config )
489
-
490
- assert (
491
- world_mesh is not None and parallel_dims is not None
492
- ), f"failed to launch distributed using { dist_config } "
493
-
494
- return world_mesh , parallel_dims
495
-
496
-
497
- def _maybe_parallelize_model (
498
- model : nn .Module ,
499
- builder_args : BuilderArgs ,
500
- world_mesh : DeviceMesh ,
501
- parallel_dims : ParallelDims ,
502
- ) -> nn .Module :
503
- """
504
- We parallelize the module and load the distributed checkpoint to the model
505
- if the user specifies using distributed inference. If not, this is a no-op.
506
-
507
- Args:
508
- model (:class:`nn.Module`):
509
- Module to be parallelized.
510
- builder_args (:class:`BuilderArgs`):
511
- Command args for model building.
512
- world_mesh (:class:`DeviceMesh`):
513
- Object which describes the mesh topology
514
- of devices for the DTensor.
515
- parallel_dims (:class:`ParallelDims`):
516
- Object which represents the parallel dimensions configuration.
517
- Returns:
518
- A :class:`nn.Module` object which is parallelized and checkpoint loaded
519
- if the user specifies using distributed inference.
520
- """
521
- if world_mesh is None :
522
- return model
523
- assert parallel_dims is not None
524
- print ("Applying model parallel to model ..." )
525
- parallelize_llama (model , world_mesh , parallel_dims )
526
- return load_checkpoints_to_model (model , builder_args , world_mesh )
527
-
528
-
529
461
def _load_model (builder_args : BuilderArgs ) -> Model :
530
- # world_mesh, parallel_dims = _maybe_init_distributed(builder_args)
531
462
if builder_args .gguf_path :
532
463
model = _load_model_gguf (builder_args )
533
- # elif builder_args.use_distributed:
534
- # model = _init_model_on_meta_device(builder_args)
535
464
else :
536
465
model = _load_model_default (builder_args )
537
- # model = _maybe_parallelize_model(model, builder_args, world_mesh, parallel_dims)
538
466
539
467
if builder_args .dso_path or builder_args .aoti_package_path :
540
468
# AOTI-compoiled model will load its own weights.
@@ -706,4 +634,4 @@ def tokenizer_setting_to_name(tiktoken: bool, tokenizers: bool) -> str:
706
634
return "TikToken"
707
635
if tokenizers :
708
636
return "Tokenizers"
709
- return "SentencePiece"
637
+ return "SentencePiece"
0 commit comments