Skip to content

Commit dc06292

Browse files
author
Flax Authors
committed
Merge pull request #4639 from google:flaxlib-types
PiperOrigin-RevId: 750628750
2 parents d3ec212 + 334f383 commit dc06292

File tree

9 files changed

+492
-239
lines changed

9 files changed

+492
-239
lines changed

benchmarks/nnx_simple_training.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -116,14 +116,11 @@ def test_step_nnx(model: MLP, batch):
116116
loss = jnp.mean((y - y_pred) ** 2)
117117
return {'loss': loss}
118118

119-
cached_train_step_nnx = nnx.cached_partial(train_step_nnx, model, optimizer)
120-
cached_test_step_nnx = nnx.cached_partial(test_step_nnx, model)
121-
122119
for step, batch in enumerate(dataset(X, Y, batch_size)):
123-
cached_train_step_nnx(batch)
120+
train_step_nnx(model, optimizer, batch)
124121

125122
if step % 1000 == 0:
126-
logs = cached_test_step_nnx((X, Y))
123+
logs = test_step_nnx(model, (X, Y))
127124

128125
if step >= total_steps - 1:
129126
break

0 commit comments

Comments
 (0)