Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Has any thought been given to using LoRA to increase the number of experts (100x) with minimal memory? #95

Open
sixChar opened this issue Dec 22, 2023 · 8 comments

Comments

@sixChar
Copy link

sixChar commented Dec 22, 2023

As I understand the current MoeLayer, a gate calculates the weight to be applied to the output of each expert, the top k are selected and run on the data, finally the results are multiplied by their respective weights and summed.

This means you have to store n copies of the layers, one for each expert.

If instead you had a single base set of parameters and each expert was defined by a low rank matrix you could hold a lot more experts in the same memory.
Calculating the weights would be the same but instead of taking a weighted sum of the output you could take a weighted sum of the parameters and then use the summed parameters to calculate the result. (This would start to look a lot like Schmidhuber's fast weight programmers)

Using the code below, I could add 100 experts (at rank 4) to a feed forward module with dim=512 and hid_dim=2048 with only ~2x increase in number of parameters. I would expect this ratio to get better as the dim/hid_dim gets larger.

Is there some fatal flaw that makes this approach not worth it?

Here are the possible flaws I could think of but none of them strike me as compelling.

I didn't test performance so it may not work as well. Although LoRA has worked pretty well in a number of places it may be that MoE relies on high rank differences between experts.
It could also be that the linear combination of matrices prior to the non-linearity(s) is not as powerful as combining the result.

Additionally, there is a performance hit to the mixing of experts as you have to do on the order of
3 * (dim + hid_dim) * rank * num_tokens * num_experts_per_token operations to mix the parameters plus the extra
(dim, rank) * (rank, hid_dim) matrix multiplications and the addition with the base matrix.
I haven't really looked at it but I'm pretty sure this is made up for by the fact that once the models are mixed you are running a single feed-forward on the input rather than multiple.

Thoughts?

Some rough code to illustrate the idea based off of the FeedForward and MoeLayer modules:
`

class MultiLoraLinear(nn.Module):
     def __init__(self, ins:int, outs:int, num_loras:int, rank:int):
          super().__init__()
          # I'm pretty sure nn.Linear initializes weights from uniform(-sqrt(num_ins),+sqrt(num_ins))
          init_scale = 2 / sqrt(ins)
          self.w_base = nn.Parameter((torch.rand(ins, outs) - 0.5) * init_scale)
          self.w_loras_a = nn.Parameter((torch.rand(num_loras, ins, rank) - 0.5) * init_scale)
          self.w_loras_b = nn.Parameter((torch.rand(num_loras, rank, outs) - 0.5) * init_scale)
          self.num_loras = num_loras
          self.ins = ins
          self.outs = outs

      def forward(self, x, expert_weights, expert_indices):
          ## construct weight matrix from weighted sum of lora params
          # select out the lora params to use 
          selected_w_loras_a = self.w_loras_a[expert_indices,:,:]

          # multiply the subset of lora params by their weighting and sum them
          w_lora_a = torch.sum(selected_w_loras_a * expert_weights.unsqueeze(-1).unsqueeze(-1), dim=1)
          selected_w_loras_b = self.w_loras_b[expert_indices,:,:]
          w_lora_b = torch.sum(selected_w_loras_b * expert_weights.unsqueeze(-1).unsqueeze(-1), dim=1)

          # Construct the full lora matrix as lora_a * lora_b transpose (but for each token/batch)
          # b: batch, i: num ins, k: rank, o: num outs
          w_lora = torch.einsum("bik,bko->bio", w_lora_a, w_lora_b)
          w = self.w_base + w_lora
 
          return torch.einsum("bi,bio->bo", x, w)


class MoreMoeFeedForward(nn.Module):
        def __init__(self, gate: nn.Module, dim: int, hid_dim: int, num_experts: int, lora_rank: int, num_experts_per_tok=5):
            super().__init__()
            assert num_experts > 0
            self.gate = gate
            self.num_experts = num_experts
            self.lora_rank = lora_rank
            self.num_experts_per_tok = num_experts_per_tok
    
            self.w1 = MultiLoraLinear(
                dim,
                hid_dim,
                num_experts,
                lora_rank,
            )
            self.w2 = MultiLoraLinear(
                hid_dim,
                dim,
                num_experts,
                lora_rank,
            )
            self.w3 = MultiLoraLinear(
                dim,
                hid_dim,
                num_experts,
                lora_rank,
            )
    
    
        def forward(self, x: torch.Tensor) -> torch.Tensor:
            # Mostly copied from MoeLayer
            inputs_squashed = x.view(-1, x.shape[-1])
            gate_logits = self.gate(inputs_squashed)
            weights, indices = torch.topk(
                gate_logits,
                self.num_experts_per_tok
            )
            weights = F.softmax(
                weights,
                dim=1,
                dtype=torch.float
            ).type_as(x)
    
            # Mostly copied from FeedForward
            res_squashed =  self.w2(
                nn.functional.silu(
                    self.w1(inputs_squashed, weights, indices)
                ) * self.w3(inputs_squashed, weights, indices),
                weights,
                indices
            )
            return res_squashed.view(x.shape)

`

@WillJStone
Copy link

A relevant publication

https://arxiv.org/abs/2310.18339

@sixChar
Copy link
Author

sixChar commented Dec 31, 2023

That's basically what I was thinking of except specific to the forward module and not only for fine-tuning

@sixChar
Copy link
Author

sixChar commented Dec 31, 2023

Also I did a little profiling on a CPU on a smaller model with batch size 4, 1024 tokens and 8 experts (with 3 used per token).

Initializing the model and running one inference on a random input gives:

  • Mistral original sparse MoE:

CPU Time: 44.692ms
CPU Memory: 9.39 Gb

  • With a slightly more rigorous implementation of LoRA (rank 8) experts like the code I posted above:

CPU Time: 10.150ms
CPU Memory: 1.60 Gb

Additionally, I don't think you probably wouldn't need to retrain something like Mistral 8x7B from scratch. You could do some stuff with singular value decomposition to get a good approximation of the base and lora matrices.

Edit: For fun I profiled the same model but with 100 experts using 30 of them per token:

CPU Time: 31.468ms
CPU Memory: 2.62 Gb

Edit 2: Thinking a bit more, I'm not sure that this setup will be as effective since the results of multiple networks won't necessarily be the same as the result of a single network that is a linear combination of their features. You might need to run the experts separately which might take a similar amount or even more time than the original way. However memory usage should still be dramatically improved.

@WillJStone
Copy link

Yeah, my plan was to implement something that (unfortunately) has to call the FFN num_experts_per_token times, each with a different adapter. Slow but low memory and... maybe has some benefits? I guess the test would be determining if you're better off using a larger LoRA rank with the standard method, or doing the "MoLoRA" method with a roughly equivalent number of new params to the one big LoRA.

(PS Happy new years everyone 😄)

@sixChar
Copy link
Author

sixChar commented Jan 1, 2024

Happy new year!

You're right that it would be slower than the currently used method for sparse mixture of experts but I don't know if it would be that much slower since the current method runs num_experts_per_token different full networks. Adding on the calculations for the LoRA version should be within 2x slower. It also might be more amenable to cache optimization since most of the FFN operations will be using the base parameters.

@WillJStone
Copy link

Hey did you ever end up trying to train this? As soon as I turn even a small number of gradients on it OOMs pretty quick. I tried with my own implementation that I developed without looking at yours, and I just tried again with yours (slightly modified) and it still happens

@sixChar
Copy link
Author

sixChar commented Jan 24, 2024

So it turns out I'm really f**ing stupid and forgot to add a line to actually run the model on some input when I was profiling.

It is not actually better on memory, it is much worse. (like 6-7 times worse on a small example)

I think it's because you're essentially making a new weight matrix for each token. There might be a work around but I'm not sure.

Sorry for wasting your time.

@WillJStone
Copy link

Nah, no time wasted. I was already playing with this idea before I saw your initial comments. I'm working on implementing the paper I linked above now. The authors released their code but it's a mess

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants