-
Notifications
You must be signed in to change notification settings - Fork 871
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Has any thought been given to using LoRA to increase the number of experts (100x) with minimal memory? #95
Comments
A relevant publication |
That's basically what I was thinking of except specific to the forward module and not only for fine-tuning |
Also I did a little profiling on a CPU on a smaller model with batch size 4, 1024 tokens and 8 experts (with 3 used per token). Initializing the model and running one inference on a random input gives:
Additionally, I don't think you probably wouldn't need to retrain something like Mistral 8x7B from scratch. You could do some stuff with singular value decomposition to get a good approximation of the base and lora matrices. Edit: For fun I profiled the same model but with 100 experts using 30 of them per token:
Edit 2: Thinking a bit more, I'm not sure that this setup will be as effective since the results of multiple networks won't necessarily be the same as the result of a single network that is a linear combination of their features. You might need to run the experts separately which might take a similar amount or even more time than the original way. However memory usage should still be dramatically improved. |
Yeah, my plan was to implement something that (unfortunately) has to call the FFN (PS Happy new years everyone 😄) |
Happy new year! You're right that it would be slower than the currently used method for sparse mixture of experts but I don't know if it would be that much slower since the current method runs |
Hey did you ever end up trying to train this? As soon as I turn even a small number of gradients on it OOMs pretty quick. I tried with my own implementation that I developed without looking at yours, and I just tried again with yours (slightly modified) and it still happens |
So it turns out I'm really f**ing stupid and forgot to add a line to actually run the model on some input when I was profiling. It is not actually better on memory, it is much worse. (like 6-7 times worse on a small example) I think it's because you're essentially making a new weight matrix for each token. There might be a work around but I'm not sure. Sorry for wasting your time. |
Nah, no time wasted. I was already playing with this idea before I saw your initial comments. I'm working on implementing the paper I linked above now. The authors released their code but it's a mess |
As I understand the current MoeLayer, a gate calculates the weight to be applied to the output of each expert, the top k are selected and run on the data, finally the results are multiplied by their respective weights and summed.
This means you have to store n copies of the layers, one for each expert.
If instead you had a single base set of parameters and each expert was defined by a low rank matrix you could hold a lot more experts in the same memory.
Calculating the weights would be the same but instead of taking a weighted sum of the output you could take a weighted sum of the parameters and then use the summed parameters to calculate the result. (This would start to look a lot like Schmidhuber's fast weight programmers)
Using the code below, I could add 100 experts (at rank 4) to a feed forward module with dim=512 and hid_dim=2048 with only ~2x increase in number of parameters. I would expect this ratio to get better as the dim/hid_dim gets larger.
Is there some fatal flaw that makes this approach not worth it?
Here are the possible flaws I could think of but none of them strike me as compelling.
I didn't test performance so it may not work as well. Although LoRA has worked pretty well in a number of places it may be that MoE relies on high rank differences between experts.
It could also be that the linear combination of matrices prior to the non-linearity(s) is not as powerful as combining the result.
Additionally, there is a performance hit to the mixing of experts as you have to do on the order of
3 * (dim + hid_dim) * rank * num_tokens * num_experts_per_token
operations to mix the parameters plus the extra(dim, rank) * (rank, hid_dim)
matrix multiplications and the addition with the base matrix.I haven't really looked at it but I'm pretty sure this is made up for by the fact that once the models are mixed you are running a single feed-forward on the input rather than multiple.
Thoughts?
Some rough code to illustrate the idea based off of the FeedForward and MoeLayer modules:
`
`
The text was updated successfully, but these errors were encountered: