Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions flax/nnx/rnglib.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def _to_keyless(


def _function_to_method(random_f):
@functools.wraps(random_f)
def rngs_random_method(self: Rngs | RngStream, *args, **kwargs) -> jax.Array:
return random_f(self(), *args, **kwargs)

Expand Down
149 changes: 75 additions & 74 deletions flax/nnx/summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import io
import typing as tp
from types import MappingProxyType
import functools
import itertools

import jax
import numpy as np
Expand All @@ -34,6 +36,7 @@

from functools import wraps


try:
from IPython import get_ipython

Expand Down Expand Up @@ -130,14 +133,14 @@ def __str__(self):

@dataclasses.dataclass
class CallInfo:
call_order: int
object_id: int
type: type
path: statelib.PathParts
input_args: tuple[tp.Any, ...]
input_kwargs: dict[str, tp.Any]
inputs_repr: str
outputs: tp.Any
flops: int | None = None
vjp_flops: int | None = None
flops: int | None
vjp_flops: int | None

class SimpleObjectRepr:
def __init__(self, obj: tp.Any):
Expand Down Expand Up @@ -168,32 +171,6 @@ def inner(state, *args, **kwargs):
return f(model, *args, **kwargs)
return jax.vjp(inner, state, *args, **kwargs)

def _get_call_info(jitted, method_name, node_stats, obj, compute_flops: bool, *args, **kwargs):
e = jitted.lower(obj, *args, **kwargs)
flops = _get_flops(e) if compute_flops else None
outputs = e.lowered.out_info[2]
output_repr = jax.tree.map(_to_dummy_array, outputs)
input_args_info, input_kwargs_info = jax.tree.map(
_to_dummy_array, (args, kwargs)
)
object_id: int = getattr(obj, '_nnx_tabulate_id')
node_info = node_stats[object_id]
assert node_info is not None
path = node_info.path
if method_name != '__call__':
path = (*path, method_name)

return CallInfo(
object_id=object_id,
type=type(obj),
path=path,
input_args=input_args_info,
input_kwargs=input_kwargs_info,
outputs=output_repr,
flops=flops,
)


def filter_rng_streams(row: CallInfo):
return not issubclass(row.type, nnx.RngStream)

Expand All @@ -206,13 +183,64 @@ def _create_obj_env(object_types):
result[(obj_type, name)] = top_method
return result

def _argsave(tracer_args, f):
def _get_inputs_repr(args, kwargs):
input_args, input_kwargs = jax.tree.map(
_to_dummy_array, (args, kwargs)
)
inputs_repr = ''
if input_args:
if len(input_args) == 1 and not input_kwargs:
inputs_repr += _as_yaml_str(input_args[0])
else:
inputs_repr += _as_yaml_str(input_args)
if input_kwargs:
inputs_repr += '\n'
if input_kwargs:
inputs_repr += _as_yaml_str(input_kwargs)
return inputs_repr

def _save_call_info(counter, tracer_args, f, node_stats, compute_flops, compute_vjp_flops, seen):
"Wrap a function to save its arguments"
n = f.__name__

# Used when computing vjp flops
def do_vjp(*args, **kwargs):
primals, f_vjp = jax.vjp(f, *args, **kwargs)
return f_vjp(primals)

method_name = f.__name__

@functools.partial(jax.jit)
def jit_f(graphdef, state):
args, kwargs = nnx.merge(graphdef, state)
return f(*args, **kwargs)

@wraps(f)
def wrapper(obj, *args, **kwargs):
tracer_args.append((obj, n, args, kwargs))
return f(obj, *args, **kwargs)
inputs_repr = _get_inputs_repr(args, kwargs)
object_id = getattr(obj, '_nnx_tabulate_id')
node_info = node_stats[object_id]
path = node_info.path
if method_name != '__call__':
path = (*path, method_name)
identifier = (inputs_repr, object_id)
counter_val = next(counter)
graphdef, state = nnx.split(((obj, *args), kwargs))
if compute_flops:
lowered = jit_f.lower(graphdef, state)
flops = _get_flops(lowered)
outputs = lowered.out_info
else:
flops = None
outputs = jit_f(graphdef, state)
if identifier not in seen:
seen.add(identifier)
output_repr = jax.tree.map(_to_dummy_array, outputs)
vjp_flops = _get_flops(jax.jit(do_vjp).lower(
obj, *args, **kwargs)) if compute_vjp_flops else None
tracer_args.append(
CallInfo(counter_val, object_id, type(obj), path, inputs_repr,
output_repr, flops, vjp_flops))
return jit_f(graphdef, state)
return wrapper

def _overwrite_methods(env):
Expand Down Expand Up @@ -367,39 +395,22 @@ def tabulate(
# iteration over methods easier.
env = _create_obj_env(object_types)

# Modify all the object's methods to save their Tracer arguments.
# tracer_args contains (object, name, args, kwargs) tuples.
tracer_args: list[tuple[tp.Any, str, tuple, dict[str, tp.Any]]] = []
saver_env = {k: _argsave(tracer_args, v) for k,v in env.items()}
_overwrite_methods(saver_env)

# Add JIT calculation to each method. We can extract flops and output info from
# the lowered JITs. We'll only call these jitted values, which guarantees
# that each method will only be traced (and added to the table) once.
jits = {} # Maps (class, method_name) to jit
for key, value in saver_env.items():
jits[key] = nnx.jit(value)
# Information is recorded in post-order, but should be presented as a pre-order traversal.
# This keeps track of the order of calls.
counter = itertools.count(0)

# Modify all the object's methods to save their lowered JIT representations.
rows : list[CallInfo] = []
seen : set = set()
jits = {k: _save_call_info(counter, rows, v, node_stats, compute_flops, compute_vjp_flops, seen)
for k, v in env.items()}
_overwrite_methods(jits)

# Trace the top function (which indirectly traces all the others)
jits[(type(obj), method)].trace(obj, *input_args, **input_kwargs)
jits[(type(obj), method)](obj, *input_args, **input_kwargs)

# Get call_info
rows : list[CallInfo] = [_get_call_info(
jits[(type(object), name)], name, node_stats, object,
compute_flops, *args, **kwargs)
for (object, name, args, kwargs) in tracer_args]

# Add VJP flops if required. This needs to be done separately because calls to `_pure_nnx_vjp`
# can result in tracing the jitted functions a second time if there's shared structure.
# This would add items to `tracer_args`, resulting in duplicate rows in the table.
if compute_vjp_flops:
for i, row in enumerate(rows):
object, method_name, args, kwargs = tracer_args[i]
def do_vjp(*args, **kwargs):
primals, f_vjp = _pure_nnx_vjp(jits[(type(object), method_name)], *args, **kwargs)
return f_vjp(primals)
row.vjp_flops = _get_flops(jax.jit(do_vjp).lower(object, *args, **kwargs))
# Sort call info in pre-order traversal order
rows.sort(key=lambda x: x.call_order)

# Restore the object's original methods
_overwrite_methods(env)
Expand Down Expand Up @@ -436,17 +447,7 @@ def do_vjp(*args, **kwargs):
path_str = '/'.join(map(str, row.path))
col_reprs.append(path_str)
col_reprs.append(row.type.__name__)
inputs_repr = ''
if row.input_args:
input_args = row.input_args
if len(row.input_args) == 1 and not row.input_kwargs:
input_args = row.input_args[0]
inputs_repr += _as_yaml_str(input_args)
if inputs_repr and row.input_kwargs:
inputs_repr += '\n'
if row.input_kwargs:
inputs_repr += _as_yaml_str(row.input_kwargs)
col_reprs.append(inputs_repr)
col_reprs.append(row.inputs_repr)
col_reprs.append(_as_yaml_str(row.outputs))
if compute_flops:
col_reprs.append(str(row.flops))
Expand Down
43 changes: 41 additions & 2 deletions tests/nnx/summary_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,10 @@ def __call__(self, x):

foo = Foo(nnx.Rngs(0))
x = jnp.ones((1, 32))
table_repr = nnx.tabulate(
table_repr_ = nnx.tabulate(
foo, x, console_kwargs=CONSOLE_TEST_KWARGS
).splitlines()
)
table_repr = table_repr_.splitlines()

self.assertIn('Foo Summary', table_repr[0])
self.assertIn('path', table_repr[2])
Expand Down Expand Up @@ -224,6 +225,32 @@ def __call__(self, x):
# We should see 3 calls per block, plus one overall call
self.assertEqual(sum([s.startswith("├─") for s in table.splitlines()]), 7)

def test_time_complexity(self):
counter = []

class Block(nnx.Module):
def __init__(self, rngs):
self.linear = nnx.Linear(2, 2, rngs=rngs)

def __call__(self, x):
counter.append(1)
return self.linear(x)

class Model(nnx.Module):
def __init__(self, rngs):
for d in range(10):
setattr(self, f"linear{d}", Block(rngs))

def __call__(self, x):
for d in range(10):
x = getattr(self, f"linear{d}")(x)
return x

m = Model(nnx.Rngs(0))
x = jnp.ones((4, 2))
nnx.tabulate(m, x, compute_flops=True, compute_vjp_flops=False)
self.assertEqual(len(counter), 10)

def test_shared(self):
class Block(nnx.Module):
def __init__(self, linear: nnx.Linear, *, rngs):
Expand Down Expand Up @@ -295,5 +322,17 @@ def __call__(self, x):
self.assertEqual(module.hooked_param.get_metadata('description'), 'Custom parameter')
self.assertEqual(module.hooked_param.get_metadata('trainable'), True)

def test_tabulate_concrete_shape(self):
class Net(nnx.Module):
def __init__(self):
self.rngs = nnx.Rngs(0)

def __call__(self, x):
return self.rngs.uniform((x.shape[0], 10))

net = Net()
x = jnp.zeros((4, 8))
nnx.tabulate(net, x, console_kwargs={"width": 200})

if __name__ == '__main__':
absltest.main()
Loading