Skip to content

Conversation

@copybara-service
Copy link

reduce log

@abheesht17
Copy link
Collaborator

/gemini review

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request aims to reduce logging noise during training. It achieves this by ensuring the training mesh information is logged only once, and the JIT compilation cache size is logged only when it changes to a previously unseen value. These are good changes for improving log readability.

My review includes two points:

  1. A potential issue with the lifecycle of the newly introduced _jit_cache, which could lead to missed log messages about recompilations when the trainer is reused.
  2. A minor correction to a pytype: disable comment.

self._buffered_eval_metrics: MetricsBuffer | None = None
self.training_hooks = None
self.data_hooks = None
self._jit_cache = set()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This cache is used to track JIT compilation cache sizes to avoid verbose logging. However, it appears this set is never cleared during the lifetime of the trainer instance. If the trainer is reused in a way that causes recompilations (e.g., by calling with_loss_fn or with_gen_model_input_fn), this cache will not be reset. This could lead to misleading logs, as new compilations might not be reported if their cache size has been seen before.

Consider clearing this set in the clear_jit_cache method, alongside resetting _jitted_train_step_fn and _jitted_eval_step_fn.

"Training with mesh: %s. Compiled train_step cache size: %s",
pxla.thread_resources.env.physical_mesh,
train_step.jitted_fn._cache_size(), # pytype: disable=attribute-error,protected-access
cache_size = train_step.jitted_fn._cache_size() # pytype: disable=attribute-error

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The pytype: disable comment is missing protected-access. Accessing _cache_size is an access to a protected member. It would be good to add it back for correctness of static analysis suppression.

Suggested change
cache_size = train_step.jitted_fn._cache_size() # pytype: disable=attribute-error
cache_size = train_step.jitted_fn._cache_size() # pytype: disable=attribute-error,protected-access

PiperOrigin-RevId: 863438904
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants