@@ -4472,7 +4472,7 @@ def from_pretrained(
4472
4472
A torch tensor parallel degree. If not provided would default to world size.
4473
4473
device_mesh (`torch.distributed.DeviceMesh`, *optional*):
4474
4474
A torch device mesh. If not provided would default to world size. Used only for tensor parallel for now.
4475
- If provided, it has to contain dimension named `"tp"` which will be used for tensor parallelism
4475
+ If provided, it has to contain dimension named `"tp"` in case it's > 1 dimensional, this dimension will be used for tensor parallelism
4476
4476
offload_folder (`str` or `os.PathLike`, *optional*):
4477
4477
If the `device_map` contains any value `"disk"`, the folder where we will offload weights.
4478
4478
offload_state_dict (`bool`, *optional*):
@@ -4617,10 +4617,15 @@ def from_pretrained(
4617
4617
if device_mesh is None :
4618
4618
tp_plan , device_map , device_mesh , tp_size = initialize_tensor_parallelism (tp_plan , tp_size = tp_size )
4619
4619
else :
4620
- # TODO: make device_mesh support multiple dimensions
4621
4620
if device_mesh .ndim > 1 :
4622
- raise ValueError ("device_mesh must be 1 dimensional and will be used for TP" )
4623
- device_map = torch .device (device_mesh .device_type , int (os .environ ["LOCAL_RANK" ]))
4621
+ if "tp" not in device_mesh .mesh_dim_names :
4622
+ raise ValueError (
4623
+ "When using `tp_plan` and n-d `device_mesh`, it must contain a 'tp' dimension. "
4624
+ "Please provide a valid `device_mesh`."
4625
+ )
4626
+ device_mesh = device_mesh ["tp" ]
4627
+ tp_size = device_mesh .size ()
4628
+ device_map = torch .device (f"{ device_mesh .device_type } :{ int (os .environ ['LOCAL_RANK' ])} " )
4624
4629
4625
4630
if tp_size is None :
4626
4631
tp_size = torch .distributed .get_world_size ()
0 commit comments