Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cohere Model Release #29622

Merged
merged 27 commits into from
Mar 15, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
6e73900
Cohere Model Release (#1)
saurabhdash2512 Mar 12, 2024
67c4a9b
Remove unnecessary files and code (#2)
saurabhdash2512 Mar 12, 2024
0bc7cf9
Delete cohere-model directory (#3)
saurabhdash2512 Mar 12, 2024
f964504
Make Fix (#5)
saurabhdash2512 Mar 13, 2024
9c867c5
Pr fixes (#6)
ahmetustun Mar 13, 2024
04d96fd
Merge branch 'huggingface:main' into main
saurabhdash2512 Mar 13, 2024
115f198
Tokenizer test (#8)
ahmetustun Mar 13, 2024
cacb8ae
Adding Docs and other minor changes (#7)
saurabhdash2512 Mar 13, 2024
d959443
Merge branch 'huggingface:main' into main
saurabhdash2512 Mar 13, 2024
82a0f3b
Add modeling tests (#9)
saurabhdash2512 Mar 13, 2024
ef6ed3d
Smol Fix (#11)
saurabhdash2512 Mar 13, 2024
aa2f878
Merge branch 'huggingface:main' into main
saurabhdash2512 Mar 13, 2024
c86b184
tokenization tests are fixed
ahmetustun Mar 13, 2024
09cd02f
Merge branch 'main' into tokenization_tests
ahmetustun Mar 13, 2024
e7567ca
format fixes
ahmetustun Mar 13, 2024
cf86bba
Merge pull request #12 from saurabhdash2512/tokenization_tests
ahmetustun Mar 13, 2024
62822c8
fix pr doc tests
ahmetustun Mar 13, 2024
43ae26e
fix pr doc tests
ahmetustun Mar 13, 2024
aeac596
Merge branch 'tokenization_tests'
ahmetustun Mar 13, 2024
d1dd3e1
fix pr doc tests
ahmetustun Mar 13, 2024
6faf117
fix pr style check
ahmetustun Mar 13, 2024
c841cc7
small changes in cohere.md
saurabhdash2512 Mar 14, 2024
966ec9c
FIX: Address final comments for transformers integration (#13)
younesbelkada Mar 14, 2024
aeb5908
Merge branch 'huggingface:main' into main
saurabhdash2512 Mar 14, 2024
bb7f728
fix modeling cohere (#14)
younesbelkada Mar 15, 2024
24a2227
Update chat templates to use the new API (#15)
Rocketknight1 Mar 15, 2024
24a746e
Merge branch 'huggingface:main' into main
saurabhdash2512 Mar 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
FIX: Address final comments for transformers integration (#13)
* fix modeling final nits and add proper test file

* for now leave empty tests

* add integration test

* push new test
  • Loading branch information
younesbelkada authored Mar 14, 2024
commit 966ec9c6c787687dc67509b2be0b1b4023e5bc51
19 changes: 6 additions & 13 deletions src/transformers/models/cohere/modeling_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,23 +106,13 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, s
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
# For BC we register cos and sin cached
self.max_seq_len_cached = max_position_embeddings
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
t = t / self.scaling_factor
freqs = torch.outer(t, self.inv_freq)
emb = torch.repeat_interleave(freqs, 2, dim=-1)
self.register_buffer("_cos_cached", emb.cos().to(torch.get_default_dtype()), persistent=False)
self.register_buffer("_sin_cached", emb.sin().to(torch.get_default_dtype()), persistent=False)

@torch.no_grad()
def forward(self, x, position_ids, seq_len=None):
if seq_len is not None:
logger.warning_once("The `seq_len` argument is deprecated and unused. It will be removed in v4.39.")

def forward(self, x, position_ids):
# x: [bs, num_attention_heads, seq_len, head_size]
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()

# Force float32 since bfloat16 loses precision on long contexts
# See https://github.com/huggingface/transformers/pull/29285
device_type = x.device.type
Expand Down Expand Up @@ -1027,9 +1017,11 @@ def _update_causal_mask(self, attention_mask, input_tensor):
return causal_mask


# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with Llama->Cohere
class CohereForCausalLM(CoherePreTrainedModel):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class CohereForCausalLM(CoherePreTrainedModel):
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with Llama->Cohere
class CohereForCausalLM(CoherePreTrainedModel):

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know why the copied from cannot be used here? It will be very useful to easily maintain the methods below such as _prepare_inputs_for_generation, otherwise you can also try to put a copied from statement on the _prepare_inputs_for_generation method

_tied_weights_keys = ["model.embed_tokens.weight", "lm_head.weight"]
_tied_weights_keys = ["lm_head.weight"]

# Ignore copy
def __init__(self, config):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def __init__(self, config):
# Ignore copy
def __init__(self, config):

super().__init__(config)
self.model = CohereModel(config)
Expand Down Expand Up @@ -1058,6 +1050,7 @@ def set_decoder(self, decoder):
def get_decoder(self):
return self.model

# Ignore copy
@add_start_docstrings_to_model_forward(COHERE_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def forward(
# Ignore copy
def forward(

Since here the difference with llama is that we multiply the lm logits with logits_scale

Expand Down
1 change: 1 addition & 0 deletions src/transformers/utils/fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ def _generate_supported_model_class_names(
"hubert",
"layoutlm",
"llama",
"cohere",
"lxmert",
"m2m_100",
"marian",
Expand Down
146 changes: 124 additions & 22 deletions tests/models/cohere/test_modeling_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,27 @@

from transformers import CohereConfig, is_torch_available
from transformers.testing_utils import (
require_bitsandbytes,
require_torch,
require_torch_multi_gpu,
require_torch_sdpa,
slow,
torch_device,
)

from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ids_tensor
from ...test_modeling_common import ModelTesterMixin, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin


if is_torch_available():
import torch

from transformers import CohereForCausalLM, CohereModel
from transformers import AutoTokenizer, CohereForCausalLM, CohereModel


# Copied from transformers.tests.models.llama.LlamaModelTester with Llama->Cohere
class CohereModelTester:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class CohereModelTester:
# Copied from transformers.tests.models.llama.test_modeling_llama.LlamaModelTester with Llama->Cohere
class CohereModelTester:

Can we add also copied from on tests as well ?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are differences in the tests.

def __init__(
self,
Expand Down Expand Up @@ -109,6 +115,7 @@ def prepare_config_and_inputs(self):

return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels

# Ignore copy
def get_config(self):
return CohereConfig(
vocab_size=self.vocab_size,
Expand All @@ -124,6 +131,7 @@ def get_config(self):
is_decoder=False,
initializer_range=self.initializer_range,
pad_token_id=self.pad_token_id,
eos_token_id=self.pad_token_id,
)

def create_and_check_model(
Expand Down Expand Up @@ -262,7 +270,7 @@ def prepare_config_and_inputs_for_common(self):


@require_torch
class CohereModelTest(unittest.TestCase):
class CohereModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (CohereModel, CohereForCausalLM) if is_torch_available() else ()
all_generative_model_classes = (CohereForCausalLM,) if is_torch_available() else ()
pipeline_model_mapping = (
Expand All @@ -285,6 +293,14 @@ def setUp(self):
self.model_tester = CohereModelTester(self)
self.config_tester = ConfigTester(self, config_class=CohereConfig, hidden_size=37)

def test_config(self):
self.config_tester.run_common_tests()

@unittest.skip("TODO @gante fix this for Cohere")
@parameterized.expand([(1, False), (1, True), (4, False)])
def test_new_cache_format(self, num_beams, do_sample):
pass

def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
Expand All @@ -295,26 +311,112 @@ def test_model_various_embeddings(self):
config_and_inputs[0].position_embedding_type = type
self.model_tester.create_and_check_model(*config_and_inputs)

@unittest.skip("TODO @gante fix this")
@parameterized.expand([(1, False), (1, True), (4, False)])
def test_new_cache_format(self, num_beams, do_sample):
pass
@require_bitsandbytes
@require_torch_sdpa
@require_torch_multi_gpu
@slow
def test_eager_matches_sdpa_generate(self):
"""
Overwritting the common test as the test is flaky on tiny models
"""
max_new_tokens = 30

model_id = "CohereForAI/c4ai-command-r-v01-4bit"
tokenizer = AutoTokenizer.from_pretrained(model_id)

model_sdpa = CohereForCausalLM.from_pretrained(
model_id, torch_dtype=torch.float16, low_cpu_mem_usage=True, device_map="auto"
)
self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")

model_eager = CohereForCausalLM.from_pretrained(
model_id, torch_dtype=torch.float16, attn_implementation="eager", device_map="auto"
)

self.assertTrue(model_eager.config._attn_implementation == "eager")

for name, submodule in model_eager.named_modules():
if "SdpaAttention" in submodule.__class__.__name__:
raise ValueError("The eager model should not have SDPA attention layers")

has_sdpa = False
for name, submodule in model_sdpa.named_modules():
if "SdpaAttention" in submodule.__class__.__name__:
has_sdpa = True
break
if not has_sdpa:
raise ValueError("The SDPA model should have SDPA attention layers")

texts = [
"hi here's a longer context, getting longer and",
"Hello this is a very long sentence my friend, very long for real",
"Today I am in Paris and",
]

for padding_side in ["left", "right"]:
tokenizer.padding_side = padding_side
tokenizer.pad_token = tokenizer.eos_token

inputs = tokenizer(texts, return_tensors="pt", padding=True).to(torch_device)

res_eager = model_eager.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)
res_sdpa = model_sdpa.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)

with self.subTest(f"{padding_side}"):
torch.testing.assert_close(
res_eager,
res_sdpa,
msg=f"\n{tokenizer.batch_decode(res_eager)} \nvs\n{tokenizer.batch_decode(res_sdpa)}",
)


@require_torch
@slow
class CohereIntegrationTest(unittest.TestCase):
@unittest.skip("Logits are not exactly the same, once we fix the instabalities somehow, will update!")
@slow
def test_model_logits(self):
input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338]
model = CohereForCausalLM.from_pretrained("CohereForAI/c4ai-command-r-v01", device_map="auto")
out = model(torch.tensor(input_ids).unsqueeze(0))
# # Expected mean on dim = -1
EXPECTED_MEAN = torch.tensor([[0.5077, -2.5771, -1.1590, -2.6220, -1.7837, -2.4421, -1.3293, -2.2028]])
torch.testing.assert_close(out[0].mean(-1).cpu(), EXPECTED_MEAN, atol=1e-2, rtol=1e-2)
# slicing logits[0, 0, 0:30]
EXPECTED_SLICE = torch.tensor([ 1.8525, 5.0039, 2.7734, 3.6270, 0.9390, -0.4587, 3.4062, 0.9468, \
3.7324, 1.2344, 5.3047, 4.7266, 5.9414, 5.5195, 1.8047, 3.5215, \
1.5752, 3.7031, 6.2891, 3.4785, 2.0293, 4.2539, 2.8086, 4.7070, \
3.6953, 4.0391, 3.9766, 3.3066, 2.9395, 3.3105]) # fmt: skip
torch.testing.assert_close(out[0][0, 0, :30].cpu(), EXPECTED_SLICE, atol=1e-5, rtol=1e-5)
@require_torch_multi_gpu
def test_batched_4bit(self):
model_id = "CohereForAI/c4ai-command-r-v01-4bit"

EXPECTED_TEXT = [
'Hello today I am going to show you how to make a simple and easy card using the new stamp set called "Hello" from the Occasions catalog. This set is so versatile and can be used for many occasions. I used the new In',
"Hi there, here we are again with another great collection of free fonts. This time we have gathered 10 free fonts that you can download and use in your designs. These fonts are free for personal and commercial use. So",
]

model = CohereForCausalLM.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)

tokenizer.pad_token = tokenizer.eos_token

text = ["Hello today I am going to show you how to", "Hi there, here we are"]
inputs = tokenizer(text, return_tensors="pt", padding=True).to(torch_device)

output = model.generate(**inputs, max_new_tokens=40, do_sample=False)
self.assertEqual(tokenizer.batch_decode(output, skip_special_tokens=True), EXPECTED_TEXT)

def test_batched_small_model_logits(self):
# Since the model is very large, we created a random cohere model so that we can do a simple
# logits check on it.
model_id = "hf-internal-testing/cohere-random"

EXPECTED_LOGITS = torch.Tensor(
[
[[0.0000, 0.1866, -0.1997], [0.0000, -0.0736, 0.1785], [0.0000, -0.1965, -0.0569]],
[[0.0000, -0.0302, 0.1488], [0.0000, -0.0402, 0.1351], [0.0000, -0.0341, 0.1116]],
]
).to(torch_device)

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = CohereForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True, torch_dtype=torch.float16).to(
torch_device
)

tokenizer.pad_token = tokenizer.eos_token

text = ["Hello today I am going to show you how to", "Hi there, here we are"]
inputs = tokenizer(text, return_tensors="pt", padding=True).to(torch_device)

with torch.no_grad():
output = model(**inputs)

logits = output.logits
self.assertTrue(torch.allclose(EXPECTED_LOGITS, logits[:, :3, :3], rtol=1e-3, atol=1e-3))
Loading