From 252032a3683b193091911cfc6d482545ef198a0f Mon Sep 17 00:00:00 2001 From: George Necula Date: Mon, 5 Aug 2024 04:23:15 -0700 Subject: [PATCH] [pallas] Improve error and debugging messages with source locations Document the `name` argument to `pallas_call` and supplement it with source location information for the kernel function. Pass all this as the `name_and_src_info` parameter to the `pallas_call_p` primitive. Added some more information to the `if debug` prints. Set the MLIR module names so that the debug dumps are named properly. I changed `import pallas.core as pl_core` to `... as pallas_core` for consistency, in a couple of modules. PiperOrigin-RevId: 659506675 --- jax/_src/pallas/core.py | 46 +++++++++++- jax/_src/pallas/mosaic/lowering.py | 74 ++++++++++--------- .../pallas/mosaic/pallas_call_registration.py | 15 +++- jax/_src/pallas/mosaic_gpu/lowering.py | 14 ++-- .../mosaic_gpu/pallas_call_registration.py | 7 +- jax/_src/pallas/pallas_call.py | 63 ++++++++-------- jax/_src/pallas/triton/lowering.py | 10 ++- .../pallas/triton/pallas_call_registration.py | 9 ++- jax/experimental/mosaic/gpu/__init__.py | 7 +- tests/pallas/pallas_test.py | 43 ++++++++++- tests/pallas/pallas_vmap_test.py | 2 +- tests/pallas/tpu_pallas_test.py | 16 ++-- 12 files changed, 205 insertions(+), 101 deletions(-) diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index 12b3f8aa3078..08c1c20e344b 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -36,6 +36,7 @@ from jax._src import state from jax._src import tree_util from jax._src import util +from jax._src.interpreters import mlir from jax._src.interpreters import partial_eval as pe from jax._src.state import discharge as state_discharge import jax.numpy as jnp @@ -56,9 +57,47 @@ def __repr__(self): Grid = Union[NamedGrid, TupleGrid] StaticGrid = tuple[int, ...] GridMappingGrid = tuple[int | DynamicGridDim, ...] -SrcInfoStr = str # function_name at filename:linenumber OriginStr = str # The origin of a block spec, e.g. input[2]["field"] + +@dataclasses.dataclass(frozen=True) +class NameAndSrcInfo: + #: The name of the pallas_call or the name of the kernel function. + name: str + #: the source info, and the name of kernel function if not in `name`.` + src_info: str + + def __str__(self): + return f"{self.name}{' ' if self.src_info else ''}{self.src_info}" + __repr__ = __str__ + + replace = dataclasses.replace + + + @staticmethod + def from_pallas_call(pallas_call_name: str | None, + src_info : str | None) -> NameAndSrcInfo: + """Formats the name and the source info. + + Args: + pallas_call_name: The `name` argument to pallas_call. + src_info: The result of `api_util.fun_source_info(kernel)`, in the form + "{function_name} at {file_name}:{line_number}". + """ + if pallas_call_name is not None: + pallas_call_name = mlir._module_name_regex.sub("_", pallas_call_name) + if src_info is None: + return NameAndSrcInfo( + "unknown" if pallas_call_name is None else pallas_call_name, + "") + if pallas_call_name is not None: + return NameAndSrcInfo(pallas_call_name, + f"for kernel function {src_info}") + src_info_parts = src_info.split(" ") + return NameAndSrcInfo(src_info_parts[0], + " ".join(src_info_parts[1:])) + + # Pytrees of jax.ShapeDtypeStruct ShapeDtypeStructTree = tuple[jax.ShapeDtypeStruct, ...] @@ -268,7 +307,7 @@ class BlockMapping: block_shape: tuple[Mapped | int, ...] block_aval: AbstractMemoryRef # The block ref aval index_map_jaxpr: jax_core.ClosedJaxpr - index_map_src_info: SrcInfoStr + index_map_src_info: NameAndSrcInfo indexing_mode: IndexingMode array_shape_dtype: jax.ShapeDtypeStruct # The whole array origin: OriginStr @@ -534,7 +573,8 @@ def _convert_block_spec_to_block_mapping( lu.wrap_init(index_map_func), index_map_tree) debug = pe.debug_info(index_map_func, index_map_tree, index_map_out_tree_thunk, False, "pallas_call index_map") - index_map_src_info = debug.func_src_info or "" + index_map_src_info = NameAndSrcInfo.from_pallas_call(None, + debug.func_src_info) with tracing_grid_env(grid, mapped_dims): jaxpr, out_avals, consts, () = pe.trace_to_jaxpr_dynamic(flat_index_map_fun, index_map_avals, diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index c77189ac8eb5..e1bd21b9b69a 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -49,7 +49,7 @@ from jax._src.lib.mlir.dialects import scf from jax._src.lib.mlir.dialects import vector from jax._src.pallas import pallas_call -from jax._src.pallas import core as pl_core +from jax._src.pallas import core as pallas_core from jax._src.pallas import primitives from jax._src.pallas import utils as pallas_utils from jax._src.pallas.mosaic import core as tpu_core @@ -71,7 +71,7 @@ NDIndexer = indexing.NDIndexer TPUMemorySpace = tpu_core.TPUMemorySpace -MemorySpace = pl_core.MemorySpace | TPUMemorySpace +MemorySpace = pallas_core.MemorySpace | TPUMemorySpace VMEM = tpu_core.TPUMemorySpace.VMEM SMEM = tpu_core.TPUMemorySpace.SMEM # Booleans are stored as the following type in memrefs. @@ -105,7 +105,7 @@ class LoweringContext: grid_names: tuple[Hashable, ...] | None mapped_dims: tuple[int, ...] # Indices of vmapped grid dimensions. user_grid_indices: Sequence[ir.Value] | None - block_shapes: list[tuple[int | pl_core.Mapped, ...]] + block_shapes: list[tuple[int | pallas_core.Mapped, ...]] name_stack: source_info_util.NameStack mesh_context: MeshContext | None replace = dataclasses.replace @@ -136,7 +136,7 @@ class LoweringRuleContext: lowering_context: LoweringContext avals_in: Sequence[jax_core.AbstractValue] avals_out: Sequence[jax_core.AbstractValue] - block_shapes: Sequence[tuple[int | pl_core.Mapped, ...] | None] + block_shapes: Sequence[tuple[int | pallas_core.Mapped, ...] | None] replace = dataclasses.replace @@ -145,9 +145,9 @@ def _memory_space_to_tpu_memspace(memory_space: MemorySpace | None ) -> ir.Attribute: if memory_space is None: memory_space = VMEM - elif memory_space == pl_core.MemorySpace.ERROR: + elif memory_space == pallas_core.MemorySpace.ERROR: memory_space = SMEM - elif memory_space == pl_core.MemorySpace.INDEX: + elif memory_space == pallas_core.MemorySpace.INDEX: memory_space = SMEM return ir.Attribute.parse(f"#tpu.memory_space<{memory_space}>") @@ -252,10 +252,10 @@ def _get_aval_physical_dtype_shape(aval): def _get_arg_type( aval, - block_mapping: pl_core.BlockMapping | None, + block_mapping: pallas_core.BlockMapping | None, ): memory_space = None - if isinstance(aval, pl_core.AbstractMemoryRef): + if isinstance(aval, pallas_core.AbstractMemoryRef): memory_space = aval.memory_space # We assume unannotated memory refs are in VMEM if memory_space is None: @@ -265,7 +265,7 @@ def _get_arg_type( # TODO(necula): clean this None block_mapping if block_mapping is None: return aval_to_ir_type(aval, memory_space=memory_space), aval.shape - shape = tuple(1 if b is pl_core.mapped else b for b in block_mapping.block_shape) + shape = tuple(1 if b is pallas_core.mapped else b for b in block_mapping.block_shape) return ( aval_to_ir_type(aval, shape=shape, memory_space=memory_space), block_mapping.block_shape, @@ -277,7 +277,7 @@ class MosaicGridMapping: grid: tuple[int, ...] | None grid_names: tuple[Hashable, ...] | None jaxpr: jax_core.Jaxpr - block_mappings: tuple[pl_core.BlockMapping | None, ...] + block_mappings: tuple[pallas_core.BlockMapping | None, ...] mapped_dims: tuple[int, ...] scalar_prefetch_types: tuple[ir.Type, ...] operand_types: tuple[ir.Type, ...] @@ -289,7 +289,7 @@ class MosaicGridMapping: mesh_info: MeshInfo | None get_grid_indices: Callable | None - def __init__(self, jaxpr: jax_core.Jaxpr, grid_mapping: pl_core.GridMapping, + def __init__(self, jaxpr: jax_core.Jaxpr, grid_mapping: pallas_core.GridMapping, dimension_semantics: tuple[str, ...] | None, mesh: mesh_lib.Mesh | None): self.grid = grid_mapping.grid @@ -340,7 +340,7 @@ def __init__(self, jaxpr: jax_core.Jaxpr, grid_mapping: pl_core.GridMapping, for aval in scratch_avals ) self.grid_types, _ = unzip2([ - _get_arg_type(pl_core.index_map_grid_aval, None) + _get_arg_type(pallas_core.index_map_grid_aval, None) for _ in range(len(self.grid)) ]) self._prepare_mesh_info(mesh) @@ -418,9 +418,11 @@ class MeshInfo: def lower_jaxpr_to_module( lowering_context: mlir.LoweringRuleContext, ctx: ir.Context, - grid_mapping: pl_core.GridMapping, + grid_mapping: pallas_core.GridMapping, jaxpr: jax_core.Jaxpr, + *, dimension_semantics: tuple[str | None, ...] | None, + name_and_src_info: pallas_core.NameAndSrcInfo, mesh: mesh_lib.Mesh | None = None, for_verification: bool = False, ) -> tuple[Module, tuple[Any, ...]]: @@ -432,7 +434,8 @@ def lower_jaxpr_to_module( bm.has_trivial_window()): continue def err_details(): - return (f"Block spec for {bm.origin} has block shape " + return (f"Block spec for {bm.origin} in pallas_call {name_and_src_info} " + "has block shape " f"{bm.block_shape}, array shape {bm.array_shape_dtype.shape}, " # TODO(necula): add index_map source location info f"and index_map returning {bm.index_map_jaxpr.jaxpr.outvars}, in " @@ -460,7 +463,7 @@ def err_details(): "only blocks having the same block shape as the array shape " "and a trivial index_map (returning all 0s)." + err_details()) - unmapped_bs = [1 if bs is pl_core.mapped else bs for bs in bm.block_shape] + unmapped_bs = [1 if bs is pallas_core.mapped else bs for bs in bm.block_shape] bs0, as0 = unmapped_bs[-1], bm.array_shape_dtype.shape[-1] if rank >= 2: bs1, as1 = unmapped_bs[-2], bm.array_shape_dtype.shape[-2] @@ -507,6 +510,9 @@ def err_details(): jaxpr, grid_mapping, dimension_semantics, mesh) mosaic_grid_mapping.maybe_compress_grid() m = ir.Module.create() + attrs = m.operation.attributes + module_name = name_and_src_info.name + attrs["sym_name"] = ir.StringAttr.get(module_name) sym_tab = ir.SymbolTable(m.operation) func_op = lower_jaxpr_to_func( ctx, jaxpr, mosaic_grid_mapping=mosaic_grid_mapping, @@ -534,7 +540,7 @@ def err_details(): ) assert mlir_func.verify(), mlir_func block_shape = [ - 1 if b is pl_core.mapped else b for b in bm.block_shape + 1 if b is pallas_core.mapped else b for b in bm.block_shape ] # If we have an extended dtype, we need to add the block shape for the # remaining physical dtype. @@ -544,7 +550,7 @@ def err_details(): window_bounds=window_shape, transform_indices=ir.FlatSymbolRefAttr.get(func_name), ) - if isinstance(bm.indexing_mode, pl_core.Unblocked): + if isinstance(bm.indexing_mode, pallas_core.Unblocked): if bm.indexing_mode.padding is None: pad_low = pad_high = [0] * len(bm.block_shape) else: @@ -557,7 +563,7 @@ def err_details(): sym_tab.insert(mlir_func) func_op.attributes["window_params"] = ir.ArrayAttr.get(window_params) static_grid = [ - MLIR_DYNAMIC if b is pl_core.dynamic_grid_dim else b for b in grid + MLIR_DYNAMIC if b is pallas_core.dynamic_grid_dim else b for b in grid ] func_op.attributes["iteration_bounds"] = ir.DenseI64ArrayAttr.get(static_grid) @@ -911,7 +917,7 @@ def _make_index(s): def _maybe_cast_to_index(cast_to_index, x): if cast_to_index: return _make_index(x) - return _ensure_mlir_value(x, aval=pl_core.index_map_grid_aval) + return _ensure_mlir_value(x, aval=pallas_core.index_map_grid_aval) def _index_to_start_size_stride( @@ -940,7 +946,7 @@ def _index_to_start_size_stride( def _indexer_to_start_size_stride( indexer: NDIndexer, - ref_block_shape: tuple[int | pl_core.Mapped, ...], + ref_block_shape: tuple[int | pallas_core.Mapped, ...], *, cast_to_index: bool, ) -> tuple[ @@ -948,7 +954,7 @@ def _indexer_to_start_size_stride( tuple[int | ir.Value, ...], tuple[int, ...], tuple[bool, ...], - tuple[int | pl_core.Mapped, ...], + tuple[int | pallas_core.Mapped, ...], ]: indices_iter = iter(indexer.indices) starts, sizes, strides, squeeze_dims = [], [], [], [] @@ -960,7 +966,7 @@ def _indexer_to_start_size_stride( 1, True, ) - if s is pl_core.mapped + if s is pallas_core.mapped else _index_to_start_size_stride(next(indices_iter), cast_to_index) ) starts.append(start) @@ -982,9 +988,9 @@ def _indexer_to_start_size_stride( def _slice_memref(ref: ir.Value, ref_aval: state.AbstractRef, indexer: NDIndexer, - ref_block_shape: tuple[int | pl_core.Mapped, ...] - ) -> tuple[ir.Value, tuple[int | pl_core.Mapped, ...], - tuple[int | pl_core.Mapped, ...]]: + ref_block_shape: tuple[int | pallas_core.Mapped, ...] + ) -> tuple[ir.Value, tuple[int | pallas_core.Mapped, ...], + tuple[int | pallas_core.Mapped, ...]]: assert ref_block_shape is not None target_shape = indexer.get_indexer_shape() starts, sizes, strides, squeeze_dims, ref_block_shape = ( @@ -1216,7 +1222,7 @@ def _masked_swap_lowering_rule( mem_slice_shape.insert(i, 1) mem_slice_shape_iter = iter(mem_slice_shape) mem_slice_shape = [ - 1 if b is pl_core.mapped else next(mem_slice_shape_iter) + 1 if b is pallas_core.mapped else next(mem_slice_shape_iter) for b in ref_block_shape ] mem_aval = aval_out.update(shape=tuple(mem_slice_shape)) @@ -2126,8 +2132,8 @@ def _run_body(i, args): if unroll != 1: raise NotImplementedError( f"Only unroll={num_steps=} and unroll=1 supported. Got {unroll=}.") - lbd = _ensure_mlir_value(start, pl_core.index_map_grid_aval) - ubd = arith.addi(lbd, _ensure_mlir_value(num_steps, pl_core.index_map_grid_aval)) + lbd = _ensure_mlir_value(start, pallas_core.index_map_grid_aval) + ubd = arith.addi(lbd, _ensure_mlir_value(num_steps, pallas_core.index_map_grid_aval)) step = ir_constant(1, mlir_type=_dtype_to_ir_type(jnp.dtype("int32"))) for_op = scf.ForOp(lbd, ubd, step, args) with ir.InsertionPoint(for_op.body): @@ -2525,7 +2531,7 @@ def _bitcast_convert_type_lowering_rule( lowering_rules[lax.bitcast_convert_type_p] = _bitcast_convert_type_lowering_rule def _alloc_value(aval: jax_core.AbstractValue) -> ir.Value: - if isinstance(aval, pl_core.AbstractMemoryRef): + if isinstance(aval, pallas_core.AbstractMemoryRef): memspace = ir.Attribute.parse(f"#tpu.memory_space<{aval.memory_space}>") if jnp.issubdtype(aval.dtype, tpu_core.semaphore_dtype): assert aval.memory_space == TPUMemorySpace.SEMAPHORE @@ -2574,8 +2580,8 @@ def _linearize_mesh_indices(*indices): return sum(a * b for a, b in zip(indices, mesh_strides)) lower_ctx = LoweringRuleContext( lowering_context=ctx.lowering_context, - avals_in=[pl_core.index_map_grid_aval] * len(device_ids), - avals_out=[pl_core.index_map_grid_aval], + avals_in=[pallas_core.index_map_grid_aval] * len(device_ids), + avals_out=[pallas_core.index_map_grid_aval], block_shapes=(None,) * len(device_ids), ) return lower_fun(_linearize_mesh_indices, multiple_results=False)( @@ -2855,7 +2861,7 @@ def _shard_map_discharge_rule( rewrite, ): del out_avals, auto, in_names, out_names, check_rep, rewrite - if not isinstance(mesh, pl_core.PallasMesh): + if not isinstance(mesh, pallas_core.PallasMesh): raise NotImplementedError("Mesh must be a PallasMesh") if len(mesh.shape) > 1: raise NotImplementedError("Mesh must be 1D") @@ -2867,9 +2873,9 @@ def body(*args): out = pallas_call.pallas_call( body, out_shape=in_avals, - in_specs=[pl_core.BlockSpec(memory_space=tpu_core.TPUMemorySpace.ANY)] + in_specs=[pallas_core.BlockSpec(memory_space=tpu_core.TPUMemorySpace.ANY)] * len(in_avals), - out_specs=[pl_core.BlockSpec(memory_space=tpu_core.TPUMemorySpace.ANY)] + out_specs=[pallas_core.BlockSpec(memory_space=tpu_core.TPUMemorySpace.ANY)] * len(in_avals), input_output_aliases={i: i for i in range(len(in_avals))}, grid=((core_axis_name, num_cores),), diff --git a/jax/_src/pallas/mosaic/pallas_call_registration.py b/jax/_src/pallas/mosaic/pallas_call_registration.py index 862ab5ff6463..95951978967c 100644 --- a/jax/_src/pallas/mosaic/pallas_call_registration.py +++ b/jax/_src/pallas/mosaic/pallas_call_registration.py @@ -66,7 +66,7 @@ def pallas_call_tpu_lowering_rule( ctx: mlir.LoweringRuleContext, *in_nodes, jaxpr: jax_core.Jaxpr, - name: str, + name_and_src_info: core.NameAndSrcInfo, grid_mapping: core.GridMapping, input_output_aliases: tuple[tuple[int, int], ...], debug: bool, @@ -75,6 +75,7 @@ def pallas_call_tpu_lowering_rule( """Lowers a pallas_call to a Mosaic TPU custom call.""" del interpret if debug: + print(f"\nThe kernel jaxpr for pallas_call {name_and_src_info}:") print(jaxpr) if "mosaic_params" in compiler_params: # TODO(slebedev): Remove this branch after July 12th 2024. @@ -106,9 +107,11 @@ def lower_module(for_verification: bool): return lowering.lower_jaxpr_to_module( ctx, mlir_ctx, grid_mapping, jaxpr, dimension_semantics=dimension_semantics, mesh=mesh, - for_verification=for_verification) + for_verification=for_verification, + name_and_src_info=name_and_src_info) mosaic_module, extra_args = lower_module(for_verification=False) if debug: + print(f"\nThe Mosaic module for pallas_call {name_and_src_info}:") print(mosaic_module) num_extra_args = len(extra_args) num_dyn_bounds = grid_mapping.num_dynamic_grid_bounds @@ -132,6 +135,7 @@ def lower_module(for_verification: bool): verification_module, num_devices, num_cores ) if promela_dump_path == "stdout": + print(f"The Promela model for pallas_call {name_and_src_info}:") print(model) else: if promela_dump_path == "sponge": @@ -142,7 +146,10 @@ def lower_module(for_verification: bool): " --jax_pallas_dump_promela_to=sponge" ) dump_ctx = tempfile.NamedTemporaryFile( - mode="w", prefix=name + "-", suffix=".pml", dir=promela_dump_path, delete=False, + mode="w", + prefix=name_and_src_info.name + "-", + suffix=".pml", + dir=promela_dump_path, delete=False, ) with dump_ctx as f: f.write(model) @@ -173,7 +180,7 @@ def _maybe_cast_inputs(*args): module=mosaic_module, out_type=kernel_out_avals, backend="tpu", - kernel_name=name, + kernel_name=name_and_src_info.name, cost_estimate=mosaic_params.get("cost_estimate"), vmem_limit_bytes=mosaic_params.get("vmem_limit_bytes"), flags=mosaic_params.get("flags"), diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 254b237ff371..7dfd0258c761 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -31,7 +31,7 @@ from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import arith as arith_dialect from jax._src.lib.mlir.dialects import memref as memref_dialect -from jax._src.pallas import core as pl_core +from jax._src.pallas import core as pallas_core from jax._src.pallas import primitives from jax._src.state import primitives as sp from jax.experimental.mosaic import gpu as mosaic_gpu @@ -53,7 +53,7 @@ @dataclasses.dataclass class ModuleContext: name: str - grid_mapping: pl_core.GridMapping + grid_mapping: pallas_core.GridMapping runtime_smem: ir.Value # ir.MemRefType smem_used_bytes: int @@ -117,7 +117,7 @@ class LoweringRuleContext: module_context: ModuleContext avals_in: Sequence[jax_core.ShapedArray] avals_out: Sequence[jax_core.ShapedArray] - block_shapes: list[tuple[int | pl_core.Mapped, ...]] | None + block_shapes: list[tuple[int | pallas_core.Mapped, ...]] | None replace = dataclasses.replace @@ -142,9 +142,9 @@ class LoweringError(Exception): # pylint: disable=g-bad-exception-name def lower_jaxpr_to_module( - grid_mapping: pl_core.GridMapping, + grid_mapping: pallas_core.GridMapping, jaxpr: jax_core.Jaxpr, - name: str, + name_and_src_info: pallas_core.NameAndSrcInfo, compiler_params: dict[str, Any], ) -> LoweringResult: in_structs = tuple(grid_mapping.in_shapes) @@ -180,7 +180,8 @@ def body(launch_ctx: mosaic_gpu.LaunchContext, *buffers): barrier.wait() - module_ctx = ModuleContext(name, grid_mapping, runtime_smem, smem_used_bytes=0) + module_ctx = ModuleContext(name_and_src_info.name, + grid_mapping, runtime_smem, smem_used_bytes=0) _ = lower_jaxpr_to_mosaic_gpu(module_ctx, jaxpr, None, buffers_smem) for b_gmem, b_smem in zip(out_buffers_gmem, out_buffers_smem): @@ -210,6 +211,7 @@ def body(launch_ctx: mosaic_gpu.LaunchContext, *buffers): *extra_smem_scratch, mgpu.TMABarrier(), ), + module_name=name_and_src_info.name, ) return LoweringResult(module, grid, gmem_scratch_bytes, out_structs) diff --git a/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py b/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py index 080034c368fe..1409800fdb30 100644 --- a/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py +++ b/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py @@ -30,7 +30,7 @@ def pallas_call_lowering( ctx: mlir.LoweringRuleContext, *args, jaxpr: jax_core.Jaxpr, - name: str, + name_and_src_info: pallas_core.NameAndSrcInfo, interpret: bool, debug: bool, input_output_aliases: tuple[tuple[int, int], ...], @@ -48,16 +48,19 @@ def pallas_call_lowering( ) if debug: + print(f"\nThe kernel jaxpr for pallas_call {name_and_src_info}:") print(jaxpr) + print(f"The grid mapping for pallas_call {name_and_src_info}:") print(grid_mapping) lowering_result = lowering.lower_jaxpr_to_module( grid_mapping, jaxpr, - name, + name_and_src_info, compiler_params, ) if debug: + print(f"\nThe Mosaic GPU module for pallas_call {name_and_src_info}:") print(lowering_result.module.operation) module = lowering_result.module diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index d1b69f424e6c..619110593edc 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -167,12 +167,12 @@ def _pallas_call_impl(*args, **kwargs): def _pallas_call_impl_interpret( *args, jaxpr: jax_core.Jaxpr, - name: str, + name_and_src_info: pallas_core.NameAndStrInfo, debug: bool, input_output_aliases: tuple[tuple[int, int], ...], grid_mapping: GridMapping, compiler_params: Any): - del compiler_params, name + del compiler_params # If we're in interpreter mode, we *scan* over the grid and eval the # discharged jaxpr. dynamic_grid_args, args = split_list( # type: ignore @@ -188,6 +188,7 @@ def _pallas_call_impl_interpret( with grid_mapping.trace_env(): discharged_jaxpr, discharged_consts = state_discharge.discharge_state(jaxpr, ()) if debug: + print(f"\nJaxpr the the kernel in pallas_call {name_and_src_info}:") print(discharged_jaxpr) out = _initialize_output_vals(grid_mapping.block_mappings_output, args, input_output_aliases) @@ -301,7 +302,7 @@ def _pallas_call_abstract_eval(*avals, grid_mapping: GridMapping, **_): for bm in grid_mapping.block_mappings_output) pallas_call_p.def_abstract_eval(_pallas_call_abstract_eval) -def _pallas_call_jvp_rule(primals, tangents, *, jaxpr, name, +def _pallas_call_jvp_rule(primals, tangents, *, jaxpr, name_and_src_info, input_output_aliases: tuple[tuple[int, int], ...], grid_mapping, debug, interpret, compiler_params: Any): if grid_mapping.num_dynamic_grid_bounds: @@ -336,8 +337,8 @@ def _pallas_call_jvp_rule(primals, tangents, *, jaxpr, name, effs.append(eff) jvp_jaxpr = jvp_jaxpr.replace(invars=invars, effects=effs) if debug: + print(f"\nThe jaxpr for the jvp of pallas_call {name_and_src_info}:") print(jvp_jaxpr) - # TODO(necula): does this work with consts? in_bms, out_bms = split_list(grid_mapping.block_mappings, [len(primals)]) jvp_bms = (*in_bms, *in_bms, *out_bms, *out_bms) jvp_grid_mapping = grid_mapping.replace( @@ -349,7 +350,8 @@ def _pallas_call_jvp_rule(primals, tangents, *, jaxpr, name, *primals, *tangents, jaxpr=jvp_jaxpr, - name=f"{name}_jvp", + name_and_src_info=name_and_src_info.replace( + name=f"{name_and_src_info.name}_jvp"), grid_mapping=jvp_grid_mapping, interpret=interpret, debug=debug, @@ -428,7 +430,7 @@ def _batch_with_explicit_loop( dims: Sequence[int | batching.NotMapped], *, jaxpr: jax_core.Jaxpr, - name: str, + name_and_src_info: pallas_core.NameAndSrcInfo, grid_mapping: GridMapping, input_output_aliases: tuple[tuple[int, int], ...], debug: bool, @@ -493,7 +495,7 @@ def body(batch_index: jax.Array, state: list[jax.Array]) -> list[jax.Array]: batch_out = pallas_call_p.bind( *batch_args, jaxpr=jaxpr, - name=name, + name_and_src_info=name_and_src_info, grid_mapping=grid_mapping, input_output_aliases=input_output_aliases, debug=debug, @@ -520,7 +522,7 @@ def _pallas_call_batching_rule( dims, *, jaxpr: jax_core.Jaxpr, - name: str, + name_and_src_info: pallas_core.NameAndSrcInfo, grid_mapping: GridMapping, input_output_aliases: tuple[tuple[int, int], ...], debug: bool, @@ -542,7 +544,7 @@ def _maybe_squeeze_out_bdim( out = pallas_call_p.bind( *args, jaxpr=jaxpr, - name=name, + name_and_src_info=name_and_src_info, grid_mapping=grid_mapping, input_output_aliases=input_output_aliases, debug=debug, @@ -573,7 +575,7 @@ def _maybe_squeeze_out_bdim( args=dynamic_grid_args + args, dims=dynamic_grid_dims + dims, jaxpr=jaxpr, - name=name, + name_and_src_info=name_and_src_info, grid_mapping=grid_mapping, input_output_aliases=input_output_aliases, debug=debug, @@ -605,7 +607,7 @@ def _maybe_squeeze_out_bdim( args=scalar_args + args, dims=scalar_bdims + bdims, jaxpr=jaxpr, - name=name, + name_and_src_info=name_and_src_info, grid_mapping=grid_mapping, input_output_aliases=input_output_aliases, debug=debug, @@ -660,7 +662,8 @@ def _maybe_squeeze_out_bdim( *dynamic_grid_args, *args, jaxpr=jaxpr, - name=f"batched_{name}", + name_and_src_info=name_and_src_info.replace( + name=f"{name_and_src_info.name}_batched"), grid_mapping=batched_grid_mapping, input_output_aliases=input_output_aliases, debug=debug, @@ -836,7 +839,7 @@ def _ensure_2d_error_shape(arg): @weakref_lru_cache def _trace_kernel_to_jaxpr(fun: Callable, - fun_src_info: pallas_core.SrcInfoStr, + name_and_src_info: pallas_core.NameAndSrcInfo, grid_mapping: GridMapping, kernel_avals: tuple[pallas_core.AbstractMemRef, ...], kernel_in_tree: tree_util.PyTreeDef, @@ -855,23 +858,17 @@ def _trace_kernel_to_jaxpr(fun: Callable, consts_avals = [jax_core.raise_to_shaped(jax_core.get_aval(c)) for c in consts] raise ValueError( - f"The kernel function {fun_src_info} in a " - "pallas_call should not capture constants. You should pass them " - f"as inputs. It captures constants of shapes: {consts_avals}") + f"The kernel function in the pallas_call {name_and_src_info} " + f"captures constants {consts_avals}. " + "You should pass them as inputs") kernel_out_tree = out_tree_thunk() if kernel_out_tree != tree_util.tree_structure(None): raise ValueError( - f"The kernel function {fun_src_info} in a " - f"pallas_call should return None. " - f"It returns a PyTree: {kernel_out_tree}") + f"The kernel function in the pallas_call {name_and_src_info} " + f"should return None. It returns a PyTree: {kernel_out_tree}") return jaxpr -def _extract_function_name(f: Callable, name: str | None) -> str: - if name is None: - name = f.__name__ if hasattr(f, "__name__") and f.__name__ else "func" - return name - _PALLAS_USE_MOSAIC_GPU = config.bool_flag( "jax_pallas_use_mosaic_gpu", @@ -1009,7 +1006,11 @@ def pallas_call( grid whose body is the kernel lowered as a JAX function. This does not require a TPU or a GPU, and is the only way to run Pallas kernels on CPU. This is useful for debugging. - name: TO BE DOCUMENTED. + name: if present, specifies the name to use for this kernel call in + debugging and error messages. To this name we append the file and line + where the kernel function is defined, .e.g: + `{name} for kernel function {kernel_name} at {file}:{line}`. + If missing, then we use `{kernel_name} at {file}:{line}`. compiler_params: TO BE DOCUMENTED. Returns: @@ -1017,7 +1018,9 @@ def pallas_call( invoke the Pallas kernel. """ - name = _extract_function_name(kernel, name) + kernel_src_info = api_util.fun_sourceinfo(kernel) + name_and_src_info = pallas_core.NameAndSrcInfo.from_pallas_call( + name, kernel_src_info) if compiler_params is None: compiler_params = {} @@ -1058,16 +1061,15 @@ def wrapped(*args): kernel_fun_sig = api_util.fun_signature(kernel) arg_names = None - kernel_src_info: pallas_core.SrcInfoStr = "" if kernel_fun_sig: kernel_debug_info = api_util.debug_info( "pallas_call kernel", - api_util.fun_sourceinfo(kernel), + kernel_src_info, kernel_fun_sig, [1] * len(kernel_fun_sig.parameters), {}, (), ()) if kernel_debug_info: arg_names = kernel_debug_info.arg_names - kernel_src_info = kernel_debug_info.func_src_info + del kernel_debug_info in_origins = tuple(in_path_to_input_origin(p, arg_names) for p in in_paths) out_origins = tuple(f"outputs{tree_util.keystr(p)}" for p in out_paths) @@ -1105,7 +1107,8 @@ def wrapped(*args): index_args, rest_args = split_list(flat_args, [grid_mapping.num_index_operands]) out_flat = pallas_call_p.bind( *dynamic_grid_bounds, *index_args, *rest_args, - jaxpr=jaxpr, name=name, + jaxpr=jaxpr, + name_and_src_info=name_and_src_info, debug=debug, interpret=interpret, grid_mapping=grid_mapping, diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index b8c043993ab6..852ac714d3c9 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -266,7 +266,7 @@ def _check_tensor_size(shape: tuple[int | pallas_core.Mapped, ...]): def lower_jaxpr_to_triton_module( jaxpr: jax_core.Jaxpr, grid_mapping: GridMapping, - name: str, + name_and_src_info: pallas_core.NameAndStrInfo, platform: str ) -> LoweringResult: if grid_mapping.num_dynamic_grid_bounds: @@ -283,6 +283,9 @@ def lower_jaxpr_to_triton_module( ) with _new_ir_context(), ir.Location.unknown(): module = ir.Module.create() + attrs = module.operation.attributes + module_name = name_and_src_info.name + attrs["sym_name"] = ir.StringAttr.get(module_name) param_types = [ tt_dialect.PointerType.get(_dtype_to_ir_type(var.aval.dtype), 1) for var in jaxpr.invars @@ -290,7 +293,7 @@ def lower_jaxpr_to_triton_module( assert len(jaxpr.outvars) == 0 fn_type = ir.FunctionType.get(param_types, []) fn = tt_dialect.FuncOp( - name, + name_and_src_info.name, ir.TypeAttr.get(fn_type), sym_visibility="public", res_attrs=ir.DictAttr.get(dict(noinline=ir.BoolAttr.get(False))), @@ -310,7 +313,8 @@ def lower_jaxpr_to_triton_module( if i not in grid_mapping.vmapped_dims ] ctx = ModuleContext( - name, grid_mapping, local_program_ids, mlir.TracebackCaches(), platform + name_and_src_info.name, + grid_mapping, local_program_ids, mlir.TracebackCaches(), platform ) if grid_mapping.num_index_operands: raise NotImplementedError( diff --git a/jax/_src/pallas/triton/pallas_call_registration.py b/jax/_src/pallas/triton/pallas_call_registration.py index 4bc71a0441ae..e89e1323c3af 100644 --- a/jax/_src/pallas/triton/pallas_call_registration.py +++ b/jax/_src/pallas/triton/pallas_call_registration.py @@ -42,7 +42,7 @@ def pallas_call_lowering( ctx: mlir.LoweringRuleContext, *in_nodes, jaxpr: jax_core.Jaxpr, - name: str, + name_and_src_info: pallas_core.NameAndSrcInfo, interpret: bool, debug: bool, input_output_aliases: tuple[tuple[int, int], ...], @@ -67,14 +67,17 @@ def pallas_call_lowering( num_stages = triton_params.pop("num_stages", 3) if debug: + print(f"\nThe kernel jaxpr for pallas_call {name_and_src_info}:") print(jaxpr) + print("The grid mapping for pallas_call {name_and_src_info}:") print(grid_mapping) lowering_result = lowering.lower_jaxpr_to_triton_module( - jaxpr, grid_mapping, name, lowering_platform + jaxpr, grid_mapping, name_and_src_info, lowering_platform ) module_op = lowering_result.module.operation if debug: + print(f"\nThe Triton module for pallas_call {name_and_src_info}:") print(module_op.get_asm(enable_debug_info=True, pretty_debug_info=True)) grid_x, grid_y, grid_z = normalize_grid(lowering_result.grid) @@ -86,7 +89,7 @@ def pallas_call_lowering( buf = io.BytesIO() module_op.write_bytecode(buf) backend_config = dict( - name=ir.StringAttr.get(name), + name=ir.StringAttr.get(name_and_src_info.name), ir=ir.StringAttr.get(buf.getvalue()), num_stages=mlir.i32_attr(num_stages), num_warps=mlir.i32_attr(num_warps), diff --git a/jax/experimental/mosaic/gpu/__init__.py b/jax/experimental/mosaic/gpu/__init__.py index 8933e02abf5a..9d4068745d3a 100644 --- a/jax/experimental/mosaic/gpu/__init__.py +++ b/jax/experimental/mosaic/gpu/__init__.py @@ -690,6 +690,7 @@ def _lower_as_gpu_kernel( in_shapes: tuple[Any, ...], out_shape, smem_scratch_shape: ShapeTree | Union[ShapeTree], + module_name: str, prof_spec: profiler.ProfilerSpec | None = None, ): ptr_ty = ir.Type.parse("!llvm.ptr") @@ -714,6 +715,8 @@ def _shape_to_ref_ty(shape: jax.ShapeDtypeStruct) -> ir.MemRefType: out_ref_tys.append(prof_spec.mlir_buffer_type(grid, block)) module = ir.Module.create() + attrs = module.operation.attributes + attrs["sym_name"] = ir.StringAttr.get(module_name) with ir.InsertionPoint(module.body): _declare_runtime_functions() gmem_scratch_bytes = 0 @@ -772,6 +775,7 @@ def as_gpu_kernel( smem_scratch_shape: ShapeTree | Union[ShapeTree], prof_spec: profiler.ProfilerSpec | None = None, cluster: tuple[int, int, int] = (1, 1, 1), + module_name: str = "unknown", ): if isinstance(in_shape, list): in_shape = tuple(in_shape) @@ -780,7 +784,8 @@ def as_gpu_kernel( module, out_shape, gmem_scratch_bytes, unwrap_output_tuple = ( _lower_as_gpu_kernel( - body, grid, cluster, block, in_shape, out_shape, smem_scratch_shape, prof_spec + body, grid, cluster, block, in_shape, out_shape, smem_scratch_shape, + module_name, prof_spec ) ) diff --git a/tests/pallas/pallas_test.py b/tests/pallas/pallas_test.py index bc3f5a9e9121..63779319b0b9 100644 --- a/tests/pallas/pallas_test.py +++ b/tests/pallas/pallas_test.py @@ -27,12 +27,14 @@ import jax from jax import lax from jax import random +from jax._src import api_util from jax._src import checkify from jax._src import config from jax._src import dtypes from jax._src import test_util as jtu from jax._src.lax.control_flow.for_loop import for_loop from jax._src.lib import version as jaxlib_version +from jax._src.pallas import core as pallas_core from jax._src.pallas.pallas_call import _trace_kernel_to_jaxpr from jax.experimental import pallas as pl import jax.numpy as jnp @@ -507,7 +509,7 @@ def kernel(src, dst): with self.assertRaisesRegex( ValueError, - "The kernel function .* should not capture constants"): + "The kernel function .* captures constants"): kernel(x) def test_vector_slicing(self): @@ -712,7 +714,7 @@ def my_kernel(x_ref, o1_ref, o2_ref): out_shape=(a, a)) with self.assertRaisesRegex( ValueError, - "The kernel function my_kernel at .*pallas_test.py:.* in a pallas_call should return None"): + "The kernel function .* my_kernel at .*pallas_test.py:.* should return None"): f(a) def test_pallas_call_kernel_with_no_signature_returns_something(self): @@ -721,7 +723,7 @@ def test_pallas_call_kernel_with_no_signature_returns_something(self): out_shape=a) with self.assertRaisesRegex( ValueError, - "The kernel function .* at .*pallas_test.py:.* in a pallas_call should return None"): + "The kernel function .* at .*pallas_test.py:.* should return None"): f(a) def test_pallas_call_in_specs_not_a_sequence(self): @@ -825,7 +827,6 @@ def test_pallas_call_out_specs_mismatch_shape(self): ".* `out_specs` is a tuple of length 1 but `out_shape` is a tuple of length 2.*", re.DOTALL)): f(a) - def test_pallas_call_block_shape_ndim_mismatch(self): a = np.arange(256, dtype=np.int32) f = self.pallas_call(lambda x_ref, o1_ref: None, @@ -885,6 +886,40 @@ def test_pallas_call_input_output_aliases_errors(self): out_shape=[jax.ShapeDtypeStruct(x.shape, jnp.float32)], input_output_aliases={1: 0})(x, x) + def test_name_and_src_info(self): + def the_kernel(): return None + ns1 = pallas_core.NameAndSrcInfo.from_pallas_call( + "my_name", api_util.fun_sourceinfo(the_kernel)) + self.assertEqual("my_name", ns1.name) + self.assertIn("the_kernel", ns1.src_info) + self.assertIn("pallas_test.py:", ns1.src_info) + self.assertRegex( + str(ns1), + "my_name for kernel function the_kernel at .*pallas_test.py:.*") + + ns2 = pallas_core.NameAndSrcInfo.from_pallas_call( + None, + api_util.fun_sourceinfo(the_kernel)) + self.assertEqual("the_kernel", ns2.name) + self.assertIn("pallas_test.py:", ns2.src_info) + self.assertRegex( + str(ns2), + "the_kernel at .*pallas_test.py:.*") + + ns3 = pallas_core.NameAndSrcInfo.from_pallas_call("my_name", None) + self.assertEqual("my_name", ns3.name) + self.assertEqual("", ns3.src_info) + self.assertEqual(str(ns3), "my_name") + + ns4 = pallas_core.NameAndSrcInfo.from_pallas_call("my name with spaces", + None) + self.assertEqual("my_name_with_spaces", ns4.name) + self.assertEqual("", ns4.src_info) + + ns5 = pallas_core.NameAndSrcInfo.from_pallas_call(None, None) + self.assertEqual("unknown", ns5.name) + self.assertEqual("", ns5.src_info) + class ApiErrorInterpreterTest(ApiErrorTest): INTERPRET = True diff --git a/tests/pallas/pallas_vmap_test.py b/tests/pallas/pallas_vmap_test.py index 724285abbbca..af8299e31689 100644 --- a/tests/pallas/pallas_vmap_test.py +++ b/tests/pallas/pallas_vmap_test.py @@ -147,7 +147,7 @@ def kernel(src, dst): with self.assertRaisesRegex( ValueError, - "The kernel function .* should not capture constants"): + "The kernel function .* captures constants"): kernel(x) def test_vmap_of_kernel_with_input_output_aliases(self): diff --git a/tests/pallas/tpu_pallas_test.py b/tests/pallas/tpu_pallas_test.py index 59a81200a6aa..358b81dbd621 100644 --- a/tests/pallas/tpu_pallas_test.py +++ b/tests/pallas/tpu_pallas_test.py @@ -1884,10 +1884,6 @@ def kernel(x_ref, o_ref): class PallasCallTraceTest(PallasBaseTest): - def parse_debug_string(self, debug_string): - jaxpr, mlir = debug_string.split('module') - return {'jaxpr': jaxpr, 'mlir': mlir} - def test_trace_start_stop_match(self): def kernel(o_ref): with jax.named_scope('scope1'): @@ -1900,10 +1896,10 @@ def kernel(o_ref): debug=True, )() # TODO(justinfu): Add an official lowering API to get the MLIR. - mlir = self.parse_debug_string(msg.getvalue())['mlir'] + debug_string = msg.getvalue() - num_start = mlir.count('tpu.trace_start') - num_stop = mlir.count('tpu.trace_stop') + num_start = debug_string.count('tpu.trace_start') + num_stop = debug_string.count('tpu.trace_stop') self.assertEqual(num_start, 1) self.assertEqual(num_stop, 1) @@ -1926,10 +1922,10 @@ def scope2(): debug=True, )() # TODO(justinfu): Add an official lowering API to get the MLIR. - mlir = self.parse_debug_string(msg.getvalue())['mlir'] + debug_string = msg.getvalue() - num_start = mlir.count('tpu.trace_start') - num_stop = mlir.count('tpu.trace_stop') + num_start = debug_string.count('tpu.trace_start') + num_stop = debug_string.count('tpu.trace_stop') self.assertEqual(num_start, 2) self.assertEqual(num_stop, 2)