Skip to content

Commit 5d41f6a

Browse files
casaroFlax Authors
authored and
Flax Authors
committed
Sow top activations based on absolute value.
PiperOrigin-RevId: 743453927
1 parent 68b69f7 commit 5d41f6a

File tree

1 file changed

+10
-6
lines changed

1 file changed

+10
-6
lines changed

examples/gemma/sow_lib.py

+10-6
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import dataclasses
1818
from flax import nnx
1919
import jax
20+
import jax.numpy as jnp
2021

2122

2223
@jax.tree_util.register_dataclass
@@ -55,10 +56,10 @@ def merge(self, decoding_step, layer: nnx.Module):
5556
]
5657
else:
5758
step_value = getattr(layer, field.name).value[0]
58-
except AttributeError:
59+
except AttributeError as exc:
5960
raise ValueError(
6061
f'Intermediate {field.name} is not in the step intermediates.'
61-
)
62+
) from exc
6263
# This logic is the same for all intermediates. The second dimenions is
6364
# the length dimension, where we want to merge the intermediates from
6465
# multiple steps.
@@ -94,8 +95,10 @@ def merge(self, decoding_step, transformer: nnx.Module):
9495
self.embeddings = self.embeddings.at[:, decoding_step + 1, ...].set(
9596
transformer.embeddings.value[0][:, 0, ...]
9697
)
97-
except AttributeError:
98-
raise ValueError('Embeddings are not in the step intermediates.')
98+
except AttributeError as exc:
99+
raise ValueError(
100+
'Embeddings are not in the step intermediates.'
101+
) from exc
99102
if len(self.layers) != len(transformer.layers):
100103
raise ValueError(
101104
'Number of layers in the transformer and intermediates do not match.'
@@ -167,9 +170,10 @@ def maybe_sow_mlp_hidden_topk(
167170
activations: jax.Array,
168171
module: nnx.Module,
169172
):
170-
"""Sows top-k activations in a mlp hidden layer if configured."""
173+
"""Sows top-absolute-k activations in a mlp hidden layer if configured."""
171174
if self.mlp_hidden_topk:
172-
values, indices = jax.lax.top_k(activations, self.mlp_hidden_topk)
175+
_, indices = jax.lax.top_k(jnp.abs(activations), self.mlp_hidden_topk)
176+
values = jnp.take_along_axis(activations, indices, axis=-1)
173177
module.sow(nnx.Intermediate, 'hidden_topk_values', values)
174178
module.sow(nnx.Intermediate, 'hidden_topk_indices', indices)
175179

0 commit comments

Comments
 (0)