Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Impl simple mamba model #1480

Merged
merged 22 commits into from
Feb 8, 2024
Merged

Impl simple mamba model #1480

merged 22 commits into from
Feb 8, 2024

Conversation

drbh
Copy link
Collaborator

@drbh drbh commented Jan 25, 2024

This draft PR is a work in progress implementation of the mamba model. This PR currently loads weights, and produces correct logits after a single pass.

This PR still needs to correctly integrate this model so it produces tokens as expected, and apply optimization to avoid all copies during runtime/unnecessary operations.

Helpful resources

Mamba: Linear-Time Sequence Modeling with Selective State Spaces (Albert Gu and Tri Dao)
https://github.com/johnma2006/mamba-minimal
https://github.com/huggingface/candle/blob/main/candle-examples/examples/mamba-minimal/model.rs
huggingface/transformers#28094

Notes: this dev work is currently targeting state-spaces/mamba-130m, so if you want to test please use that model. Additionally when starting the router the prefill needs to be limited: cargo run -- --max-batch-prefill-tokens 768 --max-input-length 768

Update / Current State

Integration tests have been added and basic functionality such as model loading is supported.

cd integration-tests
pytest -vv models/test_fused_kernel_mamba.py
  • add tests
  • load model
  • make simple request
  • resolve warmup issue
  • resolve output issues

fetching models tested during dev

text-generation-server download-weights state-spaces/mamba-130m
text-generation-server download-weights state-spaces/mamba-1.4b
text-generation-server download-weights state-spaces/mamba-2.8b

The server can be run

cd server
 MASTER_ADDR=127.0.0.1 MASTER_PORT=5555 python text_generation_server/cli.py serve state-spaces/mamba-2.8b

router

cargo run

make a request

curl -s localhost:3000/generate \
    -X POST \
    -d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":20}}' \
    -H 'Content-Type: application/json' | jq

response

{
  "generated_text": "\n\nDeep learning is a machine learning technique that uses a deep neural network to learn from data."
}

@drbh drbh mentioned this pull request Jan 29, 2024
@RonanKMcGovern
Copy link

thanks, this will be a great addition as we see more mamba architectures

@drbh drbh marked this pull request as ready for review February 7, 2024 04:35
Copy link
Collaborator

@Narsil Narsil left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@Narsil Narsil merged commit bd405e0 into main Feb 8, 2024
7 checks passed
@Narsil Narsil deleted the impl-simple-mamba-model branch February 8, 2024 09:19
kdamaszk pushed a commit to kdamaszk/tgi-gaudi that referenced this pull request Apr 29, 2024
This draft PR is a work in progress implementation of the mamba model.
This PR currently loads weights, and produces correct logits after a
single pass.

This PR still needs to correctly integrate this model so it produces
tokens as expected, and apply optimization to avoid all copies during
runtime/unnecessary operations.

[Mamba: Linear-Time Sequence Modeling with Selective State Spaces
(Albert Gu and Tri Dao)](https://arxiv.org/abs/2312.00752)
https://github.com/johnma2006/mamba-minimal

https://github.com/huggingface/candle/blob/main/candle-examples/examples/mamba-minimal/model.rs
huggingface/transformers#28094

Notes: this dev work is currently targeting `state-spaces/mamba-130m`,
so if you want to test please use that model. Additionally when starting
the router the prefill needs to be limited: `cargo run --
--max-batch-prefill-tokens 768 --max-input-length 768`

Integration tests have been added and basic functionality such as model
loading is supported.

```bash
cd integration-tests
pytest -vv models/test_fused_kernel_mamba.py
```
- [x] add tests
- [x] load model
- [x] make simple request
- [ ] resolve warmup issue
- [ ] resolve output issues

fetching models tested during dev
```bash
text-generation-server download-weights state-spaces/mamba-130m
text-generation-server download-weights state-spaces/mamba-1.4b
text-generation-server download-weights state-spaces/mamba-2.8b
```

The server can be run
```bash
cd server
 MASTER_ADDR=127.0.0.1 MASTER_PORT=5555 python text_generation_server/cli.py serve state-spaces/mamba-2.8b
```

router
```bash
cargo run
```

make a request
```bash
curl -s localhost:3000/generate \
    -X POST \
    -d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":20}}' \
    -H 'Content-Type: application/json' | jq
```

response
```json
{
  "generated_text": "\n\nDeep learning is a machine learning technique that uses a deep neural network to learn from data."
}
```

---------

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants