Skip to content

Commit

Permalink
Merge pull request jax-ml#26262 from gnecula:debug_info_one
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 722684417
  • Loading branch information
Google-ML-Automation committed Feb 3, 2025
2 parents aa64372 + c70de6d commit 7e35391
Show file tree
Hide file tree
Showing 24 changed files with 161 additions and 178 deletions.
6 changes: 3 additions & 3 deletions jax/_src/ad_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ def foo(x, y):
@wraps(fun)
@api_boundary
def fun_remat(*args, **kwargs):
debug = api_util.tracing_debug_info(
debug = api_util.debug_info(
"checkpoint / remat", fun,
args, kwargs, static_argnums=static_argnums)
fun_, args = _remat_static_argnums(fun, static_argnums, args)
Expand Down Expand Up @@ -418,7 +418,7 @@ def new_fun(*dyn_args, **kwargs):
def _trace_to_jaxpr(fun: Callable,
in_tree: PyTreeDef,
in_avals: Sequence[core.AbstractValue],
debug: lu.TracingDebugInfo
debug: core.DebugInfo
) -> tuple[core.Jaxpr, Sequence[Any], PyTreeDef]:
flat_fun, out_tree = api_util.flatten_fun(lu.wrap_init(fun), in_tree)
try:
Expand Down Expand Up @@ -447,7 +447,7 @@ def f_(*args):
args, kwargs = tree_unflatten(in_tree, args)
return f(*args, **kwargs)

debug_info = api_util.tracing_debug_info("saved_residuals", f, args, kwargs)
debug_info = api_util.debug_info("saved_residuals", f, args, kwargs)
out = api.make_jaxpr(lambda *args: api.linearize(f_, *args)[1],
return_shape=True)(*in_leaves)
assert isinstance(out, tuple)
Expand Down
14 changes: 7 additions & 7 deletions jax/_src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,11 @@
from jax._src import xla_bridge as xb
from jax._src.core import eval_jaxpr, shaped_abstractify, ShapedArray
from jax._src.api_util import (
flatten_fun, flatten_fun_nokwargs, flatten_fun_nokwargs2, argnums_partial,
flatten_axes, donation_vector,
rebase_donate_argnums, _ensure_index, _ensure_index_tuple,
apply_flat_fun_nokwargs, check_callable, tracing_debug_info,
result_paths, flat_out_axes)
flatten_fun, flatten_fun_nokwargs, flatten_fun_nokwargs2, argnums_partial,
flatten_axes, donation_vector,
rebase_donate_argnums, _ensure_index, _ensure_index_tuple,
apply_flat_fun_nokwargs, check_callable, debug_info,
result_paths, flat_out_axes)
from jax._src.lax import lax as lax_internal
from jax._src.lib import jax_jit
from jax._src.lib import xla_client as xc
Expand Down Expand Up @@ -452,7 +452,7 @@ def value_and_grad_f(*args, **kwargs):
raise TypeError(f"differentiating with respect to {argnums=} requires at least "
f"{max_argnum + 1} positional arguments to be passed by the caller, "
f"but got only {len(args)} positional arguments.")
dbg = tracing_debug_info('value_and_grad', fun, args, kwargs)
dbg = debug_info('value_and_grad', fun, args, kwargs)

f = lu.wrap_init(fun, params=kwargs, debug_info=dbg)
f_partial, dyn_args = argnums_partial(f, argnums, args,
Expand Down Expand Up @@ -1426,7 +1426,7 @@ def _prepare_pmap(fun: Callable, in_axes, out_axes, static_broadcasted_tuple,
if in_devices is not None and len(in_devices) == 0:
raise ValueError("'devices' argument to pmap must be non-empty, or None.")

dbg = tracing_debug_info(
dbg = debug_info(
"pmap", fun, args, kwargs,
static_argnums=static_broadcasted_tuple)

Expand Down
24 changes: 11 additions & 13 deletions jax/_src/api_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
prefix_errors)
from jax._src.tree_util import _replace_nones
from jax._src import linear_util as lu
from jax._src.linear_util import TracingDebugInfo
from jax._src.util import (safe_map, WrapKwArgs, Hashable, HashableFunction,
Unhashable, safe_zip)
from jax._src import traceback_util
Expand Down Expand Up @@ -582,7 +581,7 @@ def api_hook(fun, tag: str):
return fun


def tracing_debug_info(
def debug_info(
traced_for: str,
fun: Callable,
args: Sequence[Any],
Expand All @@ -594,14 +593,14 @@ def tracing_debug_info(
# TODO(necula): check if we really need this, e.g., to speed up tracing.
sourceinfo: str | None = None,
signature: inspect.Signature | None = None,
) -> TracingDebugInfo:
) -> core.DebugInfo:
if sourceinfo is None:
sourceinfo = fun_sourceinfo(fun)
if signature is None:
signature = fun_signature(fun)
arg_names = _non_static_arg_names(signature, args, kwargs, static_argnums,
static_argnames)
return TracingDebugInfo(traced_for, sourceinfo, arg_names, result_paths_thunk)
return core.DebugInfo(traced_for, sourceinfo, arg_names, result_paths_thunk)


def fun_signature(fun: Callable) -> inspect.Signature | None:
Expand All @@ -619,7 +618,7 @@ def save_wrapped_fun_sourceinfo(wrapper: Callable, wrapped: Callable):

# TODO(mattjj): make this function internal to this module
def fun_sourceinfo(fun: Callable) -> str:
# See TracingDebugInfo.fun_src_info
# See DebugInfo.fun_src_info
res = getattr(fun, "__fun_sourceinfo__", None)
if res is not None: return res
while isinstance(fun, partial):
Expand Down Expand Up @@ -684,20 +683,19 @@ def result_paths(_fun, _store, *args, **kwargs):

# TODO(necula): simplify this function, all it needs is to add the trace_debug to the Jaxpr
def add_jaxpr_debug_info(jaxpr: core.Jaxpr,
trace_debug: TracingDebugInfo | None,
debug: core.DebugInfo | None,
result_paths: tuple[str, ...] | None = None,
) -> core.Jaxpr:
"""Add debug info to jaxpr, given trace-time debug info and result paths."""
if trace_debug is None:
if debug is None:
return jaxpr
# TODO(necula): re-enable this safety check
# assert (result_paths is not None) ^ (trace_debug.result_paths_thunk is not None)
if result_paths is None:
result_paths = trace_debug.result_paths_thunk() # type: ignore
debug_info = core.JaxprDebugInfo(
trace_debug.traced_for, trace_debug.func_src_info,
trace_debug.arg_names, tuple(result_paths)) # type: ignore
return jaxpr.replace(debug_info=debug_info)
if result_paths is not None:
debug = debug._replace(result_paths=tuple(result_paths))
else:
debug = debug.resolve_result_paths()
return jaxpr.replace(debug_info=debug)

def hoist_obj_attrs(f, flat_args):
idxs, objs, flat_args_ = [], [], []
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/checkify.py
Original file line number Diff line number Diff line change
Expand Up @@ -1202,7 +1202,7 @@ def checked_fun(*args, **kwargs):
in_tree = jtu.tree_structure(((), {}))
closed_f = lambda: f(*args, **kwargs)
# stage:
debug = api_util.tracing_debug_info("checkify", f, args, kwargs)
debug = api_util.debug_info("checkify", f, args, kwargs)
fun_, out_tree = api_util.flatten_fun(lu.wrap_init(closed_f,
debug_info=debug),
in_tree)
Expand Down
36 changes: 6 additions & 30 deletions jax/_src/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,31 +82,7 @@
no_effects: Effects = effects.no_effects


# TODO(necula): make this an extension of TracingDebugInfo
class JaxprDebugInfo(NamedTuple):
# An extension of lu.TracingDebugInfo; see comments there
traced_for: str
func_src_info: str
arg_names: tuple[str | None, ...]
# This is formed after tracing, when we have concrete `result_paths`
result_paths: tuple[str, ...] # e.g. ('[0]', '[1]', ...)

def safe_arg_names(self, expected: int) -> tuple[str | None, ...]:
"""Get the arg_names with a safety check."""
if len(self.arg_names) == expected:
return self.arg_names
else:
# TODO(necula): this should not happen
return (None,) * expected

def safe_result_paths(self, expected: int) -> tuple[str | None, ...]:
"""Get the result_paths with a safety check."""
if len(self.result_paths) == expected:
return self.result_paths
else:
# TODO(necula): this should not happen
return ("",) * expected

DebugInfo = lu.DebugInfo

class Jaxpr:
__slots__ = ['__weakref__', '_constvars', '_invars', '_outvars', '_eqns',
Expand All @@ -117,7 +93,7 @@ class Jaxpr:
_outvars: list[Atom]
_eqns: list[JaxprEqn]
_effects: Effects
_debug_info: JaxprDebugInfo | None
_debug_info: DebugInfo | None

@property
def constvars(self) -> list[Var]:
Expand All @@ -140,13 +116,13 @@ def effects(self) -> Effects:
return self._effects

@property
def debug_info(self) -> JaxprDebugInfo | None:
def debug_info(self) -> DebugInfo | None:
return self._debug_info

def __init__(self, constvars: Sequence[Var], invars: Sequence[Var],
outvars: Sequence[Atom], eqns: Sequence[JaxprEqn],
effects: Effects = no_effects,
debug_info: JaxprDebugInfo | None = None):
debug_info: DebugInfo | None = None):
"""
Args:
constvars: list of variables introduced for constants. Array constants are
Expand All @@ -157,14 +133,14 @@ def __init__(self, constvars: Sequence[Var], invars: Sequence[Var],
eqns: list of equations.
effects: set of effects. The effects on a jaxpr are a superset of the
union of the effects for each equation.
debug_info: optional JaxprDebugInfo.
debug_info: optional DebugInfo.
"""
self._constvars = list(constvars)
self._invars = list(invars)
self._outvars = list(outvars)
self._eqns = list(eqns)
self._effects = effects
self._debug_info = debug_info
self._debug_info = debug_info and debug_info.resolve_result_paths()
# TODO(necula): re-enable these safety checks
# assert (not debug_info or len(debug_info.arg_names) == len(invars)), (debug_info, invars)
# assert (not debug_info or len(debug_info.result_paths) == len(outvars)), (debug_info, outvars)
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/custom_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def __call__(self, *args, **kwargs):
raise AttributeError(
f"No batching rule defined for custom_vmap function {fun_name} "
"using def_vmap.")
debug = api_util.tracing_debug_info("custom_vmap", self.fun, args, {})
debug = api_util.debug_info("custom_vmap", self.fun, args, {})
args_flat, in_tree = tree_flatten(args)
flat_fun, out_tree = api_util.flatten_fun_nokwargs(
lu.wrap_init(self.fun, debug_info=debug),
Expand Down
12 changes: 6 additions & 6 deletions jax/_src/custom_dce.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,12 +127,12 @@ def __call__(self, *args, **kwargs):
"def_dce."
)
rule_name = util.fun_name(self.dce_rule)
debug = api_util.tracing_debug_info("custom_dce", self.fun,
args, {},
static_argnums=self.static_argnums)
debug_rule = api_util.tracing_debug_info("custom_dce_rule", self.dce_rule,
args, {},
static_argnums=self.static_argnums)
debug = api_util.debug_info("custom_dce", self.fun,
args, {},
static_argnums=self.static_argnums)
debug_rule = api_util.debug_info("custom_dce_rule", self.dce_rule,
args, {},
static_argnums=self.static_argnums)
args = api_util.resolve_kwargs(self.fun, args, kwargs)
if self.static_argnums:
static_argnums = set(self.static_argnums)
Expand Down
6 changes: 3 additions & 3 deletions jax/_src/custom_partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,9 +468,9 @@ def def_partition(self, partition, infer_sharding_from_operands=None,

def __call__(self, *args, **kwargs):
args = _resolve_kwargs(self.fun, args, kwargs)
debug = api_util.tracing_debug_info("custom_partitioning", self.fun,
args, kwargs,
static_argnums=self.static_argnums)
debug = api_util.debug_info("custom_partitioning", self.fun,
args, kwargs,
static_argnums=self.static_argnums)
if self.static_argnums:
static_argnums = set(self.static_argnums)
args = tuple(x if i in static_argnums else x for i, x in enumerate(args))
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/interpreters/ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def _linearize_jaxpr(
jaxpr: core.ClosedJaxpr,
nonzeros: tuple[bool, ...]
) -> tuple[core.ClosedJaxpr, int, Sequence[bool], core.ClosedJaxpr]:
dbg = lu.TracingDebugInfo.from_jaxpr(jaxpr)
dbg = jaxpr.jaxpr.debug_info
primal_trace = pe.DynamicJaxprTrace(dbg)
tangent_trace = pe.DynamicJaxprTrace(dbg)
lin_trace = LinearizeTrace(primal_trace, tangent_trace)
Expand Down
34 changes: 17 additions & 17 deletions jax/_src/interpreters/partial_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
mapped_aval, unmapped_aval, DBIdx, InDBIdx, OutDBIdx,
InputType, OutputType, get_referent, JaxprEqnContext)
from jax._src.state.types import AbstractRef
from jax._src import tree_util
from jax._src.tree_util import (PyTreeDef, treedef_tuple,
tree_flatten, tree_structure)
from jax._src.util import (unzip2, safe_zip, safe_map, toposort, split_list,
Expand Down Expand Up @@ -932,7 +931,7 @@ def _partial_eval_jaxpr_nounits(jaxpr: ClosedJaxpr,
in_unknowns: Sequence[bool],
instantiate: bool | Sequence[bool]):
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr),
debug_info=lu.TracingDebugInfo.from_jaxpr(jaxpr))
debug_info=jaxpr.jaxpr.debug_info)

cell = []
def fun(*known_vals_in):
Expand Down Expand Up @@ -1334,10 +1333,11 @@ def prune_jaxpr_outputs(jaxpr: Jaxpr, used_outputs: Sequence[bool]) -> Jaxpr:

def _prune_jaxpr_outputs(jaxpr: Jaxpr, used_outputs: tuple[bool, ...]) -> Jaxpr:
outvars = [v for v, b in zip(jaxpr.outvars, used_outputs) if b]
dbg = jaxpr.debug_info and core.JaxprDebugInfo(
dbg = jaxpr.debug_info and core.DebugInfo(
jaxpr.debug_info.traced_for, jaxpr.debug_info.func_src_info,
jaxpr.debug_info.arg_names,
tuple(v for v, b in zip(jaxpr.debug_info.result_paths, used_outputs) if b))
tuple(v for v, b in zip(jaxpr.debug_info.safe_result_paths(len(used_outputs)),
used_outputs) if b))
new_jaxpr = jaxpr.replace(outvars=outvars, debug_info=dbg)
config.enable_checks.value and core.check_jaxpr(new_jaxpr)
return new_jaxpr
Expand Down Expand Up @@ -1422,10 +1422,12 @@ def write(x: Atom, b: bool) -> None:
eqns = new_eqns[::-1]
jaxpr_effects = make_jaxpr_effects(jaxpr.constvars, invars, outvars, eqns)

dbg = jaxpr.debug_info and core.JaxprDebugInfo(
dbg = jaxpr.debug_info and core.DebugInfo(
jaxpr.debug_info.traced_for, jaxpr.debug_info.func_src_info,
tuple(v for v, b in zip(jaxpr.debug_info.arg_names, used_inputs) if b),
tuple(v for v, b in zip(jaxpr.debug_info.result_paths, used_outputs) if b))
tuple(v for v, b in zip(jaxpr.debug_info.safe_arg_names(len(used_inputs)),
used_inputs) if b),
tuple(v for v, b in zip(jaxpr.debug_info.safe_result_paths(len(used_outputs)),
used_outputs) if b))
new_jaxpr = Jaxpr(jaxpr.constvars, invars, outvars, eqns, jaxpr_effects, dbg)
config.enable_checks.value and core.check_jaxpr(new_jaxpr)

Expand Down Expand Up @@ -1623,9 +1625,9 @@ class JaxprStackFrame:
attrs_tracked: list[tuple[Any, str]]
attrs_inits: list
attrs_vars: list[Var]
debug_info: lu.TracingDebugInfo | None
debug_info: core.DebugInfo | None

def __init__(self, debug_info: lu.TracingDebugInfo | None):
def __init__(self, debug_info: core.DebugInfo | None):
self.gensym = core.gensym()
self.tracer_to_var = {}
self.constid_to_tracer = {}
Expand Down Expand Up @@ -1809,7 +1811,7 @@ def vars_in_shape(aval: AbstractValue) -> Sequence[Var]:
class DynamicJaxprTrace(core.Trace):
__slots__ = ("frame",)

def __init__(self, debug_info: lu.TracingDebugInfo | None):
def __init__(self, debug_info: core.DebugInfo | None):
self.frame = JaxprStackFrame(debug_info)

def invalidate(self):
Expand Down Expand Up @@ -2114,7 +2116,7 @@ def _jvp_jaxpr_zeros(f, store, in_zeros, zero_avals, *primal_tangent_avals):
def trace_to_jaxpr_dynamic(
fun: lu.WrappedFun,
in_avals: Sequence[AbstractValue],
debug_info: lu.TracingDebugInfo | None = None,
debug_info: core.DebugInfo | None = None,
*,
keep_inputs: list[bool] | None = None,
) -> tuple[Jaxpr, list[AbstractValue], list[Any],
Expand All @@ -2137,7 +2139,7 @@ def trace_to_jaxpr_dynamic(
return jaxpr, [v.aval for v in jaxpr.outvars], consts, attrs_tracked

def _check_no_returned_refs(
dbg: lu.TracingDebugInfo | None,
dbg: core.DebugInfo | None,
out_tracers: Sequence[DynamicJaxprTracer]
) -> None:
if not config.mutable_array_checks.value: return
Expand All @@ -2148,10 +2150,8 @@ def _check_no_returned_refs(
raise ValueError(
f"function returned a mutable array reference of type {a.str_short()}, "
"but mutable array references cannot be returned.")
loc = (f' at output tree path {tree_util.keystr(ls[i])}' # type: ignore
if (dbg.result_paths_thunk and
(ls := dbg.result_paths_thunk()) and
ls[i]) else '')
result_paths = dbg.resolve_result_paths().safe_result_paths(len(out_tracers))
loc = f' at output tree path {result_paths[i]}'
frame = t._trace.frame
v = frame.tracer_to_var.get(id(t))
eqn = next((e for e in frame.eqns if v in e.outvars), None)
Expand All @@ -2172,7 +2172,7 @@ def _check_no_returned_refs(

@profiler.annotate_function
def trace_to_jaxpr_dynamic2(
fun: lu.WrappedFun, debug_info: lu.TracingDebugInfo | None = None
fun: lu.WrappedFun, debug_info: core.DebugInfo | None = None
) -> tuple[Jaxpr, OutputType, list[Any]]:

trace = DynamicJaxprTrace(debug_info)
Expand Down
Loading

0 comments on commit 7e35391

Please sign in to comment.