@@ -430,6 +430,7 @@ def transform_module(
430430 world_size : int ,
431431 batch_size : int ,
432432 ctx : ContextManager ,
433+ benchmark_unsharded_module : bool = False ,
433434) -> torch .nn .Module :
434435 def fx_script_module (eager_module : torch .nn .Module ) -> torch .nn .Module :
435436 eager_module (inputs [0 ])
@@ -441,52 +442,61 @@ def fx_script_module(eager_module: torch.nn.Module) -> torch.nn.Module:
441442
442443 set_propogate_device (True )
443444
444- topology : Topology = Topology ( world_size = world_size , compute_device = device . type )
445- planner = EmbeddingShardingPlanner (
446- topology = topology ,
447- batch_size = batch_size ,
448- enumerator = EmbeddingEnumerator (
445+ sharded_module = None
446+
447+ if not benchmark_unsharded_module :
448+ topology : Topology = Topology ( world_size = world_size , compute_device = device . type )
449+ planner = EmbeddingShardingPlanner (
449450 topology = topology ,
450451 batch_size = batch_size ,
451- estimator = [
452- EmbeddingPerfEstimator (topology = topology ),
453- EmbeddingStorageEstimator (topology = topology ),
454- ],
455- ),
456- )
457-
458- # Don't want to modify the module outright
459- # Since module is on cpu, won't cause cuda oom.
460- copied_module = copy .deepcopy (module )
461- # pyre-ignore [6]
462- plan = planner .plan (copied_module , [sharder ])
463-
464- if isinstance (ctx , MultiProcessContext ):
465- sharded_module = DistributedModelParallel (
466- copied_module ,
467- # pyre-ignore[6]
468- env = ShardingEnv .from_process_group (ctx .pg ),
469- plan = plan ,
470- # pyre-ignore[6]
471- sharders = [sharder ],
472- device = ctx .device ,
452+ enumerator = EmbeddingEnumerator (
453+ topology = topology ,
454+ batch_size = batch_size ,
455+ estimator = [
456+ EmbeddingPerfEstimator (topology = topology ),
457+ EmbeddingStorageEstimator (topology = topology ),
458+ ],
459+ ),
473460 )
474- else :
475- env = ShardingEnv .from_local (world_size = topology .world_size , rank = 0 )
476461
477- sharded_module = _shard_modules (
478- module = copied_module ,
479- # pyre-ignore [6]
480- sharders = [sharder ],
481- device = device ,
482- plan = plan ,
483- env = env ,
484- )
462+ # Don't want to modify the module outright
463+ # Since module is on cpu, won't cause cuda oom.
464+ copied_module = copy .deepcopy (module )
465+ # pyre-ignore [6]
466+ plan = planner .plan (copied_module , [sharder ])
467+
468+ if isinstance (ctx , MultiProcessContext ):
469+ sharded_module = DistributedModelParallel (
470+ copied_module ,
471+ # pyre-ignore[6]
472+ env = ShardingEnv .from_process_group (ctx .pg ),
473+ plan = plan ,
474+ # pyre-ignore[6]
475+ sharders = [sharder ],
476+ device = ctx .device ,
477+ )
478+ else :
479+ env = ShardingEnv .from_local (world_size = topology .world_size , rank = 0 )
480+
481+ sharded_module = _shard_modules (
482+ module = copied_module ,
483+ # pyre-ignore [6]
484+ sharders = [sharder ],
485+ device = device ,
486+ plan = plan ,
487+ env = env ,
488+ )
485489
486490 if compile_mode == CompileMode .FX_SCRIPT :
487- return fx_script_module (sharded_module )
491+ return fx_script_module (
492+ # pyre-ignore [6]
493+ sharded_module
494+ if not benchmark_unsharded_module
495+ else module
496+ )
488497 else :
489- return sharded_module
498+ # pyre-ignore [7]
499+ return sharded_module if not benchmark_unsharded_module else module
490500
491501
492502def benchmark (
@@ -504,6 +514,7 @@ def benchmark(
504514 rank : int ,
505515 enable_logging : bool = True ,
506516 device_type : str = "cuda" ,
517+ benchmark_unsharded_module : bool = False ,
507518) -> BenchmarkResult :
508519 max_mem_allocated : List [int ] = []
509520 if enable_logging :
@@ -667,6 +678,7 @@ def init_module_and_run_benchmark(
667678 rank : int = - 1 ,
668679 queue : Optional [mp .Queue ] = None ,
669680 pooling_configs : Optional [List [int ]] = None ,
681+ benchmark_unsharded_module : bool = False ,
670682) -> BenchmarkResult :
671683 """
672684 There are a couple of caveats here as to why the module has to be initialized
@@ -724,9 +736,13 @@ def init_module_and_run_benchmark(
724736 batch_size = batch_size ,
725737 # pyre-ignore[6]
726738 ctx = ctx ,
739+ benchmark_unsharded_module = benchmark_unsharded_module ,
727740 )
728741
729- name = benchmark_type_name (compile_mode , sharding_type )
742+ if benchmark_unsharded_module :
743+ name = "unsharded" + compile_mode .name
744+ else :
745+ name = benchmark_type_name (compile_mode , sharding_type )
730746
731747 res = benchmark (
732748 name ,
@@ -741,6 +757,7 @@ def init_module_and_run_benchmark(
741757 benchmark_func_kwargs = benchmark_func_kwargs ,
742758 rank = rank ,
743759 device_type = device .type ,
760+ benchmark_unsharded_module = benchmark_unsharded_module ,
744761 )
745762
746763 if queue is not None :
@@ -825,6 +842,7 @@ def benchmark_module(
825842 world_size : int = 2 ,
826843 num_benchmarks : int = 5 ,
827844 output_dir : str = "" ,
845+ benchmark_unsharded : bool = False ,
828846 func_to_benchmark : Callable [..., None ] = default_func_to_benchmark ,
829847 benchmark_func_kwargs : Optional [Dict [str , Any ]] = None ,
830848 pooling_configs : Optional [List [int ]] = None ,
@@ -896,13 +914,17 @@ def benchmark_module(
896914 ]
897915 prof_inputs = [rank_inputs [- prof_iters :] for rank_inputs in inputs ]
898916
899- for sharding_type in sharding_types :
917+ for sharding_type in sharding_types if not benchmark_unsharded else [ "Unsharded" ] :
900918 for compile_mode in compile_modes :
901- # Test sharders should have a singular sharding_type
902- # pyre-ignore [16]
903- sharder ._sharding_type = sharding_type .value
919+ if not benchmark_unsharded :
920+ # Test sharders should have a singular sharding_type
921+ # pyre-ignore [16]
922+ sharder ._sharding_type = sharding_type .value
923+ # pyre-ignore [6]
924+ benchmark_type = benchmark_type_name (compile_mode , sharding_type )
925+ else :
926+ benchmark_type = "unsharded" + compile_mode .name
904927
905- benchmark_type = benchmark_type_name (compile_mode , sharding_type )
906928 logging .info (
907929 f"\n \n ###### Running Benchmark Type: { benchmark_type } ######\n "
908930 )
@@ -933,6 +955,7 @@ def benchmark_module(
933955 module = wrapped_module ,
934956 sharder = sharder ,
935957 device = torch .device (device_type ),
958+ # pyre-ignore
936959 sharding_type = sharding_type ,
937960 compile_mode = compile_mode ,
938961 world_size = world_size ,
@@ -946,6 +969,7 @@ def benchmark_module(
946969 func_to_benchmark = func_to_benchmark ,
947970 benchmark_func_kwargs = benchmark_func_kwargs ,
948971 pooling_configs = pooling_configs ,
972+ benchmark_unsharded_module = benchmark_unsharded ,
949973 )
950974
951975 gc .collect ()
0 commit comments