Skip to content

[pallas:triton] Removed the Triton prefix from TritonCompilerParams #29023

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/jax.experimental.pallas.triton.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ Classes
.. autosummary::
:toctree: _autosummary

TritonCompilerParams
CompilerParams

Functions
---------
Expand All @@ -19,4 +19,4 @@ Functions

approx_tanh
debug_barrier
elementwise_inline_asm
elementwise_inline_asm
8 changes: 8 additions & 0 deletions docs/pallas/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,14 @@ Remember to align the itemized text with the first line of an item within a list

## Unreleased

* Deprecations

* {class}`jax.experimental.pallas.triton.TritonCompilerParams` has been
renamed to {class}`jax.experimental.pallas.triton.CompilerParams`. The
old name is deprecated and will be removed in a future release.

## Released with jax 0.6.1

* Removals

* Removed previously deprecated {mod}`jax.experimental.pallas.gpu`. To use
Expand Down
8 changes: 4 additions & 4 deletions jax/_src/pallas/pallas_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -1499,7 +1499,7 @@ def pallas_call(
interpret: Any = False,
name: str | None = None,
compiler_params: (
Mapping[Backend, CompilerParams] | CompilerParams | None
Mapping[Backend, "CompilerParams"] | "CompilerParams" | None
) = None,
cost_estimate: CostEstimate | None = None,
backend: Backend | None = None,
Expand Down Expand Up @@ -1550,7 +1550,7 @@ def pallas_call(
compiler_params: Optional compiler parameters. The value should either be a
backend-specific dataclass
(:class:`jax.experimental.pallas.tpu.TPUCompilerParams`,
:class:`jax.experimental.pallas.triton.TritonCompilerParams`,
:class:`jax.experimental.pallas.triton.CompilerParams`,
:class:`jax.experimental.pallas.mosaic_gpu.CompilerParams`) or a dict
mapping backend name to the corresponding platform-specific dataclass.
backend: Optional string literal one of ``"mosaic_tpu"``, ``"triton"`` or
Expand Down Expand Up @@ -1600,13 +1600,13 @@ def _normalize_compiler_params(
) -> Mapping[Backend, CompilerParams]:
if compiler_params is None:
return {}
if isinstance(compiler_params, pallas_core.CompilerParams):
if isinstance(compiler_params, CompilerParams):
compiler_params = {compiler_params.BACKEND: compiler_params}
assert isinstance(compiler_params, Mapping)
for backend, params in compiler_params.items():
if backend not in ["mosaic_tpu", "mosaic_gpu", "triton"]:
raise ValueError(f"Unknown backend in compiler_params: {backend}")
if not isinstance(params, pallas_core.CompilerParams):
if not isinstance(params, CompilerParams):
raise ValueError(
f"Unexpected compiler_params for backend {backend}: {params}"
)
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/pallas/triton/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from jax._src.pallas import core as pallas_core

@dataclasses.dataclass(frozen=True)
class TritonCompilerParams(pallas_core.CompilerParams):
class CompilerParams(pallas_core.CompilerParams):
"""Compiler parameters for Triton.

Attributes:
Expand Down
4 changes: 2 additions & 2 deletions jax/_src/pallas/triton/pallas_call_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,9 @@ def pallas_call_lowering(
[lowering_platform] = ctx.platforms or ctx.module_context.platforms

if "triton" in compiler_params:
params = cast(triton_core.TritonCompilerParams, compiler_params["triton"])
params = cast(triton_core.CompilerParams, compiler_params["triton"])
else:
params = triton_core.TritonCompilerParams()
params = triton_core.CompilerParams()
num_warps = 4 if params.num_warps is None else params.num_warps
num_stages = params.num_stages
if num_stages is None:
Expand Down
6 changes: 3 additions & 3 deletions jax/experimental/pallas/ops/gpu/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ def mha(
grid=grid_,
in_specs=in_specs,
out_specs=out_specs,
compiler_params=plgpu.TritonCompilerParams(
compiler_params=plgpu.CompilerParams(
num_warps=num_warps_, num_stages=num_stages),
out_shape=out_shape,
debug=debug,
Expand Down Expand Up @@ -351,7 +351,7 @@ def _preprocess_backward(out, do, lse, block_q: int,
lambda i, j, k: (j, i, k, 0)),
],
out_specs=pl.BlockSpec((None, None, block_q), lambda i, j, k: (j, k, i)),
compiler_params=plgpu.TritonCompilerParams(num_warps=4, num_stages=3),
compiler_params=plgpu.CompilerParams(num_warps=4, num_stages=3),
out_shape=out_shape,
debug=debug,
interpret=interpret,
Expand Down Expand Up @@ -634,7 +634,7 @@ def _mha_backward(sm_scale: float, causal: bool, block_sizes: BlockSizes,
name="mha_backward",
debug=debug,
interpret=interpret,
compiler_params=plgpu.TritonCompilerParams(
compiler_params=plgpu.CompilerParams(
num_warps=num_warps_, num_stages=2
),
)(q, k, v, segment_ids, out, do, lse, delta)
Expand Down
2 changes: 1 addition & 1 deletion jax/experimental/pallas/ops/gpu/decode_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def decode_attn_unbatched(
pl.BlockSpec((None, block_h), lambda i, j: (j, i)), # l
pl.BlockSpec((None, block_h), lambda i, j: (j, i)), # m
],
compiler_params=plgpu.TritonCompilerParams(
compiler_params=plgpu.CompilerParams(
num_warps=num_warps_, num_stages=num_stages
),
out_shape=[
Expand Down
8 changes: 4 additions & 4 deletions jax/experimental/pallas/ops/gpu/layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def layer_norm_forward(
]
method = pl.pallas_call(
kernel,
compiler_params=plgpu.TritonCompilerParams(num_warps=num_warps),
compiler_params=plgpu.CompilerParams(num_warps=num_warps),
grid=(),
out_shape=out_shape,
debug=False,
Expand Down Expand Up @@ -215,7 +215,7 @@ def layer_norm_backward(
out_shape_dx = jax.ShapeDtypeStruct(shape=(n,), dtype=x.dtype)
method = pl.pallas_call(
kernel,
compiler_params=plgpu.TritonCompilerParams(num_warps=num_warps),
compiler_params=plgpu.CompilerParams(num_warps=num_warps),
grid=(),
out_shape=out_shape_dx,
debug=False,
Expand Down Expand Up @@ -247,7 +247,7 @@ def layer_norm_backward(
grid_ = (pl.cdiv(reshaped_x.shape[1], block_n),)
method = pl.pallas_call(
kernel,
compiler_params=plgpu.TritonCompilerParams(num_warps=num_warps),
compiler_params=plgpu.CompilerParams(num_warps=num_warps),
grid=grid_,
out_shape=out_shape_dwbias,
debug=False,
Expand Down Expand Up @@ -283,7 +283,7 @@ def layer_norm(
out_shape = jax.ShapeDtypeStruct(shape=(n,), dtype=x.dtype)
method = pl.pallas_call(
kernel,
compiler_params=plgpu.TritonCompilerParams(
compiler_params=plgpu.CompilerParams(
num_warps=num_warps, num_stages=num_stages),
grid=(),
out_shape=out_shape,
Expand Down
2 changes: 1 addition & 1 deletion jax/experimental/pallas/ops/gpu/paged_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def paged_attention_unbatched(
],
debug=debug,
interpret=interpret,
compiler_params=plgpu.TritonCompilerParams(
compiler_params=plgpu.CompilerParams(
num_warps=num_warps, num_stages=num_stages
),
name=f"paged_attention_{block_h=}_{pages_per_compute_block=}",
Expand Down
8 changes: 4 additions & 4 deletions jax/experimental/pallas/ops/gpu/rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def rms_norm_forward(
]
method = pl.pallas_call(
kernel,
compiler_params=plgpu.TritonCompilerParams(num_warps=num_warps),
compiler_params=plgpu.CompilerParams(num_warps=num_warps),
grid=(),
out_shape=out_shape,
debug=False,
Expand Down Expand Up @@ -196,7 +196,7 @@ def rms_norm_backward(
out_shape_dx = jax.ShapeDtypeStruct(shape=(n,), dtype=x.dtype)
method = pl.pallas_call(
kernel,
compiler_params=plgpu.TritonCompilerParams(num_warps=num_warps),
compiler_params=plgpu.CompilerParams(num_warps=num_warps),
grid=(),
out_shape=out_shape_dx,
debug=False,
Expand Down Expand Up @@ -228,7 +228,7 @@ def rms_norm_backward(
grid_ = (pl.cdiv(reshaped_x.shape[1], block_n),)
method = pl.pallas_call(
kernel,
compiler_params=plgpu.TritonCompilerParams(num_warps=num_warps),
compiler_params=plgpu.CompilerParams(num_warps=num_warps),
grid=grid_,
out_shape=out_shape_dwbias,
debug=False,
Expand Down Expand Up @@ -264,7 +264,7 @@ def rms_norm(
out_shape = jax.ShapeDtypeStruct(shape=(n,), dtype=x.dtype)
method = pl.pallas_call(
kernel,
compiler_params=plgpu.TritonCompilerParams(
compiler_params=plgpu.CompilerParams(
num_warps=num_warps, num_stages=num_stages
),
grid=(),
Expand Down
2 changes: 1 addition & 1 deletion jax/experimental/pallas/ops/gpu/softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def softmax(
kernel = functools.partial(_vmappable_softmax_kernel, block_row=block_row)
f = pl.pallas_call(
kernel,
compiler_params=plgpu.TritonCompilerParams(
compiler_params=plgpu.CompilerParams(
num_warps=num_warps, num_stages=1),
grid=(),
out_shape=out_shape,
Expand Down
18 changes: 17 additions & 1 deletion jax/experimental/pallas/triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,23 @@

"""Triton-specific Pallas APIs."""

from jax._src.pallas.triton.core import TritonCompilerParams as TritonCompilerParams
from jax._src.pallas.triton.core import CompilerParams as CompilerParams
from jax._src.pallas.triton.primitives import approx_tanh as approx_tanh
from jax._src.pallas.triton.primitives import debug_barrier as debug_barrier
from jax._src.pallas.triton.primitives import elementwise_inline_asm as elementwise_inline_asm

import typing as _typing # pylint: disable=g-import-not-at-top
if _typing.TYPE_CHECKING:
TritonCompilerParams = CompilerParams
else:
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
_deprecations = {
# Deprecated on May 27th 2025.
"TritonCompilerParams": (
"TritonCompilerParams is deprecated, use CompilerParams instead.",
CompilerParams,
),
}
__getattr__ = _deprecation_getattr(__name__, _deprecations)
del _deprecation_getattr
del _typing
8 changes: 4 additions & 4 deletions tests/pallas/ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1646,7 +1646,7 @@ def kernel(x_ref, o_ref):

@unittest.skipIf(
sys.platform == "win32",
"plgpu_triton.TritonCompilerParams unavailable on Windows",
"plgpu_triton.CompilerParams unavailable on Windows",
)
def test_debug_print(self):
self.skip_if_mosaic_gpu()
Expand All @@ -1661,7 +1661,7 @@ def test_debug_print(self):
@functools.partial(
self.pallas_call,
out_shape=jax.ShapeDtypeStruct((2,), jnp.float32),
compiler_params=plgpu_triton.TritonCompilerParams(
compiler_params=plgpu_triton.CompilerParams(
num_warps=1, num_stages=1
),
)
Expand All @@ -1677,7 +1677,7 @@ def kernel(x_ref, o_ref):

@unittest.skipIf(
sys.platform == "win32",
"plgpu_triton.TritonCompilerParams unavailable on Windows",
"plgpu_triton.CompilerParams unavailable on Windows",
)
def test_debug_print_with_values(self):
if jtu.test_device_matches(["tpu"]):
Expand All @@ -1690,7 +1690,7 @@ def test_debug_print_with_values(self):
@functools.partial(
self.pallas_call,
out_shape=jax.ShapeDtypeStruct((2,), jnp.float32),
compiler_params=plgpu_triton.TritonCompilerParams(
compiler_params=plgpu_triton.CompilerParams(
num_warps=1, num_stages=1
),
)
Expand Down
Loading