Skip to content

Commit

Permalink
integration testing
Browse files Browse the repository at this point in the history
  • Loading branch information
syncdoth committed Dec 23, 2023
1 parent c162de3 commit b624d05
Showing 1 changed file with 95 additions and 13 deletions.
108 changes: 95 additions & 13 deletions tests/models/nucleus_x/test_modeling_nucleus_x.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,11 @@
""" Testing suite for the PyTorch NucleusX model. """


import gc
import unittest

from transformers import NucleusXConfig, is_torch_available
from transformers.testing_utils import require_torch, torch_device
from transformers import AutoTokenizer, NucleusXConfig, is_torch_available
from transformers.testing_utils import backend_empty_cache, require_torch, slow, torch_device

from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
Expand Down Expand Up @@ -115,7 +116,15 @@ def prepare_config_and_inputs(self):

config = self.get_config()

return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
return (
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
)

def get_config(self):
return NucleusXConfig(
Expand All @@ -135,14 +144,24 @@ def get_config(self):
)

def create_and_check_model(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
self,
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
):
model = NucleusXModel(config=config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=input_mask)
result = model(input_ids)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
self.parent.assertEqual(
result.last_hidden_state.shape,
(self.batch_size, self.seq_length, self.hidden_size),
)

def create_and_check_model_as_decoder(
self,
Expand Down Expand Up @@ -172,7 +191,10 @@ def create_and_check_model_as_decoder(
encoder_hidden_states=encoder_hidden_states,
)
result = model(input_ids, attention_mask=input_mask)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
self.parent.assertEqual(
result.last_hidden_state.shape,
(self.batch_size, self.seq_length, self.hidden_size),
)

# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_for_causal_lm with Llama->NucleusX
def create_and_check_for_causal_lm(
Expand Down Expand Up @@ -312,7 +334,10 @@ def test_nucleus_x_sequence_classification_model(self):
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
self.assertEqual(
result.logits.shape,
(self.model_tester.batch_size, self.model_tester.num_labels),
)

def test_nucleus_x_sequence_classification_model_for_single_label(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
Expand All @@ -325,7 +350,10 @@ def test_nucleus_x_sequence_classification_model_for_single_label(self):
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
self.assertEqual(
result.logits.shape,
(self.model_tester.batch_size, self.model_tester.num_labels),
)

def test_nucleus_x_sequence_classification_model_for_multi_label(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
Expand All @@ -334,17 +362,24 @@ def test_nucleus_x_sequence_classification_model_for_multi_label(self):
input_ids = input_dict["input_ids"]
attention_mask = input_ids.ne(1).to(torch_device)
sequence_labels = ids_tensor(
[self.model_tester.batch_size, config.num_labels], self.model_tester.type_sequence_label_size
[self.model_tester.batch_size, config.num_labels],
self.model_tester.type_sequence_label_size,
).to(torch.float)
model = NucleusXForSequenceClassification(config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
self.assertEqual(
result.logits.shape,
(self.model_tester.batch_size, self.model_tester.num_labels),
)

def test_nucleus_x_parallel_recurrent(self):
for _ in range(10):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
(
config,
input_dict,
) = self.model_tester.prepare_config_and_inputs_for_common()
input_ids = input_dict["input_ids"]
model = NucleusXForCausalLM(config)
model.to(torch_device)
Expand All @@ -357,7 +392,10 @@ def test_nucleus_x_parallel_recurrent(self):
past_kv = None
for i in range(input_ids.shape[1]):
rnn_out = model(
input_ids[:, : i + 1], forward_mode="recurrent", past_key_values=past_kv, use_cache=True
input_ids[:, : i + 1],
forward_mode="recurrent",
past_key_values=past_kv,
use_cache=True,
)
rnn_logits.append(rnn_out.logits)
past_kv = rnn_out.past_key_values
Expand All @@ -379,7 +417,10 @@ def test_nucleus_x_parallel_recurrent(self):
def test_nucleus_x_parallel_chunkwise(self):
logit_success = False
for _ in range(10):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
(
config,
input_dict,
) = self.model_tester.prepare_config_and_inputs_for_common()
config.groupnorm_eps = 1e-16
input_ids = input_dict["input_ids"]
model = NucleusXForCausalLM(config)
Expand All @@ -395,3 +436,44 @@ def test_nucleus_x_parallel_chunkwise(self):
@unittest.skip("NucleusX uses dictionary style KV cache, which is a non standard format")
def test_past_key_values_format(self):
pass


@require_torch
class NucleusXIntegrationTest(unittest.TestCase):
@slow
def test_model_7b_logits(self):
input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338]
model = NucleusXForCausalLM.from_pretrained("NucleusAI/Nucleus-X", device_map="auto")
input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device)
with torch.no_grad():
out = model(input_ids).logits.cpu()
# Expected mean on dim = -1
EXPECTED_MEAN = torch.tensor([[2.0853, 1.3300, 2.5745, 1.5178, 3.1419, 1.7202, 2.8694, 1.7182]])
torch.testing.assert_close(out.mean(-1), EXPECTED_MEAN, atol=1e-2, rtol=1e-2)
# slicing logits[0, 0, 0:30]
EXPECTED_SLICE = torch.tensor([-1.1098, 9.8746, 4.9376, -0.8818, -1.0313, -0.5429, -0.6475, -0.5848, -1.5682, -0.8666, -0.6847, -0.6623, 5.7973, 11.5313, -1.2270, -1.3406, -1.2220, -0.4933, -0.8791, -1.2740, -0.7890, -0.6629, -0.8539, -0.9674, -0.7105, -0.7538, -0.7392, -1.1273, -1.2821, -1.1191]) # fmt: skip
print(out[0, 0, :30])
torch.testing.assert_close(out[0, 0, :30], EXPECTED_SLICE, atol=1e-4, rtol=1e-4)

del model
backend_empty_cache(torch_device)
gc.collect()

@slow
def test_model_7b_generation(self):
EXPECTED_TEXT_COMPLETION = (
"""Hello my name is Tina and I am a 25 year old female. I am a very outgoing person"""
)
prompt = "Hello my name is"
tokenizer = AutoTokenizer.from_pretrained("NucleusAI/Nucleus-X", use_fast=False)
model = NucleusXForCausalLM.from_pretrained("NucleusAI/Nucleus-X", device_map="auto")
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.model.embed_tokens.weight.device)

# greedy generation outputs
generated_ids = model.generate(input_ids, max_new_tokens=20, temperature=0)
text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION, text)

del model
backend_empty_cache(torch_device)
gc.collect()

0 comments on commit b624d05

Please sign in to comment.