Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Exact command to reproduce the curve in MOD is Vibe? #10

Open
jzhang38 opened this issue Apr 13, 2024 · 5 comments
Open

Exact command to reproduce the curve in MOD is Vibe? #10

jzhang38 opened this issue Apr 13, 2024 · 5 comments

Comments

@jzhang38
Copy link

jzhang38 commented Apr 13, 2024

Hi Joey,

Thank you for such a wonderful OS work! !

Could you share the exact command to reproduce the curve in your MOD is Vibe blog? For example, did you use DDP and how many GPUs?

@joey00072
Copy link
Owner

joey00072 commented Apr 13, 2024

I trained 300M model on single A6000 (from paperspace grident) with bf16-mixed presicon ,

experiments/mixture_of_depth/train_mod.py here is location of training script you can look into it and change sizes, or keep it default to reproduce.

first install this repo git clone and pip install -e .
then pretokenize dataset python examples/prepare-dataset.py open this file and change dataset to minipile
and than run train_mod.py.
I am using lightning fabric so it should be pretty easy to multi node training but I trained on single 48gig a6000

I'll add README here experiments/mixture_of_depth/ in details tonight or tomorrow 😅

@jzhang38
Copy link
Author

jzhang38 commented Apr 14, 2024

I'll add README here experiments/mixture_of_depth/ in details tonight or tomorrow

Thank you so much!

Yeah I've pretty much read your code related to MoD.

One concern for me is that I noticed the dataset is implemented as an iterator object. So I am not sure whether lightning fabric would handle this correctly in a multi-gpu setup as we would need a distributed sampler.

@jzhang38
Copy link
Author

Screenshot 2024-04-14 at 3 12 37 PM

Looking at Figure 7 from the paper, I feel they also multiply the router weights to those skipped tokens as well.

@joey00072
Copy link
Owner

We are taking softmax over long seq length, most values at other end will be close to zeros,
if we multiply all tokens by router logits pass though token will become really tiny like 1e-5 or something.

@WuNein
Copy link

WuNein commented Apr 28, 2024

@joey00072
One more thing, about MoD.

filtered_x = torch.gather(input=x, dim=1, index=indices_expanded) # -> batch, capacity, dim

Since MoD only makes a very small fraction of the tokens for caluating attention, I have concerns about model performance in some extreme cases, such as very short inputs.

top_k = int(seq_len * self.capacity_factor) # may be i should use math.ceil

I think topK should have a minial value, like 10 or seq_len (when seq_len < 10).

And when faced with long text, this thing is kind of like sparse attention or a sliding window, which I think is acceptable.

I would like to consult your thoughts

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants