Skip to content

Commit 00b1d67

Browse files
danielsuoFlax Authors
authored andcommitted
[flax:examples:mnist] Add get_fake_batch and get_apply_fn_and_args methods to train.py.
Preparing Flax examples for programmatic testing and benchmarking. PiperOrigin-RevId: 815862166
1 parent 771eadb commit 00b1d67

File tree

2 files changed

+35
-3
lines changed

2 files changed

+35
-3
lines changed

examples/mnist/main.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,11 @@
2222
from absl import flags
2323
from absl import logging
2424
from clu import platform
25+
import train
2526
import jax
2627
from ml_collections import config_flags
2728
import tensorflow as tf
2829

29-
import train
30-
3130

3231
FLAGS = flags.FLAGS
3332

examples/mnist/train.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
# See issue #620.
2222
# pytype: disable=wrong-keyword-args
23+
from typing import Any
2324

2425
from absl import logging
2526
from flax import linen as nn
@@ -51,6 +52,38 @@ def __call__(self, x):
5152
return x
5253

5354

55+
def get_fake_batch(batch_size: int) -> Any:
56+
"""Returns fake data for the given batch size.
57+
58+
Args:
59+
batch_size: The global batch size to generate.
60+
61+
Returns:
62+
A properly sharded global batch of data.
63+
"""
64+
rng = jax.random.PRNGKey(0)
65+
images = jax.random.randint(rng, (batch_size, 28, 28, 1), 0, 255, jnp.uint8)
66+
labels = jax.random.randint(rng, (batch_size,), 0, 10, jnp.int32)
67+
return images, labels
68+
69+
70+
def get_apply_fn_and_args(
71+
config: ml_collections.ConfigDict,
72+
) -> tuple[Any, tuple[Any, ...], dict[str, Any], tuple[Any, ...]]:
73+
"""Returns the apply function and args for the given config.
74+
75+
Args:
76+
config: The training configuration.
77+
78+
Returns:
79+
A tuple of the apply function, args and kwargs for the apply function, and
80+
any metadata the training loop needs.
81+
"""
82+
state = create_train_state(jax.random.key(0), config)
83+
batch = get_fake_batch(config.batch_size)
84+
return apply_model, (state, *batch), dict(), ()
85+
86+
5487
@jax.jit
5588
def apply_model(state, images, labels):
5689
"""Computes gradients, loss and accuracy for a single batch."""
@@ -145,7 +178,7 @@ def train_and_evaluate(
145178
state, test_ds['image'], test_ds['label']
146179
)
147180

148-
logging.info(
181+
logging.info( # pytype: disable=logging-not-lazy
149182
'epoch:% 3d, train_loss: %.4f, train_accuracy: %.2f, test_loss: %.4f,'
150183
' test_accuracy: %.2f'
151184
% (

0 commit comments

Comments
 (0)