Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add gemma2b variants #1835

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open

Conversation

Optimox
Copy link
Contributor

@Optimox Optimox commented Oct 15, 2024

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

This is related to adding gemma2 support #1813

Changelog

What are the changes made in this PR?
*

Test plan

Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.

  • run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • add unit tests for any new functionality
  • update docstrings for any new or updated methods or classes
  • run unit tests via pytest tests
  • run recipe tests via pytest tests -m integration_test
  • manually run any new or modified recipes with sufficient proof of correctness
  • include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)

UX

If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example

  • I did not change any public API
  • I have added an example to docs or docstrings

Copy link

pytorch-bot bot commented Oct 15, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1835

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure

As of commit 2c216de with merge base 57ab583 (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 15, 2024
@Optimox Optimox mentioned this pull request Oct 15, 2024
Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for adding this! Just took a quick and very non-exhaustive first pass to leave a few comments, will get back to it with a full review later today.

torchtune/modules/attention.py Outdated Show resolved Hide resolved
torchtune/modules/attention.py Outdated Show resolved Hide resolved
recipes/configs/gemma2/27B_full.yaml Outdated Show resolved Hide resolved
torchtune/models/gemma2/_component_builders.py Outdated Show resolved Hide resolved
@joecummings joecummings mentioned this pull request Oct 15, 2024
36 tasks
logger = logging.getLogger(__name__)


class Gemma2Attention(nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we support flex attention, which support soft capping, would it make sense to just force gemma2 users to use flex attention instead of implementing this module?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Flex Attention is only supported on A100 or better, right? I don't think we can make the assumption that our users will have that.

Copy link
Contributor Author

@Optimox Optimox Oct 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello everyone,

I just pushed a new commit which includes all changes discussed with @ebsmothers.
I also implemented a flex attention version but I could not make it work properly.

The default implementation (not using FlexAttention) seems to be working (I only launched the single lora pipeline, please see the attached logs log_gemma2-2b-single-lora_1729498141.txt).

I would appreciate some help on the FlexAttention implementation. Here is why I am struggling.

If I run the following code on my A6000 GPU with torch 2.5:

import torch

from torch.nn.attention.flex_attention import (
    create_block_mask,
    flex_attention)


WINDOW_SIZE=None #None
CAPPING=50.
SCALE=12.


def get_gemma2_flex_score_mask(sliding_window_size, softcapping, query_pre_attn_scalar):
    
    def sliding_window_causal_mask(b, h, q_idx, kv_idx):
        """Causal mask and sliding window as proposed here:
        https://github.com/pytorch-labs/attention-gym/blob/main/examples/flex_attn.ipynb
        """
        causal_mask = q_idx >= kv_idx
        if sliding_window_size is None:
            # if no sliding window return causal mask
            return causal_mask
        else:
            windowed_mask = q_idx - kv_idx <= sliding_window_size

            return causal_mask & windowed_mask
    
    def soft_capping_with_scaling(score, b, h, q_idx, kv_idx):
        if query_pre_attn_scalar is None:
            # usual scaling included in FlexAttention
            score = score / softcapping
            score = torch.tanh(score) #tanh_approx(score)
            return score * softcapping
        else:
            score = score / softcapping * query_pre_attn_scalar**-0.5
            score = torch.tanh(score) #tanh_approx(score)
            return score * softcapping
    
    return sliding_window_causal_mask, soft_capping_with_scaling

# Compile the flex_attention function
flex_attention = torch.compile(flex_attention, dynamic=False)

B=4
H=8
S=117
D=256 #256

mask_mod, score_mod = get_gemma2_flex_score_mask(WINDOW_SIZE, CAPPING, SCALE)

query = torch.randn(
        B, H, S, D, device="cuda", dtype=torch.float16, requires_grad=True
    )
key = torch.randn(
    B, H, S, D, device="cuda", dtype=torch.float16, requires_grad=True
)
value = torch.randn(
    B, H, S, D, device="cuda", dtype=torch.float16, requires_grad=True
)
gradOut = torch.randn(B, H, S, D, device="cuda", dtype=torch.float16)


block_mask = create_block_mask(mask_mod=mask_mod,
                                   B=1,
                                   H=1,
                                   Q_LEN=S,
                                   KV_LEN=S,
                                   device=query.device)

out = flex_attention(
    query, key, value, score_mod=score_mod, block_mask=block_mask
)
print(out.shape)

The code runs fine if I don't compile the flex attention by commenting flex_attention = torch.compile(flex_attention, dynamic=False) but it raises this error otherwise:
BackendCompilerFailed: backend='inductor' raised: OutOfResources: out of resource: shared memory, Required: 114688, Hardware limit: 101376. Reducing block sizes or num_stages may help.

So I disabled compilation and the code seems to be running but very very slowly (48s per iteration vs 1-2s on non flex implementation).

Maybe you could help me understand what is going on ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tagging @RdoubleA and @felipemello1 for their thoughts.

Just checking: which size Gemma-2 model are you testing with?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logs I shared are from gemma2 2B, the code snippet is independent of the gemma architecture it's just a toy example.
I am currently running the qlora single device pipeline with 9B (without flex attention), I'll share the logs tomorrow (I'll push the changes to recipe as there are typos on the output path etc).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @drisspg @yanboliang in regards to this comment above, is it possible to update the default kernel options in flex for better support of A6000/3090/4090s?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the given kernel options, flex attention can be compiled and the code runs (9b lora single device training). However, the code is terribly slow (29 tokens per second) and the loss turns to nan after one batch:

Step 1 | loss:87.6104507446289 lr:2.0000000000000003e-06 tokens_per_second_per_gpu:21.504354449638843 
Step 2 | loss:nan lr:4.000000000000001e-06 tokens_per_second_per_gpu:29.156293709460176 

I don't understand what I am doing wrong, the only obvious optimisation I see is to create one block mask for every layer while I am currently recreating the same block mask for every layer (line 593 in gemma2/_attention.py). Nevertheless, I do not think that this is the current bottleneck.

Wouldn't it be better to go with the simpler implementation for now and switch to FlexAttention when it will work on more GPUs? or at least leave the choice of computation to the final user and default to the classical implementation ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I see what you're saying.. I repro'd this on my end too so it is not a function of any custom kernel configs you're using. Let me look into this a bit more but in the meantime it seems like we shouldn't enable the flex version until we figure this out

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have updated the code to keep the flex attention implementation but disable it for now, until we have found a solution.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey sorry just catching up:

So 2.5 should not require multiple of 128 for sequence length. It is unfortunately pretty common for consumer gpus to hit the SharedMemory issue. I have a pr: pytorch/pytorch#137959 to drop default block sizes but still need to debug the failing test.

For being slow, it is expected that the tanh instruction is very slow compared to the inline assembly variant: https://github.com/pytorch-labs/attention-gym/blob/36f8bd5ded5b3469f7892099590bb2405cc8f744/attn_gym/mods/softcapping.py#L92.

It is actually quite hard generically to know what what block sizes should be used since the amount of shared memory depends on the captured buffers. I am working on a better solution but that is going to take some time unfortunately

@Optimox
Copy link
Contributor Author

Optimox commented Oct 24, 2024

I have pushed changes to the recipes for 9b and 27b (typos in folders' name).
I also ran the single lora recipe for gemma2 9b, everything ran ok (with flex attention disabled). nevertheless the loss seems better with the 2b model, maybe it's just because the larger model overfits more quickly.
logs_gemma2-9b-lora-single_1729686341.txt

# tune download google/gemma-2-27b --ignore-patterns "gemma-2-27b.gguf" --hf-token <HF_TOKEN>
#
# To launch on 4 devices, run the following command from root:
# tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed --config gemma2/27B_full
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did some quick math, I guess this will take at least 216GB total memory (54GB params + 54GB gradients + 108GB optimizer states for AdamW) , which means to run on 4 devices we'd need people to be using A100s. I wonder whether we can use an 8-bit optimizer + optimizer in backward to get us down to a more reasonable peak VRAM here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does 8bit work with distributed?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yeah duh.. there may be some issues with bitsandbytes optimizers on that front. I just tried out ao low-precision optimizers and it seems to work (though haven't resumed from intermediate checkpoint). Also there may be a compile dep there. Anyways if it's too much hassle we can consider it separately, don't wanna increase the scope of this already substantial PR more than necessary

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What should I do here? Change something or expect users to change parameters according to their hardware ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry missed this comment before now. I think it's fine to leave this as you have it and revisit these details in a later PR


checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/gemma-2b/
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
checkpoint_dir: /tmp/gemma-2b/
checkpoint_dir: /tmp/gemma-2-2b/

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

if query_pre_attn_scalar is not None:
self.scaling = query_pre_attn_scalar**-0.5
else:
self.scaling = self.head_dim**-0.5
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you need to add self.cache_enabled=False here (then set it to True at the end of setup_cache), otherwise this will error out. But this is kind of a gotcha, it's not obvious that you need this. cc @SalmanMohammadi we should think about how to make this more obvious to someone adding their own attention layer

Copy link
Collaborator

@SalmanMohammadi SalmanMohammadi Oct 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I added a comment to indicate why it's in the init (maybe @Optimox forked before then?)

        # this flag indicates whether to update the kv-cache during forward
        # passes. when disabled, we can have the cache setup but still
        # perform normal forward passes
        self.cache_enabled = False

Could we be clearer here? I agree we could use with a comment in setup_caches explaining that you actually need to do this if you'd like to use the caches you've just setup.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes I think I forked before this change, will make the change tomorrow thank you!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

k = self.k_norm(k)

# Update key-value cache
if self.kv_cache is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if self.kv_cache is not None:
if self.kv_cache is not None and self.cache_enabled:

should complement the cache enabled stuff earlier to match the other attention module

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

_component_: torchtune.modules.loss.CEWithChunkedOutputLoss

# Fine-tuning arguments
batch_size: 8
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we confident this'll fit on a single device?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed batch size to 2 and accumulation to 8. What is the expected GPU? Is there a CI running everything? Otherwise I guess each user should be responsible to play with the batch to get something suitable for his GPU no ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally we try ship configs which we know will work on some common hardware configuration (see examples here https://github.com/pytorch/torchtune?tab=readme-ov-file#memory-and-training-speed), so users can maintain the expectation that they can get started without any painful OOMs. Then they are free to play with the configs. We should make sure this config works with e.g. 1xA1000 - let me know if you need a hand here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SalmanMohammadi I do not have easy access to a A100, would appreciate if someone could run the code for the 27B params model and let me know what batch size I should set.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll have a quick look when we're ready to land. We can also reasonably mirror the batch size from the config of another similarly sized model already in the codebase.

# tune download google/gemma-2-2b --ignore-patterns "gemma-2-2b.gguf" --hf-token <HF_TOKEN>
#
# To launch on 4 devices, run the following command from root:
# tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed --config gemma2/2B_full
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it's just me but when I try to run these distributed recipes I am hitting AssertionError: FSDP requires named DeviceMesh dims for ND parallelism. It looks to me like we are actually entering _init_sharded_param with a DTensor (see here), which does not happen with our other recipes. Need to figure out why this would be happening

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I think I cracked the case. See here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Big mistake, thank you for catching that!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logs after this fix look much better than previously for the 9b single lora pipeline!
log_gemma2-2b-single-lora_1729937021.txt

"""
rope = RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len, base=rope_base)

mlp = gemma_mlp(dim=embed_dim, hidden_dim=intermediate_dim)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs to be inside the for loop

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

path: /tmp/gemma-2-27b/tokenizer.model

# Dataset
dataset:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry to potentially be a pain in the ass here. We have parallel PR (#1872) which is helping standardize our configs and better expose the features we have. This means we always have packed: False in dataset, and log_peak_memory_stats: True and compile: False below, for every one of our configs.

Would it be annoying to ask if we could update these in the same way while we're here, please?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done I have updated all the configs to match the other PR!

@@ -27,6 +27,4 @@
"lora_gemma_7b",
"qlora_gemma_2b",
"qlora_gemma_7b",
"gemma_hf_to_tune",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch : )

flex_causal_sliding_window,
flex_tanh_soft_capping_with_scaling,
)
logger = logging.getLogger(__name__)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Why is this style of logger getting proliferated? We should be calling get_logger from our utils.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's just a copy paste on my side, let me know if you want me to change that on this PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can just change to torchtune.utils.get_logger, but no strong preference here. Either way we should clean up other usages in a follow-up

@ebsmothers
Copy link
Contributor

Hi @Optimox sorry for the delay here. Given that the flex attention version is still not working properly, how do you feel about pulling it out of this PR? Then we can revisit in a follow-up. For context we are going to be cutting a release soon (targeting code freeze tomorrow) so don't want to block getting this in on something that we can address in a follow-up. Let me know if this makes sense to you.

@Optimox
Copy link
Contributor Author

Optimox commented Oct 29, 2024

@ebsmothers yes no problem! What is the best way of handling this? Adding a new commit deleting the flex attention part of this branch ? Or creating a new PR without the flex attention part?

@ebsmothers
Copy link
Contributor

@Optimox honestly whatever is easiest for you. I imagine just a commit deleting the flex code would be simplest, but feel free to do whatever makes sense to you!

@Optimox
Copy link
Contributor Author

Optimox commented Oct 30, 2024

@ebsmothers I have removed the flex attention implementation from the code, let me know if there are still other changes to make!

@Optimox Optimox changed the title (WIP)feat: add gemma2b variants feat: add gemma2b variants Oct 30, 2024
Comment on lines +14 to +18
Gemma 2 and Gemma original implementations share different normalization but with
the same name, so it is mandatory to differentiate their state dict in order to map
correctly the different weights.
They are essentially the same except for "model.layers.{}.post_attention_layernorm.weight" key.
See discussion here: https://github.com/pytorch/torchtune/pull/1835#discussion_r1803410251
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for documenting this

:nosignatures:

gemma2.gemma2
gemma2.lora_gemma
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
gemma2.lora_gemma
gemma2.lora_gemma2

Comment on lines +302 to +310
sliding_mask = torch.triu(
all_ones, -1 * self.sliding_window_size + 1
) * torch.tril(all_ones, self.sliding_window_size - 1)
mask = torch.where(sliding_mask == 1, mask, -2.3819763e38)

if self.softcapping is not None:
output = output / self.softcapping
output = torch.tanh(output)
output = output * self.softcapping
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add code comments explaining sliding window and the softcapping? (Also one for the magic value in the torch.where line wouldn't hurt)

Comment on lines +287 to +288
q.mul_(self.scaling)
output = torch.matmul(q, k.transpose(2, 3))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add shape comments here too

x: torch.Tensor,
y: Optional[torch.Tensor] = None,
*,
mask: Optional[_MaskType] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you run any of the configs with packed=True (i.e. when mask is a BlockMask)?

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few more comments and questions but overall this is looking great!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants