|
17 | 17 | import dataclasses
|
18 | 18 | from flax import nnx
|
19 | 19 | import jax
|
| 20 | +import jax.numpy as jnp |
20 | 21 |
|
21 | 22 |
|
22 | 23 | @jax.tree_util.register_dataclass
|
@@ -55,10 +56,10 @@ def merge(self, decoding_step, layer: nnx.Module):
|
55 | 56 | ]
|
56 | 57 | else:
|
57 | 58 | step_value = getattr(layer, field.name).value[0]
|
58 |
| - except AttributeError: |
| 59 | + except AttributeError as exc: |
59 | 60 | raise ValueError(
|
60 | 61 | f'Intermediate {field.name} is not in the step intermediates.'
|
61 |
| - ) |
| 62 | + ) from exc |
62 | 63 | # This logic is the same for all intermediates. The second dimenions is
|
63 | 64 | # the length dimension, where we want to merge the intermediates from
|
64 | 65 | # multiple steps.
|
@@ -94,8 +95,10 @@ def merge(self, decoding_step, transformer: nnx.Module):
|
94 | 95 | self.embeddings = self.embeddings.at[:, decoding_step + 1, ...].set(
|
95 | 96 | transformer.embeddings.value[0][:, 0, ...]
|
96 | 97 | )
|
97 |
| - except AttributeError: |
98 |
| - raise ValueError('Embeddings are not in the step intermediates.') |
| 98 | + except AttributeError as exc: |
| 99 | + raise ValueError( |
| 100 | + 'Embeddings are not in the step intermediates.' |
| 101 | + ) from exc |
99 | 102 | if len(self.layers) != len(transformer.layers):
|
100 | 103 | raise ValueError(
|
101 | 104 | 'Number of layers in the transformer and intermediates do not match.'
|
@@ -167,9 +170,10 @@ def maybe_sow_mlp_hidden_topk(
|
167 | 170 | activations: jax.Array,
|
168 | 171 | module: nnx.Module,
|
169 | 172 | ):
|
170 |
| - """Sows top-k activations in a mlp hidden layer if configured.""" |
| 173 | + """Sows top-absolute-k activations in a mlp hidden layer if configured.""" |
171 | 174 | if self.mlp_hidden_topk:
|
172 |
| - values, indices = jax.lax.top_k(activations, self.mlp_hidden_topk) |
| 175 | + _, indices = jax.lax.top_k(jnp.abs(activations), self.mlp_hidden_topk) |
| 176 | + values = jnp.take_along_axis(activations, indices, axis=-1) |
173 | 177 | module.sow(nnx.Intermediate, 'hidden_topk_values', values)
|
174 | 178 | module.sow(nnx.Intermediate, 'hidden_topk_indices', indices)
|
175 | 179 |
|
|
0 commit comments