-
Notifications
You must be signed in to change notification settings - Fork 414
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: add gemma2b variants #1835
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1835
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New FailureAs of commit 2c216de with merge base 57ab583 (): NEW FAILURE - The following job has failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for adding this! Just took a quick and very non-exhaustive first pass to leave a few comments, will get back to it with a full review later today.
logger = logging.getLogger(__name__) | ||
|
||
|
||
class Gemma2Attention(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we support flex attention, which support soft capping, would it make sense to just force gemma2 users to use flex attention instead of implementing this module?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Flex Attention is only supported on A100 or better, right? I don't think we can make the assumption that our users will have that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hello everyone,
I just pushed a new commit which includes all changes discussed with @ebsmothers.
I also implemented a flex attention version but I could not make it work properly.
The default implementation (not using FlexAttention) seems to be working (I only launched the single lora pipeline, please see the attached logs log_gemma2-2b-single-lora_1729498141.txt).
I would appreciate some help on the FlexAttention implementation. Here is why I am struggling.
If I run the following code on my A6000 GPU with torch 2.5:
import torch
from torch.nn.attention.flex_attention import (
create_block_mask,
flex_attention)
WINDOW_SIZE=None #None
CAPPING=50.
SCALE=12.
def get_gemma2_flex_score_mask(sliding_window_size, softcapping, query_pre_attn_scalar):
def sliding_window_causal_mask(b, h, q_idx, kv_idx):
"""Causal mask and sliding window as proposed here:
https://github.com/pytorch-labs/attention-gym/blob/main/examples/flex_attn.ipynb
"""
causal_mask = q_idx >= kv_idx
if sliding_window_size is None:
# if no sliding window return causal mask
return causal_mask
else:
windowed_mask = q_idx - kv_idx <= sliding_window_size
return causal_mask & windowed_mask
def soft_capping_with_scaling(score, b, h, q_idx, kv_idx):
if query_pre_attn_scalar is None:
# usual scaling included in FlexAttention
score = score / softcapping
score = torch.tanh(score) #tanh_approx(score)
return score * softcapping
else:
score = score / softcapping * query_pre_attn_scalar**-0.5
score = torch.tanh(score) #tanh_approx(score)
return score * softcapping
return sliding_window_causal_mask, soft_capping_with_scaling
# Compile the flex_attention function
flex_attention = torch.compile(flex_attention, dynamic=False)
B=4
H=8
S=117
D=256 #256
mask_mod, score_mod = get_gemma2_flex_score_mask(WINDOW_SIZE, CAPPING, SCALE)
query = torch.randn(
B, H, S, D, device="cuda", dtype=torch.float16, requires_grad=True
)
key = torch.randn(
B, H, S, D, device="cuda", dtype=torch.float16, requires_grad=True
)
value = torch.randn(
B, H, S, D, device="cuda", dtype=torch.float16, requires_grad=True
)
gradOut = torch.randn(B, H, S, D, device="cuda", dtype=torch.float16)
block_mask = create_block_mask(mask_mod=mask_mod,
B=1,
H=1,
Q_LEN=S,
KV_LEN=S,
device=query.device)
out = flex_attention(
query, key, value, score_mod=score_mod, block_mask=block_mask
)
print(out.shape)
The code runs fine if I don't compile the flex attention by commenting flex_attention = torch.compile(flex_attention, dynamic=False)
but it raises this error otherwise:
BackendCompilerFailed: backend='inductor' raised: OutOfResources: out of resource: shared memory, Required: 114688, Hardware limit: 101376. Reducing block sizes or
num_stages may help.
So I disabled compilation and the code seems to be running but very very slowly (48s per iteration vs 1-2s on non flex implementation).
Maybe you could help me understand what is going on ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tagging @RdoubleA and @felipemello1 for their thoughts.
Just checking: which size Gemma-2 model are you testing with?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logs I shared are from gemma2 2B, the code snippet is independent of the gemma architecture it's just a toy example.
I am currently running the qlora single device pipeline with 9B (without flex attention), I'll share the logs tomorrow (I'll push the changes to recipe as there are typos on the output path etc).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @drisspg @yanboliang in regards to this comment above, is it possible to update the default kernel options in flex for better support of A6000/3090/4090s?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With the given kernel options, flex attention can be compiled and the code runs (9b lora single device training). However, the code is terribly slow (29 tokens per second) and the loss turns to nan after one batch:
Step 1 | loss:87.6104507446289 lr:2.0000000000000003e-06 tokens_per_second_per_gpu:21.504354449638843
Step 2 | loss:nan lr:4.000000000000001e-06 tokens_per_second_per_gpu:29.156293709460176
I don't understand what I am doing wrong, the only obvious optimisation I see is to create one block mask for every layer while I am currently recreating the same block mask for every layer (line 593 in gemma2/_attention.py
). Nevertheless, I do not think that this is the current bottleneck.
Wouldn't it be better to go with the simpler implementation for now and switch to FlexAttention when it will work on more GPUs? or at least leave the choice of computation to the final user and default to the classical implementation ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I see what you're saying.. I repro'd this on my end too so it is not a function of any custom kernel configs you're using. Let me look into this a bit more but in the meantime it seems like we shouldn't enable the flex version until we figure this out
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have updated the code to keep the flex attention implementation but disable it for now, until we have found a solution.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey sorry just catching up:
So 2.5 should not require multiple of 128 for sequence length. It is unfortunately pretty common for consumer gpus to hit the SharedMemory issue. I have a pr: pytorch/pytorch#137959 to drop default block sizes but still need to debug the failing test.
For being slow, it is expected that the tanh instruction is very slow compared to the inline assembly variant: https://github.com/pytorch-labs/attention-gym/blob/36f8bd5ded5b3469f7892099590bb2405cc8f744/attn_gym/mods/softcapping.py#L92.
It is actually quite hard generically to know what what block sizes should be used since the amount of shared memory depends on the captured buffers. I am working on a better solution but that is going to take some time unfortunately
fd79f85
to
e999572
Compare
e999572
to
6f89920
Compare
I have pushed changes to the recipes for 9b and 27b (typos in folders' name). |
# tune download google/gemma-2-27b --ignore-patterns "gemma-2-27b.gguf" --hf-token <HF_TOKEN> | ||
# | ||
# To launch on 4 devices, run the following command from root: | ||
# tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed --config gemma2/27B_full |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did some quick math, I guess this will take at least 216GB total memory (54GB params + 54GB gradients + 108GB optimizer states for AdamW) , which means to run on 4 devices we'd need people to be using A100s. I wonder whether we can use an 8-bit optimizer + optimizer in backward to get us down to a more reasonable peak VRAM here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does 8bit work with distributed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh yeah duh.. there may be some issues with bitsandbytes optimizers on that front. I just tried out ao low-precision optimizers and it seems to work (though haven't resumed from intermediate checkpoint). Also there may be a compile dep there. Anyways if it's too much hassle we can consider it separately, don't wanna increase the scope of this already substantial PR more than necessary
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What should I do here? Change something or expect users to change parameters according to their hardware ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry missed this comment before now. I think it's fine to leave this as you have it and revisit these details in a later PR
|
||
checkpointer: | ||
_component_: torchtune.training.FullModelHFCheckpointer | ||
checkpoint_dir: /tmp/gemma-2b/ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
checkpoint_dir: /tmp/gemma-2b/ | |
checkpoint_dir: /tmp/gemma-2-2b/ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
if query_pre_attn_scalar is not None: | ||
self.scaling = query_pre_attn_scalar**-0.5 | ||
else: | ||
self.scaling = self.head_dim**-0.5 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you need to add self.cache_enabled=False
here (then set it to True at the end of setup_cache
), otherwise this will error out. But this is kind of a gotcha, it's not obvious that you need this. cc @SalmanMohammadi we should think about how to make this more obvious to someone adding their own attention layer
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, I added a comment to indicate why it's in the init (maybe @Optimox forked before then?)
# this flag indicates whether to update the kv-cache during forward
# passes. when disabled, we can have the cache setup but still
# perform normal forward passes
self.cache_enabled = False
Could we be clearer here? I agree we could use with a comment in setup_caches
explaining that you actually need to do this if you'd like to use the caches you've just setup.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes I think I forked before this change, will make the change tomorrow thank you!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
k = self.k_norm(k) | ||
|
||
# Update key-value cache | ||
if self.kv_cache is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if self.kv_cache is not None: | |
if self.kv_cache is not None and self.cache_enabled: |
should complement the cache enabled stuff earlier to match the other attention module
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss | ||
|
||
# Fine-tuning arguments | ||
batch_size: 8 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are we confident this'll fit on a single device?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed batch size to 2 and accumulation to 8. What is the expected GPU? Is there a CI running everything? Otherwise I guess each user should be responsible to play with the batch to get something suitable for his GPU no ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Generally we try ship configs which we know will work on some common hardware configuration (see examples here https://github.com/pytorch/torchtune?tab=readme-ov-file#memory-and-training-speed), so users can maintain the expectation that they can get started without any painful OOMs. Then they are free to play with the configs. We should make sure this config works with e.g. 1xA1000 - let me know if you need a hand here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@SalmanMohammadi I do not have easy access to a A100, would appreciate if someone could run the code for the 27B params model and let me know what batch size I should set.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll have a quick look when we're ready to land. We can also reasonably mirror the batch size from the config of another similarly sized model already in the codebase.
# tune download google/gemma-2-2b --ignore-patterns "gemma-2-2b.gguf" --hf-token <HF_TOKEN> | ||
# | ||
# To launch on 4 devices, run the following command from root: | ||
# tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed --config gemma2/2B_full |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe it's just me but when I try to run these distributed recipes I am hitting AssertionError: FSDP requires named DeviceMesh dims for ND parallelism
. It looks to me like we are actually entering _init_sharded_param
with a DTensor (see here), which does not happen with our other recipes. Need to figure out why this would be happening
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah I think I cracked the case. See here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Big mistake, thank you for catching that!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logs after this fix look much better than previously for the 9b single lora pipeline!
log_gemma2-2b-single-lora_1729937021.txt
""" | ||
rope = RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len, base=rope_base) | ||
|
||
mlp = gemma_mlp(dim=embed_dim, hidden_dim=intermediate_dim) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This needs to be inside the for loop
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
path: /tmp/gemma-2-27b/tokenizer.model | ||
|
||
# Dataset | ||
dataset: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry to potentially be a pain in the ass here. We have parallel PR (#1872) which is helping standardize our configs and better expose the features we have. This means we always have packed: False
in dataset
, and log_peak_memory_stats: True
and compile: False
below, for every one of our configs.
Would it be annoying to ask if we could update these in the same way while we're here, please?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done I have updated all the configs to match the other PR!
@@ -27,6 +27,4 @@ | |||
"lora_gemma_7b", | |||
"qlora_gemma_2b", | |||
"qlora_gemma_7b", | |||
"gemma_hf_to_tune", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch : )
dd4cf33
to
54a237c
Compare
flex_causal_sliding_window, | ||
flex_tanh_soft_capping_with_scaling, | ||
) | ||
logger = logging.getLogger(__name__) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Why is this style of logger getting proliferated? We should be calling get_logger
from our utils.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's just a copy paste on my side, let me know if you want me to change that on this PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can just change to torchtune.utils.get_logger
, but no strong preference here. Either way we should clean up other usages in a follow-up
Hi @Optimox sorry for the delay here. Given that the flex attention version is still not working properly, how do you feel about pulling it out of this PR? Then we can revisit in a follow-up. For context we are going to be cutting a release soon (targeting code freeze tomorrow) so don't want to block getting this in on something that we can address in a follow-up. Let me know if this makes sense to you. |
@ebsmothers yes no problem! What is the best way of handling this? Adding a new commit deleting the flex attention part of this branch ? Or creating a new PR without the flex attention part? |
@Optimox honestly whatever is easiest for you. I imagine just a commit deleting the flex code would be simplest, but feel free to do whatever makes sense to you! |
@ebsmothers I have removed the flex attention implementation from the code, let me know if there are still other changes to make! |
Gemma 2 and Gemma original implementations share different normalization but with | ||
the same name, so it is mandatory to differentiate their state dict in order to map | ||
correctly the different weights. | ||
They are essentially the same except for "model.layers.{}.post_attention_layernorm.weight" key. | ||
See discussion here: https://github.com/pytorch/torchtune/pull/1835#discussion_r1803410251 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for documenting this
:nosignatures: | ||
|
||
gemma2.gemma2 | ||
gemma2.lora_gemma |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
gemma2.lora_gemma | |
gemma2.lora_gemma2 |
sliding_mask = torch.triu( | ||
all_ones, -1 * self.sliding_window_size + 1 | ||
) * torch.tril(all_ones, self.sliding_window_size - 1) | ||
mask = torch.where(sliding_mask == 1, mask, -2.3819763e38) | ||
|
||
if self.softcapping is not None: | ||
output = output / self.softcapping | ||
output = torch.tanh(output) | ||
output = output * self.softcapping |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add code comments explaining sliding window and the softcapping? (Also one for the magic value in the torch.where
line wouldn't hurt)
q.mul_(self.scaling) | ||
output = torch.matmul(q, k.transpose(2, 3)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's add shape comments here too
x: torch.Tensor, | ||
y: Optional[torch.Tensor] = None, | ||
*, | ||
mask: Optional[_MaskType] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have you run any of the configs with packed=True
(i.e. when mask is a BlockMask
)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few more comments and questions but overall this is looking great!
Context
What is the purpose of this PR? Is it to
This is related to adding gemma2 support #1813
Changelog
What are the changes made in this PR?
*
Test plan
Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.
pre-commit install
)pytest tests
pytest tests -m integration_test
UX
If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example