-
Notifications
You must be signed in to change notification settings - Fork 27.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Cohere Model Release #29622
Cohere Model Release #29622
Changes from 1 commit
6e73900
67c4a9b
0bc7cf9
f964504
9c867c5
04d96fd
115f198
cacb8ae
d959443
82a0f3b
ef6ed3d
aa2f878
c86b184
09cd02f
e7567ca
cf86bba
62822c8
43ae26e
aeac596
d1dd3e1
6faf117
c841cc7
966ec9c
aeb5908
bb7f728
24a2227
24a746e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
* fix modeling final nits and add proper test file * for now leave empty tests * add integration test * push new test
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
|
@@ -106,23 +106,13 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, s | |||||||
self.base = base | ||||||||
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) | ||||||||
self.register_buffer("inv_freq", inv_freq, persistent=False) | ||||||||
# For BC we register cos and sin cached | ||||||||
self.max_seq_len_cached = max_position_embeddings | ||||||||
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) | ||||||||
t = t / self.scaling_factor | ||||||||
freqs = torch.outer(t, self.inv_freq) | ||||||||
emb = torch.repeat_interleave(freqs, 2, dim=-1) | ||||||||
self.register_buffer("_cos_cached", emb.cos().to(torch.get_default_dtype()), persistent=False) | ||||||||
self.register_buffer("_sin_cached", emb.sin().to(torch.get_default_dtype()), persistent=False) | ||||||||
|
||||||||
@torch.no_grad() | ||||||||
def forward(self, x, position_ids, seq_len=None): | ||||||||
if seq_len is not None: | ||||||||
logger.warning_once("The `seq_len` argument is deprecated and unused. It will be removed in v4.39.") | ||||||||
|
||||||||
def forward(self, x, position_ids): | ||||||||
# x: [bs, num_attention_heads, seq_len, head_size] | ||||||||
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) | ||||||||
position_ids_expanded = position_ids[:, None, :].float() | ||||||||
|
||||||||
# Force float32 since bfloat16 loses precision on long contexts | ||||||||
# See https://github.com/huggingface/transformers/pull/29285 | ||||||||
device_type = x.device.type | ||||||||
|
@@ -1027,9 +1017,11 @@ def _update_causal_mask(self, attention_mask, input_tensor): | |||||||
return causal_mask | ||||||||
|
||||||||
|
||||||||
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with Llama->Cohere | ||||||||
class CohereForCausalLM(CoherePreTrainedModel): | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you know why the copied from cannot be used here? It will be very useful to easily maintain the methods below such as |
||||||||
_tied_weights_keys = ["model.embed_tokens.weight", "lm_head.weight"] | ||||||||
_tied_weights_keys = ["lm_head.weight"] | ||||||||
|
||||||||
# Ignore copy | ||||||||
def __init__(self, config): | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
super().__init__(config) | ||||||||
self.model = CohereModel(config) | ||||||||
|
@@ -1058,6 +1050,7 @@ def set_decoder(self, decoder): | |||||||
def get_decoder(self): | ||||||||
return self.model | ||||||||
|
||||||||
# Ignore copy | ||||||||
@add_start_docstrings_to_model_forward(COHERE_INPUTS_DOCSTRING) | ||||||||
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) | ||||||||
def forward( | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Since here the difference with llama is that we multiply the lm logits with |
||||||||
|
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
|
@@ -20,21 +20,27 @@ | |||||||
|
||||||||
from transformers import CohereConfig, is_torch_available | ||||||||
from transformers.testing_utils import ( | ||||||||
require_bitsandbytes, | ||||||||
require_torch, | ||||||||
require_torch_multi_gpu, | ||||||||
require_torch_sdpa, | ||||||||
slow, | ||||||||
torch_device, | ||||||||
) | ||||||||
|
||||||||
from ...generation.test_utils import GenerationTesterMixin | ||||||||
from ...test_configuration_common import ConfigTester | ||||||||
from ...test_modeling_common import ids_tensor | ||||||||
from ...test_modeling_common import ModelTesterMixin, ids_tensor | ||||||||
from ...test_pipeline_mixin import PipelineTesterMixin | ||||||||
|
||||||||
|
||||||||
if is_torch_available(): | ||||||||
import torch | ||||||||
|
||||||||
from transformers import CohereForCausalLM, CohereModel | ||||||||
from transformers import AutoTokenizer, CohereForCausalLM, CohereModel | ||||||||
|
||||||||
|
||||||||
# Copied from transformers.tests.models.llama.LlamaModelTester with Llama->Cohere | ||||||||
class CohereModelTester: | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Can we add also copied from on tests as well ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There are differences in the tests. |
||||||||
def __init__( | ||||||||
self, | ||||||||
|
@@ -109,6 +115,7 @@ def prepare_config_and_inputs(self): | |||||||
|
||||||||
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels | ||||||||
|
||||||||
# Ignore copy | ||||||||
def get_config(self): | ||||||||
return CohereConfig( | ||||||||
vocab_size=self.vocab_size, | ||||||||
|
@@ -124,6 +131,7 @@ def get_config(self): | |||||||
is_decoder=False, | ||||||||
initializer_range=self.initializer_range, | ||||||||
pad_token_id=self.pad_token_id, | ||||||||
eos_token_id=self.pad_token_id, | ||||||||
) | ||||||||
|
||||||||
def create_and_check_model( | ||||||||
|
@@ -262,7 +270,7 @@ def prepare_config_and_inputs_for_common(self): | |||||||
|
||||||||
|
||||||||
@require_torch | ||||||||
class CohereModelTest(unittest.TestCase): | ||||||||
class CohereModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): | ||||||||
all_model_classes = (CohereModel, CohereForCausalLM) if is_torch_available() else () | ||||||||
all_generative_model_classes = (CohereForCausalLM,) if is_torch_available() else () | ||||||||
pipeline_model_mapping = ( | ||||||||
|
@@ -285,6 +293,14 @@ def setUp(self): | |||||||
self.model_tester = CohereModelTester(self) | ||||||||
self.config_tester = ConfigTester(self, config_class=CohereConfig, hidden_size=37) | ||||||||
|
||||||||
def test_config(self): | ||||||||
self.config_tester.run_common_tests() | ||||||||
|
||||||||
@unittest.skip("TODO @gante fix this for Cohere") | ||||||||
@parameterized.expand([(1, False), (1, True), (4, False)]) | ||||||||
def test_new_cache_format(self, num_beams, do_sample): | ||||||||
pass | ||||||||
|
||||||||
def test_model(self): | ||||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs() | ||||||||
self.model_tester.create_and_check_model(*config_and_inputs) | ||||||||
|
@@ -295,26 +311,112 @@ def test_model_various_embeddings(self): | |||||||
config_and_inputs[0].position_embedding_type = type | ||||||||
self.model_tester.create_and_check_model(*config_and_inputs) | ||||||||
|
||||||||
@unittest.skip("TODO @gante fix this") | ||||||||
@parameterized.expand([(1, False), (1, True), (4, False)]) | ||||||||
def test_new_cache_format(self, num_beams, do_sample): | ||||||||
pass | ||||||||
@require_bitsandbytes | ||||||||
@require_torch_sdpa | ||||||||
@require_torch_multi_gpu | ||||||||
@slow | ||||||||
def test_eager_matches_sdpa_generate(self): | ||||||||
""" | ||||||||
Overwritting the common test as the test is flaky on tiny models | ||||||||
""" | ||||||||
max_new_tokens = 30 | ||||||||
|
||||||||
model_id = "CohereForAI/c4ai-command-r-v01-4bit" | ||||||||
tokenizer = AutoTokenizer.from_pretrained(model_id) | ||||||||
|
||||||||
model_sdpa = CohereForCausalLM.from_pretrained( | ||||||||
model_id, torch_dtype=torch.float16, low_cpu_mem_usage=True, device_map="auto" | ||||||||
) | ||||||||
self.assertTrue(model_sdpa.config._attn_implementation == "sdpa") | ||||||||
|
||||||||
model_eager = CohereForCausalLM.from_pretrained( | ||||||||
model_id, torch_dtype=torch.float16, attn_implementation="eager", device_map="auto" | ||||||||
) | ||||||||
|
||||||||
self.assertTrue(model_eager.config._attn_implementation == "eager") | ||||||||
|
||||||||
for name, submodule in model_eager.named_modules(): | ||||||||
if "SdpaAttention" in submodule.__class__.__name__: | ||||||||
raise ValueError("The eager model should not have SDPA attention layers") | ||||||||
|
||||||||
has_sdpa = False | ||||||||
for name, submodule in model_sdpa.named_modules(): | ||||||||
if "SdpaAttention" in submodule.__class__.__name__: | ||||||||
has_sdpa = True | ||||||||
break | ||||||||
if not has_sdpa: | ||||||||
raise ValueError("The SDPA model should have SDPA attention layers") | ||||||||
|
||||||||
texts = [ | ||||||||
"hi here's a longer context, getting longer and", | ||||||||
"Hello this is a very long sentence my friend, very long for real", | ||||||||
"Today I am in Paris and", | ||||||||
] | ||||||||
|
||||||||
for padding_side in ["left", "right"]: | ||||||||
tokenizer.padding_side = padding_side | ||||||||
tokenizer.pad_token = tokenizer.eos_token | ||||||||
|
||||||||
inputs = tokenizer(texts, return_tensors="pt", padding=True).to(torch_device) | ||||||||
|
||||||||
res_eager = model_eager.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False) | ||||||||
res_sdpa = model_sdpa.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False) | ||||||||
|
||||||||
with self.subTest(f"{padding_side}"): | ||||||||
torch.testing.assert_close( | ||||||||
res_eager, | ||||||||
res_sdpa, | ||||||||
msg=f"\n{tokenizer.batch_decode(res_eager)} \nvs\n{tokenizer.batch_decode(res_sdpa)}", | ||||||||
) | ||||||||
|
||||||||
|
||||||||
@require_torch | ||||||||
@slow | ||||||||
class CohereIntegrationTest(unittest.TestCase): | ||||||||
@unittest.skip("Logits are not exactly the same, once we fix the instabalities somehow, will update!") | ||||||||
@slow | ||||||||
def test_model_logits(self): | ||||||||
input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338] | ||||||||
model = CohereForCausalLM.from_pretrained("CohereForAI/c4ai-command-r-v01", device_map="auto") | ||||||||
out = model(torch.tensor(input_ids).unsqueeze(0)) | ||||||||
# # Expected mean on dim = -1 | ||||||||
EXPECTED_MEAN = torch.tensor([[0.5077, -2.5771, -1.1590, -2.6220, -1.7837, -2.4421, -1.3293, -2.2028]]) | ||||||||
torch.testing.assert_close(out[0].mean(-1).cpu(), EXPECTED_MEAN, atol=1e-2, rtol=1e-2) | ||||||||
# slicing logits[0, 0, 0:30] | ||||||||
EXPECTED_SLICE = torch.tensor([ 1.8525, 5.0039, 2.7734, 3.6270, 0.9390, -0.4587, 3.4062, 0.9468, \ | ||||||||
3.7324, 1.2344, 5.3047, 4.7266, 5.9414, 5.5195, 1.8047, 3.5215, \ | ||||||||
1.5752, 3.7031, 6.2891, 3.4785, 2.0293, 4.2539, 2.8086, 4.7070, \ | ||||||||
3.6953, 4.0391, 3.9766, 3.3066, 2.9395, 3.3105]) # fmt: skip | ||||||||
torch.testing.assert_close(out[0][0, 0, :30].cpu(), EXPECTED_SLICE, atol=1e-5, rtol=1e-5) | ||||||||
@require_torch_multi_gpu | ||||||||
def test_batched_4bit(self): | ||||||||
model_id = "CohereForAI/c4ai-command-r-v01-4bit" | ||||||||
|
||||||||
EXPECTED_TEXT = [ | ||||||||
'Hello today I am going to show you how to make a simple and easy card using the new stamp set called "Hello" from the Occasions catalog. This set is so versatile and can be used for many occasions. I used the new In', | ||||||||
"Hi there, here we are again with another great collection of free fonts. This time we have gathered 10 free fonts that you can download and use in your designs. These fonts are free for personal and commercial use. So", | ||||||||
] | ||||||||
|
||||||||
model = CohereForCausalLM.from_pretrained(model_id) | ||||||||
tokenizer = AutoTokenizer.from_pretrained(model_id) | ||||||||
|
||||||||
tokenizer.pad_token = tokenizer.eos_token | ||||||||
|
||||||||
text = ["Hello today I am going to show you how to", "Hi there, here we are"] | ||||||||
inputs = tokenizer(text, return_tensors="pt", padding=True).to(torch_device) | ||||||||
|
||||||||
output = model.generate(**inputs, max_new_tokens=40, do_sample=False) | ||||||||
self.assertEqual(tokenizer.batch_decode(output, skip_special_tokens=True), EXPECTED_TEXT) | ||||||||
|
||||||||
def test_batched_small_model_logits(self): | ||||||||
# Since the model is very large, we created a random cohere model so that we can do a simple | ||||||||
# logits check on it. | ||||||||
model_id = "hf-internal-testing/cohere-random" | ||||||||
|
||||||||
EXPECTED_LOGITS = torch.Tensor( | ||||||||
[ | ||||||||
[[0.0000, 0.1866, -0.1997], [0.0000, -0.0736, 0.1785], [0.0000, -0.1965, -0.0569]], | ||||||||
[[0.0000, -0.0302, 0.1488], [0.0000, -0.0402, 0.1351], [0.0000, -0.0341, 0.1116]], | ||||||||
] | ||||||||
).to(torch_device) | ||||||||
|
||||||||
tokenizer = AutoTokenizer.from_pretrained(model_id) | ||||||||
model = CohereForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True, torch_dtype=torch.float16).to( | ||||||||
torch_device | ||||||||
) | ||||||||
|
||||||||
tokenizer.pad_token = tokenizer.eos_token | ||||||||
|
||||||||
text = ["Hello today I am going to show you how to", "Hi there, here we are"] | ||||||||
inputs = tokenizer(text, return_tensors="pt", padding=True).to(torch_device) | ||||||||
|
||||||||
with torch.no_grad(): | ||||||||
output = model(**inputs) | ||||||||
|
||||||||
logits = output.logits | ||||||||
self.assertTrue(torch.allclose(EXPECTED_LOGITS, logits[:, :3, :3], rtol=1e-3, atol=1e-3)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.