Skip to content

Commit 30eecf6

Browse files
superbobryGoogle-ML-Automation
authored andcommitted
[pallas:triton] Removed the Triton prefix from TritonCompilerParams
All Triton-specific APIs are always used qualified, e.g. `plgpu.TritonCompilerParams`, so the prefix is redundant. PiperOrigin-RevId: 764276165
1 parent fd28b2f commit 30eecf6

13 files changed

+52
-28
lines changed

docs/jax.experimental.pallas.triton.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ Classes
99
.. autosummary::
1010
:toctree: _autosummary
1111

12-
TritonCompilerParams
12+
CompilerParams
1313

1414
Functions
1515
---------
@@ -19,4 +19,4 @@ Functions
1919

2020
approx_tanh
2121
debug_barrier
22-
elementwise_inline_asm
22+
elementwise_inline_asm

docs/pallas/CHANGELOG.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,14 @@ Remember to align the itemized text with the first line of an item within a list
1313

1414
## Unreleased
1515

16+
* Deprecations
17+
18+
* {class}`jax.experimental.pallas.triton.TritonCompilerParams` has been
19+
renamed to {class}`jax.experimental.pallas.triton.CompilerParams`. The
20+
old name is deprecated and will be removed in a future release.
21+
22+
## Released with jax 0.6.1
23+
1624
* Removals
1725

1826
* Removed previously deprecated {mod}`jax.experimental.pallas.gpu`. To use

jax/_src/pallas/pallas_call.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1499,7 +1499,7 @@ def pallas_call(
14991499
interpret: Any = False,
15001500
name: str | None = None,
15011501
compiler_params: (
1502-
Mapping[Backend, CompilerParams] | CompilerParams | None
1502+
Mapping[Backend, "CompilerParams"] | "CompilerParams" | None
15031503
) = None,
15041504
cost_estimate: CostEstimate | None = None,
15051505
backend: Backend | None = None,
@@ -1550,7 +1550,7 @@ def pallas_call(
15501550
compiler_params: Optional compiler parameters. The value should either be a
15511551
backend-specific dataclass
15521552
(:class:`jax.experimental.pallas.tpu.TPUCompilerParams`,
1553-
:class:`jax.experimental.pallas.triton.TritonCompilerParams`,
1553+
:class:`jax.experimental.pallas.triton.CompilerParams`,
15541554
:class:`jax.experimental.pallas.mosaic_gpu.CompilerParams`) or a dict
15551555
mapping backend name to the corresponding platform-specific dataclass.
15561556
backend: Optional string literal one of ``"mosaic_tpu"``, ``"triton"`` or
@@ -1600,13 +1600,13 @@ def _normalize_compiler_params(
16001600
) -> Mapping[Backend, CompilerParams]:
16011601
if compiler_params is None:
16021602
return {}
1603-
if isinstance(compiler_params, pallas_core.CompilerParams):
1603+
if isinstance(compiler_params, CompilerParams):
16041604
compiler_params = {compiler_params.BACKEND: compiler_params}
16051605
assert isinstance(compiler_params, Mapping)
16061606
for backend, params in compiler_params.items():
16071607
if backend not in ["mosaic_tpu", "mosaic_gpu", "triton"]:
16081608
raise ValueError(f"Unknown backend in compiler_params: {backend}")
1609-
if not isinstance(params, pallas_core.CompilerParams):
1609+
if not isinstance(params, CompilerParams):
16101610
raise ValueError(
16111611
f"Unexpected compiler_params for backend {backend}: {params}"
16121612
)

jax/_src/pallas/triton/core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from jax._src.pallas import core as pallas_core
2222

2323
@dataclasses.dataclass(frozen=True)
24-
class TritonCompilerParams(pallas_core.CompilerParams):
24+
class CompilerParams(pallas_core.CompilerParams):
2525
"""Compiler parameters for Triton.
2626
2727
Attributes:

jax/_src/pallas/triton/pallas_call_registration.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,9 @@ def pallas_call_lowering(
7272
[lowering_platform] = ctx.platforms or ctx.module_context.platforms
7373

7474
if "triton" in compiler_params:
75-
params = cast(triton_core.TritonCompilerParams, compiler_params["triton"])
75+
params = cast(triton_core.CompilerParams, compiler_params["triton"])
7676
else:
77-
params = triton_core.TritonCompilerParams()
77+
params = triton_core.CompilerParams()
7878
num_warps = 4 if params.num_warps is None else params.num_warps
7979
num_stages = params.num_stages
8080
if num_stages is None:

jax/experimental/pallas/ops/gpu/attention.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,7 @@ def mha(
288288
grid=grid_,
289289
in_specs=in_specs,
290290
out_specs=out_specs,
291-
compiler_params=plgpu.TritonCompilerParams(
291+
compiler_params=plgpu.CompilerParams(
292292
num_warps=num_warps_, num_stages=num_stages),
293293
out_shape=out_shape,
294294
debug=debug,
@@ -351,7 +351,7 @@ def _preprocess_backward(out, do, lse, block_q: int,
351351
lambda i, j, k: (j, i, k, 0)),
352352
],
353353
out_specs=pl.BlockSpec((None, None, block_q), lambda i, j, k: (j, k, i)),
354-
compiler_params=plgpu.TritonCompilerParams(num_warps=4, num_stages=3),
354+
compiler_params=plgpu.CompilerParams(num_warps=4, num_stages=3),
355355
out_shape=out_shape,
356356
debug=debug,
357357
interpret=interpret,
@@ -634,7 +634,7 @@ def _mha_backward(sm_scale: float, causal: bool, block_sizes: BlockSizes,
634634
name="mha_backward",
635635
debug=debug,
636636
interpret=interpret,
637-
compiler_params=plgpu.TritonCompilerParams(
637+
compiler_params=plgpu.CompilerParams(
638638
num_warps=num_warps_, num_stages=2
639639
),
640640
)(q, k, v, segment_ids, out, do, lse, delta)

jax/experimental/pallas/ops/gpu/decode_attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ def decode_attn_unbatched(
193193
pl.BlockSpec((None, block_h), lambda i, j: (j, i)), # l
194194
pl.BlockSpec((None, block_h), lambda i, j: (j, i)), # m
195195
],
196-
compiler_params=plgpu.TritonCompilerParams(
196+
compiler_params=plgpu.CompilerParams(
197197
num_warps=num_warps_, num_stages=num_stages
198198
),
199199
out_shape=[

jax/experimental/pallas/ops/gpu/layer_norm.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def layer_norm_forward(
9494
]
9595
method = pl.pallas_call(
9696
kernel,
97-
compiler_params=plgpu.TritonCompilerParams(num_warps=num_warps),
97+
compiler_params=plgpu.CompilerParams(num_warps=num_warps),
9898
grid=(),
9999
out_shape=out_shape,
100100
debug=False,
@@ -215,7 +215,7 @@ def layer_norm_backward(
215215
out_shape_dx = jax.ShapeDtypeStruct(shape=(n,), dtype=x.dtype)
216216
method = pl.pallas_call(
217217
kernel,
218-
compiler_params=plgpu.TritonCompilerParams(num_warps=num_warps),
218+
compiler_params=plgpu.CompilerParams(num_warps=num_warps),
219219
grid=(),
220220
out_shape=out_shape_dx,
221221
debug=False,
@@ -247,7 +247,7 @@ def layer_norm_backward(
247247
grid_ = (pl.cdiv(reshaped_x.shape[1], block_n),)
248248
method = pl.pallas_call(
249249
kernel,
250-
compiler_params=plgpu.TritonCompilerParams(num_warps=num_warps),
250+
compiler_params=plgpu.CompilerParams(num_warps=num_warps),
251251
grid=grid_,
252252
out_shape=out_shape_dwbias,
253253
debug=False,
@@ -283,7 +283,7 @@ def layer_norm(
283283
out_shape = jax.ShapeDtypeStruct(shape=(n,), dtype=x.dtype)
284284
method = pl.pallas_call(
285285
kernel,
286-
compiler_params=plgpu.TritonCompilerParams(
286+
compiler_params=plgpu.CompilerParams(
287287
num_warps=num_warps, num_stages=num_stages),
288288
grid=(),
289289
out_shape=out_shape,

jax/experimental/pallas/ops/gpu/paged_attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ def paged_attention_unbatched(
222222
],
223223
debug=debug,
224224
interpret=interpret,
225-
compiler_params=plgpu.TritonCompilerParams(
225+
compiler_params=plgpu.CompilerParams(
226226
num_warps=num_warps, num_stages=num_stages
227227
),
228228
name=f"paged_attention_{block_h=}_{pages_per_compute_block=}",

jax/experimental/pallas/ops/gpu/rms_norm.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def rms_norm_forward(
8282
]
8383
method = pl.pallas_call(
8484
kernel,
85-
compiler_params=plgpu.TritonCompilerParams(num_warps=num_warps),
85+
compiler_params=plgpu.CompilerParams(num_warps=num_warps),
8686
grid=(),
8787
out_shape=out_shape,
8888
debug=False,
@@ -196,7 +196,7 @@ def rms_norm_backward(
196196
out_shape_dx = jax.ShapeDtypeStruct(shape=(n,), dtype=x.dtype)
197197
method = pl.pallas_call(
198198
kernel,
199-
compiler_params=plgpu.TritonCompilerParams(num_warps=num_warps),
199+
compiler_params=plgpu.CompilerParams(num_warps=num_warps),
200200
grid=(),
201201
out_shape=out_shape_dx,
202202
debug=False,
@@ -228,7 +228,7 @@ def rms_norm_backward(
228228
grid_ = (pl.cdiv(reshaped_x.shape[1], block_n),)
229229
method = pl.pallas_call(
230230
kernel,
231-
compiler_params=plgpu.TritonCompilerParams(num_warps=num_warps),
231+
compiler_params=plgpu.CompilerParams(num_warps=num_warps),
232232
grid=grid_,
233233
out_shape=out_shape_dwbias,
234234
debug=False,
@@ -264,7 +264,7 @@ def rms_norm(
264264
out_shape = jax.ShapeDtypeStruct(shape=(n,), dtype=x.dtype)
265265
method = pl.pallas_call(
266266
kernel,
267-
compiler_params=plgpu.TritonCompilerParams(
267+
compiler_params=plgpu.CompilerParams(
268268
num_warps=num_warps, num_stages=num_stages
269269
),
270270
grid=(),

jax/experimental/pallas/ops/gpu/softmax.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def softmax(
8080
kernel = functools.partial(_vmappable_softmax_kernel, block_row=block_row)
8181
f = pl.pallas_call(
8282
kernel,
83-
compiler_params=plgpu.TritonCompilerParams(
83+
compiler_params=plgpu.CompilerParams(
8484
num_warps=num_warps, num_stages=1),
8585
grid=(),
8686
out_shape=out_shape,

jax/experimental/pallas/triton.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,23 @@
1414

1515
"""Triton-specific Pallas APIs."""
1616

17-
from jax._src.pallas.triton.core import TritonCompilerParams as TritonCompilerParams
17+
from jax._src.pallas.triton.core import CompilerParams as CompilerParams
1818
from jax._src.pallas.triton.primitives import approx_tanh as approx_tanh
1919
from jax._src.pallas.triton.primitives import debug_barrier as debug_barrier
2020
from jax._src.pallas.triton.primitives import elementwise_inline_asm as elementwise_inline_asm
21+
22+
import typing as _typing # pylint: disable=g-import-not-at-top
23+
if _typing.TYPE_CHECKING:
24+
TritonCompilerParams = CompilerParams
25+
else:
26+
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
27+
_deprecations = {
28+
# Deprecated on May 27th 2025.
29+
"TritonCompilerParams": (
30+
"TritonCompilerParams is deprecated, use CompilerParams instead.",
31+
CompilerParams,
32+
),
33+
}
34+
__getattr__ = _deprecation_getattr(__name__, _deprecations)
35+
del _deprecation_getattr
36+
del _typing

tests/pallas/ops_test.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1646,7 +1646,7 @@ def kernel(x_ref, o_ref):
16461646

16471647
@unittest.skipIf(
16481648
sys.platform == "win32",
1649-
"plgpu_triton.TritonCompilerParams unavailable on Windows",
1649+
"plgpu_triton.CompilerParams unavailable on Windows",
16501650
)
16511651
def test_debug_print(self):
16521652
self.skip_if_mosaic_gpu()
@@ -1661,7 +1661,7 @@ def test_debug_print(self):
16611661
@functools.partial(
16621662
self.pallas_call,
16631663
out_shape=jax.ShapeDtypeStruct((2,), jnp.float32),
1664-
compiler_params=plgpu_triton.TritonCompilerParams(
1664+
compiler_params=plgpu_triton.CompilerParams(
16651665
num_warps=1, num_stages=1
16661666
),
16671667
)
@@ -1677,7 +1677,7 @@ def kernel(x_ref, o_ref):
16771677

16781678
@unittest.skipIf(
16791679
sys.platform == "win32",
1680-
"plgpu_triton.TritonCompilerParams unavailable on Windows",
1680+
"plgpu_triton.CompilerParams unavailable on Windows",
16811681
)
16821682
def test_debug_print_with_values(self):
16831683
if jtu.test_device_matches(["tpu"]):
@@ -1690,7 +1690,7 @@ def test_debug_print_with_values(self):
16901690
@functools.partial(
16911691
self.pallas_call,
16921692
out_shape=jax.ShapeDtypeStruct((2,), jnp.float32),
1693-
compiler_params=plgpu_triton.TritonCompilerParams(
1693+
compiler_params=plgpu_triton.CompilerParams(
16941694
num_warps=1, num_stages=1
16951695
),
16961696
)

0 commit comments

Comments
 (0)