Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add NucleusX Model #27259

Open
wants to merge 29 commits into
base: main
Choose a base branch
from
Open

Add NucleusX Model #27259

wants to merge 29 commits into from

Conversation

syncdoth
Copy link

@syncdoth syncdoth commented Nov 3, 2023

What does this PR do?

This PR adds a new model named NucleusX. This model is contributed by Sehyun Choi and NucleusAI. The model is based on the Retentive Network architecture, and the code is largely adapted from this repo, which again borrows core implementations from torchscale. We are planning to release our paper and weights soon.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

We kindly request the review of this new model from @ArthurZucker and @younesbelkada!

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

This is the configuration class to store the configuration of a [`~NucleusXModel`]. It is used to instantiate an
NucleusX model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of the NucleusX-7B
[NucleusAI/NucleusX-7B](https://huggingface.co/NucleusAI/NucleusX-7B) architecture.
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This weight is not released yet; we are planning to release the weight at this link soon! We have included this link to pass the configuration testing requiring a link to the checkpoint.

@syncdoth
Copy link
Author

syncdoth commented Nov 4, 2023

The current test failure at tests_pr_documentation_tests is due to the incorrect repo_id and links, namely NucleusAI/NucleusX-7B and https://huggingface.co/NucleusAI/NucleusX-7B used in examples and model docs. These checkpoints are not released yet; we plan to release them soon.

@syncdoth syncdoth marked this pull request as ready for review November 4, 2023 16:19
@syncdoth syncdoth changed the title [WIP] Add NucleusX Model Add NucleusX Model Nov 4, 2023
@syncdoth
Copy link
Author

syncdoth commented Nov 4, 2023

cc: @sippycoder and also @LysandreJik!

@ArthurZucker
Copy link
Collaborator

Hey! Thanks for opening the PR, I'll let @Rocketknight1 do a first review as he is more familiar with this kind of models!

@Rocketknight1
Copy link
Member

Hi all! RetNets seem like a really interesting architecture, so I'm quite excited to take a look - I'll try to review this in the next day or two.

Copy link
Member

@Rocketknight1 Rocketknight1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi all, I just looked through this! Overall, the core modelling code looks very solid, and I couldn't find much to complain about. We normally encourage the use of # Copied from in these PRs, but given that RetNets differ significantly from Transformers, most functions here will be unique to NucleusX.

I also think the test coverage is good, in particular the tests confirming that outputs are equivalent in parallel/recurrent/chunkwise mode.

Before we can merge, though, we need checkpoints to be uploaded. Also, we ideally need an integration test. The purpose of these tests is to ensure that model output for a specific checkpoint remains numerically constant, which is very important to ensure that future updates don't create silent errors. Here is an example of an integration test that you can copy for NucleusX.
If your checkpoints are too large for our CI, we can make a tiny-random-nucleusx model to use for the integration test. An integration test confirming generation output remains constant when do_sample=False would also be helpful!

Overall though, this looks like a really solid PR, and I suspect we shouldn't have much trouble including this in transformers. Thank you for your contribution!

src/transformers/models/nucleus_x/modeling_nucleus_x.py Outdated Show resolved Hide resolved
src/transformers/models/nucleus_x/modeling_nucleus_x.py Outdated Show resolved Hide resolved
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

>>> from transformers import NucleusXModel, NucleusXConfig

>>> # Initializing a NucleusX-7B style configuration
>>> configuration = NucleusXConfig()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
>>> configuration = NucleusXConfig()
>>> configuration = NucleusXConfig(decoder_layers=2)

Also, one more comment! The doctest runner is crashing on this file, and I suspect the reason is that it's running out of memory because you're initializing a 7B model in float32 and so using 28GB of memory, which is a lot for the doctest runner! Maybe change this line to initialize a much smaller model?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was trying to find out why the doctest is failing for a long time, and this makes total sense!

@syncdoth
Copy link
Author

syncdoth commented Nov 9, 2023

@Rocketknight1 Thanks for reviewing this PR! I have gone through the comments and resolved them. There are also some other updates:

  • dtype handling for tensors created in NucleusXRelPos to have the same dtype as the model weights
  • rename some *layer_norm modules to *rms_norm for conformity.
  • removed subln option (Sub-LayerNorm), which is not applicable to our choice of FFN (SwiGLU).

There are other minor changes, which can be found in the commit logs.

As per weight release, we are working hard to make that happen :) We'll ping here when the weights are ready for public release.

Thanks again!

@@ -131,6 +131,7 @@ def get_config(self):
decoder_layers=self.num_hidden_layers,
is_decoder=False,
decoder_retention_heads=self.num_attention_heads,
use_cache=False,
Copy link
Author

@syncdoth syncdoth Nov 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When use_cache=True, NucleusXMultiScaleRetention.parallel_forward will be less efficient. This is because use_cache=True makes the NucleusXMultiScaleRetention.parallel_forward to compute past_key_values, which incurs another O(T^2) computations.

use_cache=True should be set only when we want to do recurrent forward following the parallel forward (e.g. during generation, we compute the prompt in parallel, but generate with recurrent mode).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @gante to this bit - is our generation code ready to handle networks with multiple forward modes?

@Rocketknight1
Copy link
Member

Also @syncdoth, while we're waiting for the checkpoints, there are some tests failing in this PR that are unrelated. If you pull the latest version from main and rebase your PR, that should fix them.


forward_mode = kwargs.get("forward_mode", "parallel")
if past_key_values is not None:
# NOTE: when we have past_key_values, using recurrent mode will be faster.
Copy link
Author

@syncdoth syncdoth Nov 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Rocketknight1 @gante For your question about "generate code handling networks with multiple forward mode" in another comment, this is my take on that: when we call prepare_inputs_for_generation, if we detect past_key_values, it means that the prompt to the generation has been computed, and using recurrent forward is better. Hence the forward_mode = "recurrent" line below.

Note that forward_mode is just a string (used like an enum) that the model takes as forward input to select the forward mode at each forward step!

Copy link
Member

@gante gante Nov 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Rocketknight1 @syncdoth We can support multiple generation modes, but the implementation for it depends on a few factors! The audio models are the best examples. For instance:

  1. Whisper wraps generate and accepts additional flags. These flags trigger additional arguments for generate (e.g. a custom logits processor to generate timestamps) or postprocessing
  2. Bark contains 3 internal models, and wraps generate to call generate on them in a sequence

Additionally, if model.forward accepts multiple modes, you can also prepare the flags in model. prepare_inputs_for_generation, as written above :)

@syncdoth
Copy link
Author

Also @syncdoth, while we're waiting for the checkpoints, there are some tests failing in this PR that are unrelated. If you pull the latest version from main and rebase your PR, that should fix them.

This may be a beginner question, but should I rebase main and (force?) push or merge main and push?

@Rocketknight1
Copy link
Member

Probably the easiest way to do it is to pull the latest version of main, then rebase your branch onto main, and then force push.

@fakerybakery
Copy link

Hi @syncdoth, do you know what happened to Nucleus AI? The website is now down

@syncdoth
Copy link
Author

Hi @syncdoth, do you know what happened to Nucleus AI? The website is now down

This is unrelated to this PR but there's some maintenance going on with the website. Hang tight :)

@Rocketknight1
Copy link
Member

btw @syncdoth if you're still getting test failures, try 'sync upstream' on the main branch of your forked repo, then on your development machine, pull the latest main branch, change to the add_nucleus_x branch, rebase and finally force push. Should resolve everything!

@syncdoth syncdoth force-pushed the add_nucleus_x branch 2 times, most recently from 8663ad0 to 0d9ee02 Compare November 23, 2023 13:46
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@Rocketknight1
Copy link
Member

Don't stale, please! This looks quite close to being ready! (cc @syncdoth - let me know if you need any help with the last bit!)

@syncdoth
Copy link
Author

We are on the verge of releasing the weight! There's been a bit of delay in the schedule 🥲

The last bit should be updating the weight links in the docs and writing the integration tests; We are working on it hard!

syncdoth and others added 21 commits December 24, 2023 02:37
Removes the accidentally added comma in tests/generation/test_utils.py
When loading the model weights in different dtype other than fp32,
`.float` statements may cause troubles. This commit handles the tensor
creations and float casting to be aware of the `dtype` of the weights.
Reason: since we are using GLU, sub-layernorm is not well-defined.
This follows the example of other models, such as LongT5, idefics, llama, etc.
@syncdoth
Copy link
Author

Hi @Rocketknight1, I’m seeing test failure related to document building, and testing the run of NucleusXForCausalLM.forward example. It seems that it might be due to .from_pretrained from a 7B checkpoint killing the worker, like the previous example in the configuration. Do you think I should change the example to a smaller one?

@Maykeye
Copy link

Maykeye commented Dec 25, 2023

Does it require some tinkering to use generate in not parallel mode? (I don't have RAM for processing 16KB prompt in parallel)

I dumped source to model folder, edited config to treat it as trusted_remoted_code=True thingy, parallel works fine, as in test:

In [7]: print(tokenizer.decode(model.generate(**tokenizer("Hello my name is", return_tensors="pt").to("cuda"), max_new_tokens=20, do_sample=False, forward_mod
   ...: e="parallel").ravel()))
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
/home/fella/src/llama/text-generation-webui/models/NucleusAI_Nucleus-X/modeling_nucleus_x.py:370: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  cache = (current_kv, scale, torch.tensor(prev_seqlen + 1, dtype=torch.long))
<s> Hello my name is Tina and I am a 25 year old female. I am a very outgoing person

but recurrent no

In [8]: print(tokenizer.decode(model.generate(**tokenizer("Hello my name is", return_tensors="pt").to("cuda"), max_new_tokens=20, do_sample=False, forward_mod
   ...: e="recurrent").ravel()))
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
/home/fella/src/llama/text-generation-webui/models/NucleusAI_Nucleus-X/modeling_nucleus_x.py:370: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  cache = (current_kv, scale, torch.tensor(prev_seqlen + 1, dtype=torch.long))
<s> Hello my name is the most of the world.
The first thing I noticed was the size of the room. It

(Even if I say in config.json to use recurrent forward mode, 16KB prompt fails to pass through model.generate unless I use forward_mode='recurrent')

@Rocketknight1
Copy link
Member

Hi @syncdoth, sorry for the Christmas delay! You're correct, though - the issue is almost certainly caused by the docstring trying to load a model too big for the test runner. Is there any smaller checkpoint we can use? You could also try torch_dtype=torch.bfloat16.

@syncdoth
Copy link
Author

syncdoth commented Feb 2, 2024

Haha plz don't stale this! We are still working hard to put out the model. We are working on a small model to pass the PR requirement, but it has been a lower priority unfortunately :( will finish to finish this within mid Feb!

@huggingface huggingface deleted a comment from github-actions bot Feb 2, 2024
@ArthurZucker
Copy link
Collaborator

No worries 🤗

@LysandreJik LysandreJik added the WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress label Feb 14, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants