Skip to content

Conversation

sshonTT
Copy link

@sshonTT sshonTT commented Jul 23, 2025

These changes are required to support multi-chip training for real models on the torch-xla side.

  • Added V2 OpSharding support in XlaShardingSpec, which is used internally by MpLoader for parallel input loading. The original implementation only supported V1 shardings.
  • Fixed environment variable parsing for CONVERT_SHLO_TO_SHARDY - previous logic treated values like "0" or "false" as truthy.
  • Added logic to compute dims, reshape_dims, and transpose_perm for V2 sharding based on mesh_shape and partition_spec.

The new logic now correctly handles cases that were previously unsupported:

case 1: mesh_shape=(2,1,1,1), partition_spec=(0,None,None,None)
-> dims=[2,1,1,1], reshape_dims=[2,1,1,1], transpose_perm=[0,1,2,3]

case 2: mesh_shape=(2,1,1,1), partition_spec=(0,)
-> dims=[2], reshape_dims=[2,1,1,1], transpose_perm=[0,1,2,3]

case 3: mesh_shape=(2,4), partition_spec=(0,None)
-> dims=[2,1,4], reshape_dims=[2,4], transpose_perm=[0,1]

@sshonTT sshonTT requested a review from hshahTT July 23, 2025 22:02
@sshonTT sshonTT force-pushed the sshon/xla-sharding-spec branch 2 times, most recently from b94f982 to f7ece6c Compare July 24, 2025 14:22
@hshahTT
Copy link

hshahTT commented Jul 24, 2025

Hey Sungjoon, could you add more details to the PR about why this is needed and an overview of the high-level changes? Also, are you sure this is needed "for GSPMD support"? The purpose of the V2 shardings is to enable Shardy support within torch-xla. Torch-xla should already natively support GSPMD without needing any changes.

@sshonTT
Copy link
Author

sshonTT commented Jul 24, 2025

Hey Sungjoon, could you add more details to the PR about why this is needed and an overview of the high-level changes? Also, are you sure this is needed "for GSPMD support"? The purpose of the V2 shardings is to enable Shardy support within torch-xla. Torch-xla should already natively support GSPMD without needing any changes.

Hey Het, thanks for the feedback — and sorry for the lack of detail earlier. I’ll update the commit message to include more context.

These changes are required to support multi-chip training for real models from the torch-xla side. Specifically:

We use MpLoader for parallel input loading, which internally calls xla_spec and _XLAC.XlaShardingSpec. That code currently creates an OpSharding using the v1 format, so I added v2 support alongside it.

I also updated how we handle os.environ.get('CONVERT_SHLO_TO_SHARDY', False), since the previous logic treated values like "0" or "false" as truthy, which was causing unexpected behavior.

And you're right — GSPMD itself is already supported in torch-xla. What I meant was GSPMD-based training, and I’ll clarify that in the updated message.

@sshonTT sshonTT force-pushed the sshon/xla-sharding-spec branch 2 times, most recently from c394d60 to b695da2 Compare July 24, 2025 19:50
@sshonTT sshonTT changed the title Apply v2 version of sharding format for XlaShardingSpec module Add V2 sharding support and improve partition spec handling for multichip training Jul 24, 2025
initial_dims.append(1)

# 2. Start with the initial_dims.
dims = list(initial_dims)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can just rename initial_dims to `dims and remove L179, since you're not using initial dims after this line.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed! FWIW, it was working on the older branch (link), but not on the new one (link). I’m looking into it now. Unless it turns out to require changes in torch-xla, I probably won’t need anything, but if it does, I’ll ping you for a review.

…-chip training

These changes are required to support multi-chip training for real models on the torch-xla side.

- Added V2 OpSharding support in XlaShardingSpec, which is used internally by MpLoader for parallel input loading. The original implementation only supported V1 shardings.
- Fixed environment variable parsing for CONVERT_SHLO_TO_SHARDY - previous logic treated values like "0" or "false" as truthy.
- Added logic to compute dims, reshape_dims, and transpose_perm for V2 sharding based on mesh_shape and partition_spec.

The new logic now correctly handles cases that were previously unsupported:

  case 1: mesh_shape=(2,1,1,1), partition_spec=(0,None,None,None)
           -> dims=[2,1,1,1], reshape_dims=[2,1,1,1], transpose_perm=[0,1,2,3]

  case 2: mesh_shape=(2,1,1,1), partition_spec=(0,)
          Ã-> dims=[2], reshape_dims=[2,1,1,1], transpose_perm=[0,1,2,3]

  case 3: mesh_shape=(2,4), partition_spec=(0,None)
           -> dims=[2,1,4], reshape_dims=[2,4], transpose_perm=[0,1]
@sshonTT sshonTT force-pushed the sshon/xla-sharding-spec branch from b695da2 to d5b1b3b Compare July 31, 2025 15:35
@hshahTT hshahTT merged commit 7d989d1 into master Aug 2, 2025
hshahTT added a commit that referenced this pull request Aug 5, 2025
…chip training (#2)

* Add V2 sharding support and improve partition spec handling for multi-chip training

These changes are required to support multi-chip training for real models on the torch-xla side.

- Added V2 OpSharding support in XlaShardingSpec, which is used internally by MpLoader for parallel input loading. The original implementation only supported V1 shardings.
- Fixed environment variable parsing for CONVERT_SHLO_TO_SHARDY - previous logic treated values like "0" or "false" as truthy.
- Added logic to compute dims, reshape_dims, and transpose_perm for V2 sharding based on mesh_shape and partition_spec.

The new logic now correctly handles cases that were previously unsupported:

  case 1: mesh_shape=(2,1,1,1), partition_spec=(0,None,None,None)
           -> dims=[2,1,1,1], reshape_dims=[2,1,1,1], transpose_perm=[0,1,2,3]

  case 2: mesh_shape=(2,1,1,1), partition_spec=(0,)
          Ã-> dims=[2], reshape_dims=[2,1,1,1], transpose_perm=[0,1,2,3]

  case 3: mesh_shape=(2,4), partition_spec=(0,None)
           -> dims=[2,1,4], reshape_dims=[2,4], transpose_perm=[0,1]

* Fix formatting according to Torch-XLA style guide

---------

Co-authored-by: Het Shah <hshah@tenstorrent.com>
hshahTT added a commit that referenced this pull request Aug 30, 2025
…chip training (#2)

* Add V2 sharding support and improve partition spec handling for multi-chip training

These changes are required to support multi-chip training for real models on the torch-xla side.

- Added V2 OpSharding support in XlaShardingSpec, which is used internally by MpLoader for parallel input loading. The original implementation only supported V1 shardings.
- Fixed environment variable parsing for CONVERT_SHLO_TO_SHARDY - previous logic treated values like "0" or "false" as truthy.
- Added logic to compute dims, reshape_dims, and transpose_perm for V2 sharding based on mesh_shape and partition_spec.

The new logic now correctly handles cases that were previously unsupported:

  case 1: mesh_shape=(2,1,1,1), partition_spec=(0,None,None,None)
           -> dims=[2,1,1,1], reshape_dims=[2,1,1,1], transpose_perm=[0,1,2,3]

  case 2: mesh_shape=(2,1,1,1), partition_spec=(0,)
          Ã-> dims=[2], reshape_dims=[2,1,1,1], transpose_perm=[0,1,2,3]

  case 3: mesh_shape=(2,4), partition_spec=(0,None)
           -> dims=[2,1,4], reshape_dims=[2,4], transpose_perm=[0,1]

* Fix formatting according to Torch-XLA style guide

---------

Co-authored-by: Het Shah <hshah@tenstorrent.com>
hshahTT added a commit that referenced this pull request Sep 7, 2025
…chip training (#2)

* Add V2 sharding support and improve partition spec handling for multi-chip training

These changes are required to support multi-chip training for real models on the torch-xla side.

- Added V2 OpSharding support in XlaShardingSpec, which is used internally by MpLoader for parallel input loading. The original implementation only supported V1 shardings.
- Fixed environment variable parsing for CONVERT_SHLO_TO_SHARDY - previous logic treated values like "0" or "false" as truthy.
- Added logic to compute dims, reshape_dims, and transpose_perm for V2 sharding based on mesh_shape and partition_spec.

The new logic now correctly handles cases that were previously unsupported:

  case 1: mesh_shape=(2,1,1,1), partition_spec=(0,None,None,None)
           -> dims=[2,1,1,1], reshape_dims=[2,1,1,1], transpose_perm=[0,1,2,3]

  case 2: mesh_shape=(2,1,1,1), partition_spec=(0,)
          Ã-> dims=[2], reshape_dims=[2,1,1,1], transpose_perm=[0,1,2,3]

  case 3: mesh_shape=(2,4), partition_spec=(0,None)
           -> dims=[2,1,4], reshape_dims=[2,4], transpose_perm=[0,1]

* Fix formatting according to Torch-XLA style guide

---------

Co-authored-by: Het Shah <hshah@tenstorrent.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants