Skip to content

Commit

Permalink
[pallas] Improve error and debugging messages with source locations
Browse files Browse the repository at this point in the history
Document the `name` argument to `pallas_call` and supplement it with source location information for the kernel function.
Pass all this as the `name_and_src_info` parameter to the `pallas_call_p` primitive.

Added some more information to the `if debug` prints.

Set the MLIR module names so that the debug dumps are named properly.

I changed `import pallas.core as pl_core` to `... as pallas_core` for consistency, in a couple of modules.

PiperOrigin-RevId: 659506675
  • Loading branch information
gnecula authored and jax authors committed Aug 5, 2024
1 parent b2a469b commit 252032a
Show file tree
Hide file tree
Showing 12 changed files with 205 additions and 101 deletions.
46 changes: 43 additions & 3 deletions jax/_src/pallas/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from jax._src import state
from jax._src import tree_util
from jax._src import util
from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src.state import discharge as state_discharge
import jax.numpy as jnp
Expand All @@ -56,9 +57,47 @@ def __repr__(self):
Grid = Union[NamedGrid, TupleGrid]
StaticGrid = tuple[int, ...]
GridMappingGrid = tuple[int | DynamicGridDim, ...]
SrcInfoStr = str # function_name at filename:linenumber
OriginStr = str # The origin of a block spec, e.g. input[2]["field"]


@dataclasses.dataclass(frozen=True)
class NameAndSrcInfo:
#: The name of the pallas_call or the name of the kernel function.
name: str
#: the source info, and the name of kernel function if not in `name`.`
src_info: str

def __str__(self):
return f"{self.name}{' ' if self.src_info else ''}{self.src_info}"
__repr__ = __str__

replace = dataclasses.replace


@staticmethod
def from_pallas_call(pallas_call_name: str | None,
src_info : str | None) -> NameAndSrcInfo:
"""Formats the name and the source info.
Args:
pallas_call_name: The `name` argument to pallas_call.
src_info: The result of `api_util.fun_source_info(kernel)`, in the form
"{function_name} at {file_name}:{line_number}".
"""
if pallas_call_name is not None:
pallas_call_name = mlir._module_name_regex.sub("_", pallas_call_name)
if src_info is None:
return NameAndSrcInfo(
"unknown" if pallas_call_name is None else pallas_call_name,
"")
if pallas_call_name is not None:
return NameAndSrcInfo(pallas_call_name,
f"for kernel function {src_info}")
src_info_parts = src_info.split(" ")
return NameAndSrcInfo(src_info_parts[0],
" ".join(src_info_parts[1:]))


# Pytrees of jax.ShapeDtypeStruct
ShapeDtypeStructTree = tuple[jax.ShapeDtypeStruct, ...]

Expand Down Expand Up @@ -268,7 +307,7 @@ class BlockMapping:
block_shape: tuple[Mapped | int, ...]
block_aval: AbstractMemoryRef # The block ref aval
index_map_jaxpr: jax_core.ClosedJaxpr
index_map_src_info: SrcInfoStr
index_map_src_info: NameAndSrcInfo
indexing_mode: IndexingMode
array_shape_dtype: jax.ShapeDtypeStruct # The whole array
origin: OriginStr
Expand Down Expand Up @@ -534,7 +573,8 @@ def _convert_block_spec_to_block_mapping(
lu.wrap_init(index_map_func), index_map_tree)
debug = pe.debug_info(index_map_func, index_map_tree, index_map_out_tree_thunk,
False, "pallas_call index_map")
index_map_src_info = debug.func_src_info or "<unknown>"
index_map_src_info = NameAndSrcInfo.from_pallas_call(None,
debug.func_src_info)
with tracing_grid_env(grid, mapped_dims):
jaxpr, out_avals, consts, () = pe.trace_to_jaxpr_dynamic(flat_index_map_fun,
index_map_avals,
Expand Down
74 changes: 40 additions & 34 deletions jax/_src/pallas/mosaic/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
from jax._src.lib.mlir.dialects import scf
from jax._src.lib.mlir.dialects import vector
from jax._src.pallas import pallas_call
from jax._src.pallas import core as pl_core
from jax._src.pallas import core as pallas_core
from jax._src.pallas import primitives
from jax._src.pallas import utils as pallas_utils
from jax._src.pallas.mosaic import core as tpu_core
Expand All @@ -71,7 +71,7 @@

NDIndexer = indexing.NDIndexer
TPUMemorySpace = tpu_core.TPUMemorySpace
MemorySpace = pl_core.MemorySpace | TPUMemorySpace
MemorySpace = pallas_core.MemorySpace | TPUMemorySpace
VMEM = tpu_core.TPUMemorySpace.VMEM
SMEM = tpu_core.TPUMemorySpace.SMEM
# Booleans are stored as the following type in memrefs.
Expand Down Expand Up @@ -105,7 +105,7 @@ class LoweringContext:
grid_names: tuple[Hashable, ...] | None
mapped_dims: tuple[int, ...] # Indices of vmapped grid dimensions.
user_grid_indices: Sequence[ir.Value] | None
block_shapes: list[tuple[int | pl_core.Mapped, ...]]
block_shapes: list[tuple[int | pallas_core.Mapped, ...]]
name_stack: source_info_util.NameStack
mesh_context: MeshContext | None
replace = dataclasses.replace
Expand Down Expand Up @@ -136,7 +136,7 @@ class LoweringRuleContext:
lowering_context: LoweringContext
avals_in: Sequence[jax_core.AbstractValue]
avals_out: Sequence[jax_core.AbstractValue]
block_shapes: Sequence[tuple[int | pl_core.Mapped, ...] | None]
block_shapes: Sequence[tuple[int | pallas_core.Mapped, ...] | None]

replace = dataclasses.replace

Expand All @@ -145,9 +145,9 @@ def _memory_space_to_tpu_memspace(memory_space: MemorySpace | None
) -> ir.Attribute:
if memory_space is None:
memory_space = VMEM
elif memory_space == pl_core.MemorySpace.ERROR:
elif memory_space == pallas_core.MemorySpace.ERROR:
memory_space = SMEM
elif memory_space == pl_core.MemorySpace.INDEX:
elif memory_space == pallas_core.MemorySpace.INDEX:
memory_space = SMEM
return ir.Attribute.parse(f"#tpu.memory_space<{memory_space}>")

Expand Down Expand Up @@ -252,10 +252,10 @@ def _get_aval_physical_dtype_shape(aval):

def _get_arg_type(
aval,
block_mapping: pl_core.BlockMapping | None,
block_mapping: pallas_core.BlockMapping | None,
):
memory_space = None
if isinstance(aval, pl_core.AbstractMemoryRef):
if isinstance(aval, pallas_core.AbstractMemoryRef):
memory_space = aval.memory_space
# We assume unannotated memory refs are in VMEM
if memory_space is None:
Expand All @@ -265,7 +265,7 @@ def _get_arg_type(
# TODO(necula): clean this None block_mapping
if block_mapping is None:
return aval_to_ir_type(aval, memory_space=memory_space), aval.shape
shape = tuple(1 if b is pl_core.mapped else b for b in block_mapping.block_shape)
shape = tuple(1 if b is pallas_core.mapped else b for b in block_mapping.block_shape)
return (
aval_to_ir_type(aval, shape=shape, memory_space=memory_space),
block_mapping.block_shape,
Expand All @@ -277,7 +277,7 @@ class MosaicGridMapping:
grid: tuple[int, ...] | None
grid_names: tuple[Hashable, ...] | None
jaxpr: jax_core.Jaxpr
block_mappings: tuple[pl_core.BlockMapping | None, ...]
block_mappings: tuple[pallas_core.BlockMapping | None, ...]
mapped_dims: tuple[int, ...]
scalar_prefetch_types: tuple[ir.Type, ...]
operand_types: tuple[ir.Type, ...]
Expand All @@ -289,7 +289,7 @@ class MosaicGridMapping:
mesh_info: MeshInfo | None
get_grid_indices: Callable | None

def __init__(self, jaxpr: jax_core.Jaxpr, grid_mapping: pl_core.GridMapping,
def __init__(self, jaxpr: jax_core.Jaxpr, grid_mapping: pallas_core.GridMapping,
dimension_semantics: tuple[str, ...] | None,
mesh: mesh_lib.Mesh | None):
self.grid = grid_mapping.grid
Expand Down Expand Up @@ -340,7 +340,7 @@ def __init__(self, jaxpr: jax_core.Jaxpr, grid_mapping: pl_core.GridMapping,
for aval in scratch_avals
)
self.grid_types, _ = unzip2([
_get_arg_type(pl_core.index_map_grid_aval, None)
_get_arg_type(pallas_core.index_map_grid_aval, None)
for _ in range(len(self.grid))
])
self._prepare_mesh_info(mesh)
Expand Down Expand Up @@ -418,9 +418,11 @@ class MeshInfo:
def lower_jaxpr_to_module(
lowering_context: mlir.LoweringRuleContext,
ctx: ir.Context,
grid_mapping: pl_core.GridMapping,
grid_mapping: pallas_core.GridMapping,
jaxpr: jax_core.Jaxpr,
*,
dimension_semantics: tuple[str | None, ...] | None,
name_and_src_info: pallas_core.NameAndSrcInfo,
mesh: mesh_lib.Mesh | None = None,
for_verification: bool = False,
) -> tuple[Module, tuple[Any, ...]]:
Expand All @@ -432,7 +434,8 @@ def lower_jaxpr_to_module(
bm.has_trivial_window()):
continue
def err_details():
return (f"Block spec for {bm.origin} has block shape "
return (f"Block spec for {bm.origin} in pallas_call {name_and_src_info} "
"has block shape "
f"{bm.block_shape}, array shape {bm.array_shape_dtype.shape}, "
# TODO(necula): add index_map source location info
f"and index_map returning {bm.index_map_jaxpr.jaxpr.outvars}, in "
Expand Down Expand Up @@ -460,7 +463,7 @@ def err_details():
"only blocks having the same block shape as the array shape "
"and a trivial index_map (returning all 0s)." + err_details())

unmapped_bs = [1 if bs is pl_core.mapped else bs for bs in bm.block_shape]
unmapped_bs = [1 if bs is pallas_core.mapped else bs for bs in bm.block_shape]
bs0, as0 = unmapped_bs[-1], bm.array_shape_dtype.shape[-1]
if rank >= 2:
bs1, as1 = unmapped_bs[-2], bm.array_shape_dtype.shape[-2]
Expand Down Expand Up @@ -507,6 +510,9 @@ def err_details():
jaxpr, grid_mapping, dimension_semantics, mesh)
mosaic_grid_mapping.maybe_compress_grid()
m = ir.Module.create()
attrs = m.operation.attributes
module_name = name_and_src_info.name
attrs["sym_name"] = ir.StringAttr.get(module_name)
sym_tab = ir.SymbolTable(m.operation)
func_op = lower_jaxpr_to_func(
ctx, jaxpr, mosaic_grid_mapping=mosaic_grid_mapping,
Expand Down Expand Up @@ -534,7 +540,7 @@ def err_details():
)
assert mlir_func.verify(), mlir_func
block_shape = [
1 if b is pl_core.mapped else b for b in bm.block_shape
1 if b is pallas_core.mapped else b for b in bm.block_shape
]
# If we have an extended dtype, we need to add the block shape for the
# remaining physical dtype.
Expand All @@ -544,7 +550,7 @@ def err_details():
window_bounds=window_shape,
transform_indices=ir.FlatSymbolRefAttr.get(func_name),
)
if isinstance(bm.indexing_mode, pl_core.Unblocked):
if isinstance(bm.indexing_mode, pallas_core.Unblocked):
if bm.indexing_mode.padding is None:
pad_low = pad_high = [0] * len(bm.block_shape)
else:
Expand All @@ -557,7 +563,7 @@ def err_details():
sym_tab.insert(mlir_func)
func_op.attributes["window_params"] = ir.ArrayAttr.get(window_params)
static_grid = [
MLIR_DYNAMIC if b is pl_core.dynamic_grid_dim else b for b in grid
MLIR_DYNAMIC if b is pallas_core.dynamic_grid_dim else b for b in grid
]
func_op.attributes["iteration_bounds"] = ir.DenseI64ArrayAttr.get(static_grid)

Expand Down Expand Up @@ -911,7 +917,7 @@ def _make_index(s):
def _maybe_cast_to_index(cast_to_index, x):
if cast_to_index:
return _make_index(x)
return _ensure_mlir_value(x, aval=pl_core.index_map_grid_aval)
return _ensure_mlir_value(x, aval=pallas_core.index_map_grid_aval)


def _index_to_start_size_stride(
Expand Down Expand Up @@ -940,15 +946,15 @@ def _index_to_start_size_stride(

def _indexer_to_start_size_stride(
indexer: NDIndexer,
ref_block_shape: tuple[int | pl_core.Mapped, ...],
ref_block_shape: tuple[int | pallas_core.Mapped, ...],
*,
cast_to_index: bool,
) -> tuple[
tuple[ir.Value, ...],
tuple[int | ir.Value, ...],
tuple[int, ...],
tuple[bool, ...],
tuple[int | pl_core.Mapped, ...],
tuple[int | pallas_core.Mapped, ...],
]:
indices_iter = iter(indexer.indices)
starts, sizes, strides, squeeze_dims = [], [], [], []
Expand All @@ -960,7 +966,7 @@ def _indexer_to_start_size_stride(
1,
True,
)
if s is pl_core.mapped
if s is pallas_core.mapped
else _index_to_start_size_stride(next(indices_iter), cast_to_index)
)
starts.append(start)
Expand All @@ -982,9 +988,9 @@ def _indexer_to_start_size_stride(

def _slice_memref(ref: ir.Value, ref_aval: state.AbstractRef,
indexer: NDIndexer,
ref_block_shape: tuple[int | pl_core.Mapped, ...]
) -> tuple[ir.Value, tuple[int | pl_core.Mapped, ...],
tuple[int | pl_core.Mapped, ...]]:
ref_block_shape: tuple[int | pallas_core.Mapped, ...]
) -> tuple[ir.Value, tuple[int | pallas_core.Mapped, ...],
tuple[int | pallas_core.Mapped, ...]]:
assert ref_block_shape is not None
target_shape = indexer.get_indexer_shape()
starts, sizes, strides, squeeze_dims, ref_block_shape = (
Expand Down Expand Up @@ -1216,7 +1222,7 @@ def _masked_swap_lowering_rule(
mem_slice_shape.insert(i, 1)
mem_slice_shape_iter = iter(mem_slice_shape)
mem_slice_shape = [
1 if b is pl_core.mapped else next(mem_slice_shape_iter)
1 if b is pallas_core.mapped else next(mem_slice_shape_iter)
for b in ref_block_shape
]
mem_aval = aval_out.update(shape=tuple(mem_slice_shape))
Expand Down Expand Up @@ -2126,8 +2132,8 @@ def _run_body(i, args):
if unroll != 1:
raise NotImplementedError(
f"Only unroll={num_steps=} and unroll=1 supported. Got {unroll=}.")
lbd = _ensure_mlir_value(start, pl_core.index_map_grid_aval)
ubd = arith.addi(lbd, _ensure_mlir_value(num_steps, pl_core.index_map_grid_aval))
lbd = _ensure_mlir_value(start, pallas_core.index_map_grid_aval)
ubd = arith.addi(lbd, _ensure_mlir_value(num_steps, pallas_core.index_map_grid_aval))
step = ir_constant(1, mlir_type=_dtype_to_ir_type(jnp.dtype("int32")))
for_op = scf.ForOp(lbd, ubd, step, args)
with ir.InsertionPoint(for_op.body):
Expand Down Expand Up @@ -2525,7 +2531,7 @@ def _bitcast_convert_type_lowering_rule(
lowering_rules[lax.bitcast_convert_type_p] = _bitcast_convert_type_lowering_rule

def _alloc_value(aval: jax_core.AbstractValue) -> ir.Value:
if isinstance(aval, pl_core.AbstractMemoryRef):
if isinstance(aval, pallas_core.AbstractMemoryRef):
memspace = ir.Attribute.parse(f"#tpu.memory_space<{aval.memory_space}>")
if jnp.issubdtype(aval.dtype, tpu_core.semaphore_dtype):
assert aval.memory_space == TPUMemorySpace.SEMAPHORE
Expand Down Expand Up @@ -2574,8 +2580,8 @@ def _linearize_mesh_indices(*indices):
return sum(a * b for a, b in zip(indices, mesh_strides))
lower_ctx = LoweringRuleContext(
lowering_context=ctx.lowering_context,
avals_in=[pl_core.index_map_grid_aval] * len(device_ids),
avals_out=[pl_core.index_map_grid_aval],
avals_in=[pallas_core.index_map_grid_aval] * len(device_ids),
avals_out=[pallas_core.index_map_grid_aval],
block_shapes=(None,) * len(device_ids),
)
return lower_fun(_linearize_mesh_indices, multiple_results=False)(
Expand Down Expand Up @@ -2855,7 +2861,7 @@ def _shard_map_discharge_rule(
rewrite,
):
del out_avals, auto, in_names, out_names, check_rep, rewrite
if not isinstance(mesh, pl_core.PallasMesh):
if not isinstance(mesh, pallas_core.PallasMesh):
raise NotImplementedError("Mesh must be a PallasMesh")
if len(mesh.shape) > 1:
raise NotImplementedError("Mesh must be 1D")
Expand All @@ -2867,9 +2873,9 @@ def body(*args):
out = pallas_call.pallas_call(
body,
out_shape=in_avals,
in_specs=[pl_core.BlockSpec(memory_space=tpu_core.TPUMemorySpace.ANY)]
in_specs=[pallas_core.BlockSpec(memory_space=tpu_core.TPUMemorySpace.ANY)]
* len(in_avals),
out_specs=[pl_core.BlockSpec(memory_space=tpu_core.TPUMemorySpace.ANY)]
out_specs=[pallas_core.BlockSpec(memory_space=tpu_core.TPUMemorySpace.ANY)]
* len(in_avals),
input_output_aliases={i: i for i in range(len(in_avals))},
grid=((core_axis_name, num_cores),),
Expand Down
Loading

0 comments on commit 252032a

Please sign in to comment.