-
Notifications
You must be signed in to change notification settings - Fork 230
reduce log #1022
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
reduce log #1022
Conversation
7f1931b to
8e7d230
Compare
|
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request aims to reduce logging noise during training. It achieves this by ensuring the training mesh information is logged only once, and the JIT compilation cache size is logged only when it changes to a previously unseen value. These are good changes for improving log readability.
My review includes two points:
- A potential issue with the lifecycle of the newly introduced
_jit_cache, which could lead to missed log messages about recompilations when the trainer is reused. - A minor correction to a
pytype: disablecomment.
| self._buffered_eval_metrics: MetricsBuffer | None = None | ||
| self.training_hooks = None | ||
| self.data_hooks = None | ||
| self._jit_cache = set() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This cache is used to track JIT compilation cache sizes to avoid verbose logging. However, it appears this set is never cleared during the lifetime of the trainer instance. If the trainer is reused in a way that causes recompilations (e.g., by calling with_loss_fn or with_gen_model_input_fn), this cache will not be reset. This could lead to misleading logs, as new compilations might not be reported if their cache size has been seen before.
Consider clearing this set in the clear_jit_cache method, alongside resetting _jitted_train_step_fn and _jitted_eval_step_fn.
| "Training with mesh: %s. Compiled train_step cache size: %s", | ||
| pxla.thread_resources.env.physical_mesh, | ||
| train_step.jitted_fn._cache_size(), # pytype: disable=attribute-error,protected-access | ||
| cache_size = train_step.jitted_fn._cache_size() # pytype: disable=attribute-error |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The pytype: disable comment is missing protected-access. Accessing _cache_size is an access to a protected member. It would be good to add it back for correctness of static analysis suppression.
| cache_size = train_step.jitted_fn._cache_size() # pytype: disable=attribute-error | |
| cache_size = train_step.jitted_fn._cache_size() # pytype: disable=attribute-error,protected-access |
8e7d230 to
8cc7bfb
Compare
8cc7bfb to
f925cdb
Compare
PiperOrigin-RevId: 863438904
f925cdb to
f237d7c
Compare
reduce log