Skip to content
This repository was archived by the owner on Jun 24, 2024. It is now read-only.
This repository was archived by the owner on Jun 24, 2024. It is now read-only.

Support for RWKV #75

Open
Open
@philpax

Description

@philpax

So this is a pretty immense task and I'd start with #45, but...

RWKV is an RNN with Transformer-level LLM performance, which can also be directly trained like a GPT transformer (parallelizable). And it's 100% attention-free. You only need the hidden state at position t to compute the state at position t+1. You can use the "GPT" mode to quickly compute the hidden state for the "RNN" mode.

So it's combining the best of RNN and transformer - great performance, fast inference, saves VRAM, fast training, "infinite" ctx_len, and free sentence embedding (using the final hidden state).

It's entirely open-source, so not legally burdened like LLaMA, and (from what I've seen) is more powerful than BLOOM at the same parameter count.

I asked the RWKV Discord which implementation would be worth looking at, and this is what I was told:

RWKV-LM/RWKV-v4neo/src/model.py is the implementation that's actually used to train the large models, it's cuda only and has tons of features you probably don't need.
rwkv_pip_package only implements inference, but is a good implementation and worth a look, recently got a lot more complex due to supporting more and more strategies and including various optimizations.
ChatRWKV/src/model_run is an older version, but haven't played with it so not sure how good it is. Might be worth a look since it's basically an older version of the one in rwkv_pip_package.
RWKV_in_150_lines.py I still haven't fully checked out, but I know it doesn't support GPT mode, so that may or may not be less useful
Also worth a look is RWKV-v4neo/src/model_run.py, which is a small inference-only impl capable of loading the large RWKV checkpoints
I'm not sure if it has GPT-mode, though

So it sounds like rwkv_pip_package is the way to go as source material:

https://github.com/BlinkDL/ChatRWKV/blob/main/rwkv_pip_package/src/rwkv/model.py

The following articles are very useful for understanding how RWKV works:

An interesting detail from the latter is the following:

The largest number a 16-bit floating point number (float16) can represent is 65 504, anything above that overflows, which is bad. Most of the code has no problems with this, partially because the Layer Normalizations keep values in a reasonable range. However, the RWKV attention contains exponentially large numbers (exp(bonus + k)). In practice, the RWKV attention is implemented in a way where we factor out an exponential factor from num and den to keep everything within float16 range. See for example the time_mixing function in RWKV in 150 lines.

This may pose issues for the GGML 4-bit quantisation format, which is non-optimal. We would likely want GPTQ quantisation.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions