Skip to content

Conversation

@IvyZX
Copy link
Collaborator

@IvyZX IvyZX commented Oct 31, 2025

No description provided.

@IvyZX IvyZX requested review from justinjfu and sharadmv October 31, 2025 21:26
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @IvyZX, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request adds a comprehensive guide to Pallas core-specific programming using pl.core_map for TPUs. The guide covers fundamental concepts like per-core programming and inter-core communication, progresses to advanced pipelining and scalar prefetch techniques, and concludes with an example of mapping operations onto SparseCores, providing developers with detailed instructions and examples for fine-grained control over TPU hardware.

Highlights

  • Introduction to pl.core_map: The guide introduces pl.core_map for writing Pallas kernels, emphasizing its benefits over pallas_call for per-core programming, flexible pipelining, and inter-core collectives on TPUs.
  • Core-level Programming Examples: It provides practical examples, starting with a simple per-core kernel demonstrating VMEM and semaphore allocations, and inter-core communication using barriers and remote copies.
  • Pipelining and Work Splitting: The guide explains how to implement custom pipelining with pltpu.emit_pipeline and manually parallelize work across cores using index_map and pl.BoundedSlice.
  • Advanced Techniques: It covers advanced topics such as scalar prefetch and dynamic block indexing, showcasing how to use SMEM buffers and sync_copy for optimized data access.
  • SparseCore Integration: The guide demonstrates how to map operations over SparseCores, detailing the setup of VectorSubcoreMesh and handling work distribution across subcores for sparse memory access.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This PR adds a new guide for Pallas core_map on TPU. The guide is well-structured and covers important concepts from basic per-core kernels to more advanced topics like pipelining, scalar prefetch, and SparseCores. The examples are clear and helpful.

I've found a few issues, mainly minor typos in the documentation and a more significant issue with the barrier implementation in the first example. The barrier is incorrect and inefficient, which could be misleading for users. I've suggested a correct and efficient implementation.

Most of the typos are present in both the Jupyter notebook and the generated Markdown file. It would be best to fix them in the source notebook and regenerate the Markdown file.

{
"cell_type": "markdown",
"source": [
"# Pallas Core-specifc Programming"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There's a typo in 'specific'.

Suggested change
"# Pallas Core-specifc Programming"
"# Pallas Core-specific Programming"

{
"cell_type": "markdown",
"source": [
"In addition to the typical TPU device mesh, you need to make a mesh of cores. Consider this as an addition dimension called \"core\", with length 2, in addition to the 4-device mesh you work with. That is 8 cores in total."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Typo: 'addition' should be 'additional'.

Suggested change
"In addition to the typical TPU device mesh, you need to make a mesh of cores. Consider this as an addition dimension called \"core\", with length 2, in addition to the 4-device mesh you work with. That is 8 cores in total."
"In addition to the typical TPU device mesh, you need to make a mesh of cores. Consider this as an additional dimension called \"core\", with length 2, in addition to the 4-device mesh you work with. That is 8 cores in total."

"\n",
"**Parallelize work per core**\n",
"\n",
"Since you are programming on the core level, you get to customize exactly how the work is splitted amongst cores. To do that, you need to:\n",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Typo: 'splitted' should be 'split'.

Suggested change
"Since you are programming on the core level, you get to customize exactly how the work is splitted amongst cores. To do that, you need to:\n",
"Since you are programming on the core level, you get to customize exactly how the work is split amongst cores. To do that, you need to:\n",

"source": [
"## Scalar prefetch\n",
"\n",
"The code below extended the kernel above but uses [scalar prefetch and dynamic block indexing](https://docs.jax.dev/en/latest/pallas/tpu/sparse.html) to select a specific sub-slice of the input.\n",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Typo: 'extended' should be 'extends'.

Suggested change
"The code below extended the kernel above but uses [scalar prefetch and dynamic block indexing](https://docs.jax.dev/en/latest/pallas/tpu/sparse.html) to select a specific sub-slice of the input.\n",
"The code below extends the kernel above but uses [scalar prefetch and dynamic block indexing](https://docs.jax.dev/en/latest/pallas/tpu/sparse.html) to select a specific sub-slice of the input.\n",

"source": [
"## Mapping over SparseCores\n",
"\n",
"TPU v4 and above includes a [SparseCore](https://openxla.org/xla/sparsecore), which is specialized in sparse memory access and operations. This guide will not dive into the capabilities of SparseCore, but rather show how to run a program on SparseCore with same semantics and minimal changes from the TensorCore code.\n",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Grammar: 'with same semantics' should be 'with the same semantics'.

Suggested change
"TPU v4 and above includes a [SparseCore](https://openxla.org/xla/sparsecore), which is specialized in sparse memory access and operations. This guide will not dive into the capabilities of SparseCore, but rather show how to run a program on SparseCore with same semantics and minimal changes from the TensorCore code.\n",
"TPU v4 and above includes a [SparseCore](https://openxla.org/xla/sparsecore), which is specialized in sparse memory access and operations. This guide will not dive into the capabilities of SparseCore, but rather show how to run a program on SparseCore with the same semantics and minimal changes from the TensorCore code.\n",

{
"cell_type": "markdown",
"source": [
"The code below is very similar from the `add_one_kernel` we wrote earlier, except for a few differences:\n",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Typo: 'similar from' should be 'similar to'.

Suggested change
"The code below is very similar from the `add_one_kernel` we wrote earlier, except for a few differences:\n",
"The code below is very similar to the `add_one_kernel` we wrote earlier, except for a few differences:\n",


* **Flexible pipelining**: You have the option to write pipelining communications on your own, instead of relying on Pallas grids and specs. This is helpful if your pipeline diverges from the standard "copy-in, compute & copy-out" pattern.

* **Collectives**: Since `core_map` allows inter-core communications, it is especially helpful when writing collectives on the core level.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought pallas_call allows this too? Maybe it would help to mention that the way core-specific code in pallas is done is quite indirect and not user-friendly. You have to set the grid=(num_cores,) and mark that dimension as PARALLEL.


## Environment setup

Modern accelerators often have multiple cores under a device. For TPU chips higher than v4, every JAX device by default contains two TensorCores (aka. a [Megacore](https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#chips)). They also contain a [SparseCore](https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#sparsecore), consisting of many subcores.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only v4 and v5p have megacore. v7 lacks megacore and instead comes in pairs of chips with separate HBM. Also I think only v5p/v6/v7 have sparsecore.

for i in range(num_devices):
for j in range(num_cores):
pltpu.semaphore_signal(sem0, 1, device_id={'device': i, 'core': j})
pltpu.semaphore_wait(sem0, num_devices * num_cores)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is it necessary to barrier with everything - can you just barrier with the cores you are computing with?


**Parallelize work per core**

Since you are programming on the core level, you get to customize exactly how the work is splitted amongst cores. To do that, you need to:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how the work is splitted -> how the work is split


1. Provide an `index_map` function that, given the iteration indices, return *the slice* of the input data that shall be passed in.

1. On `BlockSpec`, wrap the corresponding dimension with `pl.BoundedSlice`, indicating the `index_map` function would return a slice instead of a iteration index on that dimension.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you have done the same thing with index_map = (core_idx * core_slc_size // 8 + i, j) and just use a normal (8, 128) block shape?


You could make a shortcut `kernel()` that wraps all the `shard_map`, `core_map` and `run_scoped` boilerplates.

Some similar APIs are currently available in Pallas package, such as `plgpu.kernel` and `plsc.kernel`. A unified API may be released soon.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's just add this to Pallas now? Any thoughts @sharadmv?

@@ -0,0 +1,358 @@
# Pallas Core-specifc Programming
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo: specifc -> specific


* **Per-core level programming**: You write code for an TPU/GPU core, not for a JAX device. This is crucial if you want to specifically control a core, or how cores communicate and distribute work among one another.

* **Flexible pipelining**: You have the option to write pipelining communications on your own, instead of relying on Pallas grids and specs. This is helpful if your pipeline diverges from the standard "copy-in, compute & copy-out" pattern.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can already do this in pallas_call, by simply not using the grid.

"timestamp": 1761945248463
}
],
"last_runtime": {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should not be in the ipynb


## Environment setup

Modern accelerators often have multiple cores under a device. For TPU chips higher than v4, every JAX device by default contains two TensorCores (aka. a [Megacore](https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#chips)). They also contain a [SparseCore](https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#sparsecore), consisting of many subcores.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd say for "For chips such as TPU v5p" to be precise


Modern accelerators often have multiple cores under a device. For TPU chips higher than v4, every JAX device by default contains two TensorCores (aka. a [Megacore](https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#chips)). They also contain a [SparseCore](https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#sparsecore), consisting of many subcores.

This guide was written on a v5p chip, which contains 4 devices (2 TensorCores each) and a SparseCore of 16 subcores.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"SparseCore with 16 subcores"

Copy link
Collaborator

@superbobry superbobry Nov 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

4 SparseCores, each with 16 (vector) subcores. We can link to https://openxla.org/xla/sparsecore#specifications_at_a_glance.


`pl.core_map` allows you to write per-core local code, just as `jax.shard_map` allows you to write per-device code.

In the example kernel below, each core has its own VMEM and semaphore allocations. As with normal kernel, you can initiate copy between HBM and VMEM refs using `async_copy`.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"can initiate copy" -> "can initiate copies"
"async_copy" -> "pl.async_copy"


In the example kernel below, each core has its own VMEM and semaphore allocations. As with normal kernel, you can initiate copy between HBM and VMEM refs using `async_copy`.

**Communication amongst cores**
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"amongst" -> "between"


* Call it inside a `pl.core_map`, which takes the TensorCore mesh.

* You would need `collective_id` if there exists inter-core communications.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You will need collective_id for the barrier semaphore


## Pipelining with `core_map`

Note that the kernel above only does simple copies and computes, without automatic pipelining via Pallas `grid` and `BlockSpec`. To do pipelining inside `core_map`, use `pltpu.emit_pipeline` inside the core-local kernel.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"computes" -> "compute"


**Parallelize work per core**

Since you are programming on the core level, you get to customize exactly how the work is splitted amongst cores. To do that, you need to:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"on the core level" -> "at the core level"

"splitted amongst cores" -> "split between cores"


1. Provide an `index_map` function that, given the iteration indices, return *the slice* of the input data that shall be passed in.

1. On `BlockSpec`, wrap the corresponding dimension with `pl.BoundedSlice`, indicating the `index_map` function would return a slice instead of a iteration index on that dimension.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that we don't strictly need BoundedSlice here. We can also use half of the original BlockSize and offset the index map. Also, emit_pipeline with core_axis_name also automatically partitions the grid

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think for this, the first example should be emit_pipeline with core_axis_name.

The second one could be a more custom splitting across cores.


## Mapping over SparseCores

TPU v4 and above includes a [SparseCore](https://openxla.org/xla/sparsecore), which is specialized in sparse memory access and operations. This guide will not dive into the capabilities of SparseCore, but rather show how to run a program on SparseCore with same semantics and minimal changes from the TensorCore code.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that TPU v5e does not have SparseCore.


## Environment setup

Modern accelerators often have multiple cores under a device. For TPU chips higher than v4, every JAX device by default contains two TensorCores (aka. a [Megacore](https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#chips)). They also contain a [SparseCore](https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#sparsecore), consisting of many subcores.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we say "SparseCores", since every chip has at least 2?

from functools import partial

import jax
from jax.sharding import NamedSharding, PartitionSpec as P
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: I think we have jax.P now.


**Communication amongst cores**

Before making a inter-core communication, you may need to do a global barrier signal (`pltpu.semaphore_signal`), to make sure all the destination semaphores have been properly initialized.
Copy link
Collaborator

@superbobry superbobry Nov 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it "may" or "must"? If "may", it will be useful to explain when this is in fact required.


## Mapping over SparseCores

TPU v4 and above includes a [SparseCore](https://openxla.org/xla/sparsecore), which is specialized in sparse memory access and operations. This guide will not dive into the capabilities of SparseCore, but rather show how to run a program on SparseCore with same semantics and minimal changes from the TensorCore code.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: should we say "SparseCores" here as well, instead of "a SparseCore" to highlight that it's >1 per chip.


sc_mesh = plsc.VectorSubcoreMesh(
core_axis_name="core", subcore_axis_name="subcore",
num_cores=sc_info.num_cores
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we should default to sc_info.num_cores instead of requiring users to always query SC info?


1. You need to split the work amongst all subcores, so a few lines to compute the specific slice for each subcore.

1. SparseCore register computation allows smaller slices (`4x16` max for int32), so you need nested loops to iterate the slice during computation phase.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: 2.

Also, should we reference sc_info.num_lanes here and have a single loop reading out (num_lanes,) vectors? 4x16 relies on unrolling in the SC compiler, which only really works for a handful of datatypes atm.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants