-
Notifications
You must be signed in to change notification settings - Fork 289
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Cache AoT compilation result #927
Conversation
axlearn/common/trainer.py
Outdated
# excessive recompilations. Note: this could introduce overhead to training due to | ||
# pre-compilation checks (such as sharding check) that increases the step time for some | ||
# models. Note that this cache is always disabled at steps when xsc is enabled. | ||
disable_python_train_step_cache: Optional[bool] = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's avoid negative boolean fields, as enable_python_train_step_cache=True
is more readable than disable_python_train_step_cache=False
.
disable_python_train_step_cache: Optional[bool] = None | |
enable_python_train_step_cache: Optional[bool] = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
axlearn/common/trainer.py
Outdated
self._compiled_train_step = compiled_jit_train_step_fn | ||
return self._compiled_train_step |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we leave self._compiled_train_step
unchanged when with_xsc
is true?
self._compiled_train_step = compiled_jit_train_step_fn | |
return self._compiled_train_step | |
return compiled_jit_train_step_fn |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point, yes.
axlearn/common/trainer.py
Outdated
if cfg.cache_python_train_step is True: | ||
raise ValueError("cache_python_train_step cannot be True when xsc is enabled.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually cache_python_train_step
is still useful even with XSC, since the non-XSC steps will be cached and only the XSC step needs to be recompiled. The XSC step does not run often, so maybe this is OK?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed.
axlearn/common/trainer.py
Outdated
# pre-compilation checks (such as sharding check) that increases the step time for some | ||
# models. Note that this cache is always disabled at steps when xsc is enabled. | ||
# Defaults to None which is interpreted as True. | ||
cache_python_train_step: Optional[bool] = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit --
cache_python_train_step: Optional[bool] = None | |
cache_compiled_train_step: Optional[bool] = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
This improves the step time of some models on v6e by 2x. See the comments of
cache_compiled_train_step
.