Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bfloat16 support, and an attempt at homogenizing model_dtype & precision #54

Merged
merged 17 commits into from
Jul 10, 2024

Conversation

francoishernandez
Copy link
Member

@francoishernandez francoishernandez commented Jul 3, 2024

bfloat16

caption

X = steps
xent

X = relative time
speed

It seems to work relatively plug-n-play, but we might need to adapt a few things optimizer-wise:

  • fusedadam does not seem supported;
  • loss lags a bit behind compared to fp16/fp32

We might investigate some bf16-specific implementations, e.g. https://github.com/arogozhnikov/adamw_bfloat16

precision // model_dtype homogenization

Previously, model_dtype is used for training, with some "precision" deduced and applied depending on some other settings (optimizer), and precision is set in PredictConfig for inference. This PR proposes a factorization of precision at the common RunningConfig level, and dtype (actual dtype the model is cast to for training),is deduced with the same conditions as before.

TODOs:

  • check refactoring did not break inference;
  • clarify int8 specific case handling (-> done via dtype computed_field);
  • investigate bf16 optimization;
  • add some validation if needed (e.g. fusedadam + bf16 incompatibility)
  • add some docs/FAQ page with various precision/dtype related specificities?

@francoishernandez francoishernandez added enhancement New feature or request refactor Some refactoring, aesthetic or cleanup code changes labels Jul 3, 2024
@francoishernandez
Copy link
Member Author

93158fe enables amp for the bfloat16 case, which seems to work fine.

Capture d’écran 2024-07-04 à 16 37 21 Capture d’écran 2024-07-04 à 16 37 07

@francoishernandez francoishernandez changed the title [WIP] bfloat16 support, and an attempt at homogenizing model_dtype & precision bfloat16 support, and an attempt at homogenizing model_dtype & precision Jul 4, 2024
@francoishernandez
Copy link
Member Author

francoishernandez commented Jul 4, 2024

TODO

  • rename precision to compute_dtype
  • rename dtype to storage_dtype (or model_dtype?)

@francoishernandez francoishernandez marked this pull request as ready for review July 4, 2024 15:21
@vince62s
Copy link
Contributor

vince62s commented Jul 4, 2024

for xlm-roberta-xl(xxl) which are natively fp32, I added this here:

eole_safetensor[key] = eole_safetensor[key].to(torch.float16)

to convert them to fp16
I think since we can convert any kind of model (more and more are in bf16) maybe by default we can keep the original dtype but we can add a flag to force the storage in another dtype.

@francoishernandez francoishernandez merged commit 81318aa into main Jul 10, 2024
4 checks passed
@francoishernandez francoishernandez deleted the bf16_support branch February 7, 2025 08:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request refactor Some refactoring, aesthetic or cleanup code changes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants