- 
                Notifications
    
You must be signed in to change notification settings  - Fork 3.2k
 
Add Pallas core_map guide #33048
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add Pallas core_map guide #33048
Conversation
          Summary of ChangesHello @IvyZX, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request adds a comprehensive guide to Pallas core-specific programming using  Highlights
 Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either  
 Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a  Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
  | 
    
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This PR adds a new guide for Pallas core_map on TPU. The guide is well-structured and covers important concepts from basic per-core kernels to more advanced topics like pipelining, scalar prefetch, and SparseCores. The examples are clear and helpful.
I've found a few issues, mainly minor typos in the documentation and a more significant issue with the barrier implementation in the first example. The barrier is incorrect and inefficient, which could be misleading for users. I've suggested a correct and efficient implementation.
Most of the typos are present in both the Jupyter notebook and the generated Markdown file. It would be best to fix them in the source notebook and regenerate the Markdown file.
| { | ||
| "cell_type": "markdown", | ||
| "source": [ | ||
| "# Pallas Core-specifc Programming" | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| { | ||
| "cell_type": "markdown", | ||
| "source": [ | ||
| "In addition to the typical TPU device mesh, you need to make a mesh of cores. Consider this as an addition dimension called \"core\", with length 2, in addition to the 4-device mesh you work with. That is 8 cores in total." | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Typo: 'addition' should be 'additional'.
| "In addition to the typical TPU device mesh, you need to make a mesh of cores. Consider this as an addition dimension called \"core\", with length 2, in addition to the 4-device mesh you work with. That is 8 cores in total." | |
| "In addition to the typical TPU device mesh, you need to make a mesh of cores. Consider this as an additional dimension called \"core\", with length 2, in addition to the 4-device mesh you work with. That is 8 cores in total." | 
| "\n", | ||
| "**Parallelize work per core**\n", | ||
| "\n", | ||
| "Since you are programming on the core level, you get to customize exactly how the work is splitted amongst cores. To do that, you need to:\n", | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Typo: 'splitted' should be 'split'.
| "Since you are programming on the core level, you get to customize exactly how the work is splitted amongst cores. To do that, you need to:\n", | |
| "Since you are programming on the core level, you get to customize exactly how the work is split amongst cores. To do that, you need to:\n", | 
| "source": [ | ||
| "## Scalar prefetch\n", | ||
| "\n", | ||
| "The code below extended the kernel above but uses [scalar prefetch and dynamic block indexing](https://docs.jax.dev/en/latest/pallas/tpu/sparse.html) to select a specific sub-slice of the input.\n", | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Typo: 'extended' should be 'extends'.
| "The code below extended the kernel above but uses [scalar prefetch and dynamic block indexing](https://docs.jax.dev/en/latest/pallas/tpu/sparse.html) to select a specific sub-slice of the input.\n", | |
| "The code below extends the kernel above but uses [scalar prefetch and dynamic block indexing](https://docs.jax.dev/en/latest/pallas/tpu/sparse.html) to select a specific sub-slice of the input.\n", | 
| "source": [ | ||
| "## Mapping over SparseCores\n", | ||
| "\n", | ||
| "TPU v4 and above includes a [SparseCore](https://openxla.org/xla/sparsecore), which is specialized in sparse memory access and operations. This guide will not dive into the capabilities of SparseCore, but rather show how to run a program on SparseCore with same semantics and minimal changes from the TensorCore code.\n", | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Grammar: 'with same semantics' should be 'with the same semantics'.
| "TPU v4 and above includes a [SparseCore](https://openxla.org/xla/sparsecore), which is specialized in sparse memory access and operations. This guide will not dive into the capabilities of SparseCore, but rather show how to run a program on SparseCore with same semantics and minimal changes from the TensorCore code.\n", | |
| "TPU v4 and above includes a [SparseCore](https://openxla.org/xla/sparsecore), which is specialized in sparse memory access and operations. This guide will not dive into the capabilities of SparseCore, but rather show how to run a program on SparseCore with the same semantics and minimal changes from the TensorCore code.\n", | 
| { | ||
| "cell_type": "markdown", | ||
| "source": [ | ||
| "The code below is very similar from the `add_one_kernel` we wrote earlier, except for a few differences:\n", | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| 
               | 
          ||
| * **Flexible pipelining**: You have the option to write pipelining communications on your own, instead of relying on Pallas grids and specs. This is helpful if your pipeline diverges from the standard "copy-in, compute & copy-out" pattern. | ||
| 
               | 
          ||
| * **Collectives**: Since `core_map` allows inter-core communications, it is especially helpful when writing collectives on the core level. | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought pallas_call allows this too? Maybe it would help to mention that the way core-specific code in pallas is done is quite indirect and not user-friendly. You have to set the grid=(num_cores,) and mark that dimension as PARALLEL.
| 
               | 
          ||
| ## Environment setup | ||
| 
               | 
          ||
| Modern accelerators often have multiple cores under a device. For TPU chips higher than v4, every JAX device by default contains two TensorCores (aka. a [Megacore](https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#chips)). They also contain a [SparseCore](https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#sparsecore), consisting of many subcores. | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Only v4 and v5p have megacore. v7 lacks megacore and instead comes in pairs of chips with separate HBM. Also I think only v5p/v6/v7 have sparsecore.
| for i in range(num_devices): | ||
| for j in range(num_cores): | ||
| pltpu.semaphore_signal(sem0, 1, device_id={'device': i, 'core': j}) | ||
| pltpu.semaphore_wait(sem0, num_devices * num_cores) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is it necessary to barrier with everything - can you just barrier with the cores you are computing with?
| 
               | 
          ||
| **Parallelize work per core** | ||
| 
               | 
          ||
| Since you are programming on the core level, you get to customize exactly how the work is splitted amongst cores. To do that, you need to: | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how the work is splitted -> how the work is split
| 
               | 
          ||
| 1. Provide an `index_map` function that, given the iteration indices, return *the slice* of the input data that shall be passed in. | ||
| 
               | 
          ||
| 1. On `BlockSpec`, wrap the corresponding dimension with `pl.BoundedSlice`, indicating the `index_map` function would return a slice instead of a iteration index on that dimension. | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you have done the same thing with index_map = (core_idx * core_slc_size // 8 + i, j) and just use a normal (8, 128) block shape?
| 
               | 
          ||
| You could make a shortcut `kernel()` that wraps all the `shard_map`, `core_map` and `run_scoped` boilerplates. | ||
| 
               | 
          ||
| Some similar APIs are currently available in Pallas package, such as `plgpu.kernel` and `plsc.kernel`. A unified API may be released soon. | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's just add this to Pallas now? Any thoughts @sharadmv?
| @@ -0,0 +1,358 @@ | |||
| # Pallas Core-specifc Programming | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typo: specifc -> specific
| 
               | 
          ||
| * **Per-core level programming**: You write code for an TPU/GPU core, not for a JAX device. This is crucial if you want to specifically control a core, or how cores communicate and distribute work among one another. | ||
| 
               | 
          ||
| * **Flexible pipelining**: You have the option to write pipelining communications on your own, instead of relying on Pallas grids and specs. This is helpful if your pipeline diverges from the standard "copy-in, compute & copy-out" pattern. | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can already do this in pallas_call, by simply not using the grid.
| "timestamp": 1761945248463 | ||
| } | ||
| ], | ||
| "last_runtime": { | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this should not be in the ipynb
| 
               | 
          ||
| ## Environment setup | ||
| 
               | 
          ||
| Modern accelerators often have multiple cores under a device. For TPU chips higher than v4, every JAX device by default contains two TensorCores (aka. a [Megacore](https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#chips)). They also contain a [SparseCore](https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#sparsecore), consisting of many subcores. | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd say for "For chips such as TPU v5p" to be precise
| 
               | 
          ||
| Modern accelerators often have multiple cores under a device. For TPU chips higher than v4, every JAX device by default contains two TensorCores (aka. a [Megacore](https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#chips)). They also contain a [SparseCore](https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#sparsecore), consisting of many subcores. | ||
| 
               | 
          ||
| This guide was written on a v5p chip, which contains 4 devices (2 TensorCores each) and a SparseCore of 16 subcores. | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"SparseCore with 16 subcores"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
4 SparseCores, each with 16 (vector) subcores. We can link to https://openxla.org/xla/sparsecore#specifications_at_a_glance.
| 
               | 
          ||
| `pl.core_map` allows you to write per-core local code, just as `jax.shard_map` allows you to write per-device code. | ||
| 
               | 
          ||
| In the example kernel below, each core has its own VMEM and semaphore allocations. As with normal kernel, you can initiate copy between HBM and VMEM refs using `async_copy`. | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"can initiate copy" -> "can initiate copies"
"async_copy" -> "pl.async_copy"
| 
               | 
          ||
| In the example kernel below, each core has its own VMEM and semaphore allocations. As with normal kernel, you can initiate copy between HBM and VMEM refs using `async_copy`. | ||
| 
               | 
          ||
| **Communication amongst cores** | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"amongst" -> "between"
| 
               | 
          ||
| * Call it inside a `pl.core_map`, which takes the TensorCore mesh. | ||
| 
               | 
          ||
| * You would need `collective_id` if there exists inter-core communications. | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You will need collective_id for the barrier semaphore
| 
               | 
          ||
| ## Pipelining with `core_map` | ||
| 
               | 
          ||
| Note that the kernel above only does simple copies and computes, without automatic pipelining via Pallas `grid` and `BlockSpec`. To do pipelining inside `core_map`, use `pltpu.emit_pipeline` inside the core-local kernel. | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"computes" -> "compute"
| 
               | 
          ||
| **Parallelize work per core** | ||
| 
               | 
          ||
| Since you are programming on the core level, you get to customize exactly how the work is splitted amongst cores. To do that, you need to: | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"on the core level" -> "at the core level"
"splitted amongst cores" -> "split between cores"
| 
               | 
          ||
| 1. Provide an `index_map` function that, given the iteration indices, return *the slice* of the input data that shall be passed in. | ||
| 
               | 
          ||
| 1. On `BlockSpec`, wrap the corresponding dimension with `pl.BoundedSlice`, indicating the `index_map` function would return a slice instead of a iteration index on that dimension. | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that we don't strictly need BoundedSlice here. We can also use half of the original BlockSize and offset the index map. Also, emit_pipeline with core_axis_name also automatically partitions the grid
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think for this, the first example should be emit_pipeline with core_axis_name.
The second one could be a more custom splitting across cores.
| 
               | 
          ||
| ## Mapping over SparseCores | ||
| 
               | 
          ||
| TPU v4 and above includes a [SparseCore](https://openxla.org/xla/sparsecore), which is specialized in sparse memory access and operations. This guide will not dive into the capabilities of SparseCore, but rather show how to run a program on SparseCore with same semantics and minimal changes from the TensorCore code. | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that TPU v5e does not have SparseCore.
| 
               | 
          ||
| ## Environment setup | ||
| 
               | 
          ||
| Modern accelerators often have multiple cores under a device. For TPU chips higher than v4, every JAX device by default contains two TensorCores (aka. a [Megacore](https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#chips)). They also contain a [SparseCore](https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#sparsecore), consisting of many subcores. | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we say "SparseCores", since every chip has at least 2?
| from functools import partial | ||
| 
               | 
          ||
| import jax | ||
| from jax.sharding import NamedSharding, PartitionSpec as P | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: I think we have jax.P now.
| 
               | 
          ||
| **Communication amongst cores** | ||
| 
               | 
          ||
| Before making a inter-core communication, you may need to do a global barrier signal (`pltpu.semaphore_signal`), to make sure all the destination semaphores have been properly initialized. | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it "may" or "must"? If "may", it will be useful to explain when this is in fact required.
| 
               | 
          ||
| ## Mapping over SparseCores | ||
| 
               | 
          ||
| TPU v4 and above includes a [SparseCore](https://openxla.org/xla/sparsecore), which is specialized in sparse memory access and operations. This guide will not dive into the capabilities of SparseCore, but rather show how to run a program on SparseCore with same semantics and minimal changes from the TensorCore code. | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: should we say "SparseCores" here as well, instead of "a SparseCore" to highlight that it's >1 per chip.
| 
               | 
          ||
| sc_mesh = plsc.VectorSubcoreMesh( | ||
| core_axis_name="core", subcore_axis_name="subcore", | ||
| num_cores=sc_info.num_cores | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if we should default to sc_info.num_cores instead of requiring users to always query SC info?
| 
               | 
          ||
| 1. You need to split the work amongst all subcores, so a few lines to compute the specific slice for each subcore. | ||
| 
               | 
          ||
| 1. SparseCore register computation allows smaller slices (`4x16` max for int32), so you need nested loops to iterate the slice during computation phase. | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: 2.
Also, should we reference sc_info.num_lanes here and have a single loop reading out (num_lanes,) vectors? 4x16 relies on unrolling in the SC compiler, which only really works for a handful of datatypes atm.
No description provided.