-
Notifications
You must be signed in to change notification settings - Fork 0
Add V2 sharding support and improve partition spec handling for multichip training #2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
b94f982
to
f7ece6c
Compare
Hey Sungjoon, could you add more details to the PR about why this is needed and an overview of the high-level changes? Also, are you sure this is needed "for GSPMD support"? The purpose of the V2 shardings is to enable Shardy support within torch-xla. Torch-xla should already natively support GSPMD without needing any changes. |
Hey Het, thanks for the feedback — and sorry for the lack of detail earlier. I’ll update the commit message to include more context. These changes are required to support multi-chip training for real models from the torch-xla side. Specifically: We use MpLoader for parallel input loading, which internally calls xla_spec and _XLAC.XlaShardingSpec. That code currently creates an OpSharding using the v1 format, so I added v2 support alongside it. I also updated how we handle os.environ.get('CONVERT_SHLO_TO_SHARDY', False), since the previous logic treated values like "0" or "false" as truthy, which was causing unexpected behavior. And you're right — GSPMD itself is already supported in torch-xla. What I meant was GSPMD-based training, and I’ll clarify that in the updated message. |
c394d60
to
b695da2
Compare
initial_dims.append(1) | ||
|
||
# 2. Start with the initial_dims. | ||
dims = list(initial_dims) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you can just rename initial_dims
to `dims and remove L179, since you're not using initial dims after this line.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
…-chip training These changes are required to support multi-chip training for real models on the torch-xla side. - Added V2 OpSharding support in XlaShardingSpec, which is used internally by MpLoader for parallel input loading. The original implementation only supported V1 shardings. - Fixed environment variable parsing for CONVERT_SHLO_TO_SHARDY - previous logic treated values like "0" or "false" as truthy. - Added logic to compute dims, reshape_dims, and transpose_perm for V2 sharding based on mesh_shape and partition_spec. The new logic now correctly handles cases that were previously unsupported: case 1: mesh_shape=(2,1,1,1), partition_spec=(0,None,None,None) -> dims=[2,1,1,1], reshape_dims=[2,1,1,1], transpose_perm=[0,1,2,3] case 2: mesh_shape=(2,1,1,1), partition_spec=(0,) Ã-> dims=[2], reshape_dims=[2,1,1,1], transpose_perm=[0,1,2,3] case 3: mesh_shape=(2,4), partition_spec=(0,None) -> dims=[2,1,4], reshape_dims=[2,4], transpose_perm=[0,1]
b695da2
to
d5b1b3b
Compare
…chip training (#2) * Add V2 sharding support and improve partition spec handling for multi-chip training These changes are required to support multi-chip training for real models on the torch-xla side. - Added V2 OpSharding support in XlaShardingSpec, which is used internally by MpLoader for parallel input loading. The original implementation only supported V1 shardings. - Fixed environment variable parsing for CONVERT_SHLO_TO_SHARDY - previous logic treated values like "0" or "false" as truthy. - Added logic to compute dims, reshape_dims, and transpose_perm for V2 sharding based on mesh_shape and partition_spec. The new logic now correctly handles cases that were previously unsupported: case 1: mesh_shape=(2,1,1,1), partition_spec=(0,None,None,None) -> dims=[2,1,1,1], reshape_dims=[2,1,1,1], transpose_perm=[0,1,2,3] case 2: mesh_shape=(2,1,1,1), partition_spec=(0,) Ã-> dims=[2], reshape_dims=[2,1,1,1], transpose_perm=[0,1,2,3] case 3: mesh_shape=(2,4), partition_spec=(0,None) -> dims=[2,1,4], reshape_dims=[2,4], transpose_perm=[0,1] * Fix formatting according to Torch-XLA style guide --------- Co-authored-by: Het Shah <hshah@tenstorrent.com>
…chip training (#2) * Add V2 sharding support and improve partition spec handling for multi-chip training These changes are required to support multi-chip training for real models on the torch-xla side. - Added V2 OpSharding support in XlaShardingSpec, which is used internally by MpLoader for parallel input loading. The original implementation only supported V1 shardings. - Fixed environment variable parsing for CONVERT_SHLO_TO_SHARDY - previous logic treated values like "0" or "false" as truthy. - Added logic to compute dims, reshape_dims, and transpose_perm for V2 sharding based on mesh_shape and partition_spec. The new logic now correctly handles cases that were previously unsupported: case 1: mesh_shape=(2,1,1,1), partition_spec=(0,None,None,None) -> dims=[2,1,1,1], reshape_dims=[2,1,1,1], transpose_perm=[0,1,2,3] case 2: mesh_shape=(2,1,1,1), partition_spec=(0,) Ã-> dims=[2], reshape_dims=[2,1,1,1], transpose_perm=[0,1,2,3] case 3: mesh_shape=(2,4), partition_spec=(0,None) -> dims=[2,1,4], reshape_dims=[2,4], transpose_perm=[0,1] * Fix formatting according to Torch-XLA style guide --------- Co-authored-by: Het Shah <hshah@tenstorrent.com>
…chip training (#2) * Add V2 sharding support and improve partition spec handling for multi-chip training These changes are required to support multi-chip training for real models on the torch-xla side. - Added V2 OpSharding support in XlaShardingSpec, which is used internally by MpLoader for parallel input loading. The original implementation only supported V1 shardings. - Fixed environment variable parsing for CONVERT_SHLO_TO_SHARDY - previous logic treated values like "0" or "false" as truthy. - Added logic to compute dims, reshape_dims, and transpose_perm for V2 sharding based on mesh_shape and partition_spec. The new logic now correctly handles cases that were previously unsupported: case 1: mesh_shape=(2,1,1,1), partition_spec=(0,None,None,None) -> dims=[2,1,1,1], reshape_dims=[2,1,1,1], transpose_perm=[0,1,2,3] case 2: mesh_shape=(2,1,1,1), partition_spec=(0,) Ã-> dims=[2], reshape_dims=[2,1,1,1], transpose_perm=[0,1,2,3] case 3: mesh_shape=(2,4), partition_spec=(0,None) -> dims=[2,1,4], reshape_dims=[2,4], transpose_perm=[0,1] * Fix formatting according to Torch-XLA style guide --------- Co-authored-by: Het Shah <hshah@tenstorrent.com>
These changes are required to support multi-chip training for real models on the torch-xla side.
The new logic now correctly handles cases that were previously unsupported:
case 1: mesh_shape=(2,1,1,1), partition_spec=(0,None,None,None)
-> dims=[2,1,1,1], reshape_dims=[2,1,1,1], transpose_perm=[0,1,2,3]
case 2: mesh_shape=(2,1,1,1), partition_spec=(0,)
-> dims=[2], reshape_dims=[2,1,1,1], transpose_perm=[0,1,2,3]
case 3: mesh_shape=(2,4), partition_spec=(0,None)
-> dims=[2,1,4], reshape_dims=[2,4], transpose_perm=[0,1]