Skip to content

Commit

Permalink
simple PR fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
tomeras91 committed Mar 31, 2024
1 parent 56183b4 commit 59d832a
Show file tree
Hide file tree
Showing 5 changed files with 12 additions and 28 deletions.
23 changes: 8 additions & 15 deletions docs/source/en/model_doc/jamba.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
Expand Down Expand Up @@ -49,12 +49,10 @@ You also have to have the model on a CUDA device.
You can run the model not using the optimized Mamba kernels, but it is **not** recommended as it will result in significantly lower latencies. In order to do that, you'll need to specify `use_mamba_kernels=False` when loading the model.

### Run the model
Please note that, at the moment, `trust_remote_code=True` is required for running the new Jamba architecture.
```python
from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1",
trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1")
tokenizer = AutoTokenizer.from_pretrained("ai21labs/Jamba-v0.1")

input_ids = tokenizer("In the recent Super Bowl LVIII,", return_tensors='pt').to(model.device)["input_ids"]
Expand All @@ -73,17 +71,15 @@ The published checkpoint is saved in BF16. In order to load it into RAM in BF16/
```python
from transformers import AutoModelForCausalLM
import torch
model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1",
trust_remote_code=True,
torch_dtype=torch.bfloat16) # you can also use torch_dtype=torch.float16
model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1", torch_dtype=torch.bfloat16)
# you can also use torch_dtype=torch.float16
```

When using half precision, you can enable the [FlashAttention2](https://github.com/Dao-AILab/flash-attention) implementation of the Attention blocks. In order to use it, you also need the model on a CUDA device. Since in this precision the model is to big to fit on a single 80GB GPU, you'll also need to parallelize it using [accelerate](https://huggingface.co/docs/accelerate/index):
```python
from transformers import AutoModelForCausalLM
import torch
model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1",
trust_remote_code=True,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
device_map="auto")
Expand All @@ -96,13 +92,10 @@ model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1",

```python
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
quantization_config = BitsAndBytesConfig(load_in_8bit=True,
llm_int8_skip_modules=["mamba"])
model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1",
trust_remote_code=True,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
quantization_config=quantization_config)
quantization_config = BitsAndBytesConfig(load_in_8bit=True, llm_int8_skip_modules=["mamba"])
model = AutoModelForCausalLM.from_pretrained(
"ai21labs/Jamba-v0.1", torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2", quantization_config=quantization_config
)
```
</details>

Expand Down
5 changes: 1 addition & 4 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,7 +515,7 @@
"InstructBlipQFormerConfig",
"InstructBlipVisionConfig",
],
"models.jamba": ["JAMBA_PRETRAINED_CONFIG_ARCHIVE_MAP", "JambaConfig"],
"models.jamba": ["JambaConfig"],
"models.jukebox": [
"JUKEBOX_PRETRAINED_CONFIG_ARCHIVE_MAP",
"JukeboxConfig",
Expand Down Expand Up @@ -2461,7 +2461,6 @@
)
_import_structure["models.jamba"].extend(
[
"JAMBA_PRETRAINED_MODEL_ARCHIVE_LIST",
"JambaForCausalLM",
"JambaForSequenceClassification",
"JambaModel",
Expand Down Expand Up @@ -7168,8 +7167,6 @@
InstructBlipQFormerModel,
InstructBlipVisionModel,
)

# PyTorch model imports
from .models.jamba import (
JAMBA_PRETRAINED_MODEL_ARCHIVE_LIST,
JambaForCausalLM,
Expand Down
5 changes: 2 additions & 3 deletions src/transformers/models/jamba/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2020 The HuggingFace Team. All rights reserved.
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -17,7 +17,7 @@


_import_structure = {
"configuration_jamba": ["JAMBA_PRETRAINED_CONFIG_ARCHIVE_MAP", "JambaConfig"],
"configuration_jamba": ["JambaConfig"],
}


Expand All @@ -28,7 +28,6 @@
pass
else:
_import_structure["modeling_jamba"] = [
"JAMBA_PRETRAINED_MODEL_ARCHIVE_LIST",
"JambaForCausalLM",
"JambaForSequenceClassification",
"JambaModel",
Expand Down
5 changes: 0 additions & 5 deletions src/transformers/models/jamba/modeling_jamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,6 @@

_CONFIG_FOR_DOC = "JambaConfig"

JAMBA_PRETRAINED_MODEL_ARCHIVE_LIST = [
"ai21labs/Jamba-v0.1",
# See all Jamba models at https://huggingface.co/models?filter=jamba
]


# Adapted from transformers.models.mixtral.modeling_mixtral.load_balancing_loss_func
def load_balancing_loss_func(
Expand Down
2 changes: 1 addition & 1 deletion tests/models/jamba/test_modeling_jamba.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down

0 comments on commit 59d832a

Please sign in to comment.