Skip to content

Commit

Permalink
skip some gpt_neox tests that require 80G RAM (huggingface#17923)
Browse files Browse the repository at this point in the history
* skip some gpt_neox tests that require 80G RAM

* remove tests

* fix quality

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
  • Loading branch information
ydshieh and ydshieh authored Jul 1, 2022
1 parent 49cd736 commit 14fb8a6
Showing 1 changed file with 1 addition and 28 deletions.
29 changes: 1 addition & 28 deletions tests/models/gpt_neox/test_modeling_gpt_neox.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import unittest

from transformers import GPTNeoXConfig, is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
from transformers.testing_utils import require_torch, torch_device

from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
Expand All @@ -28,7 +28,6 @@
import torch

from transformers import GPTNeoXForCausalLM, GPTNeoXModel
from transformers.models.gpt_neox.modeling_gpt_neox import GPT_NEOX_PRETRAINED_MODEL_ARCHIVE_LIST


class GPTNeoXModelTester:
Expand Down Expand Up @@ -229,29 +228,3 @@ def test_model_for_causal_lm(self):
@unittest.skip(reason="Feed forward chunking is not implemented")
def test_feed_forward_chunking(self):
pass

@slow
def test_model_from_pretrained(self):
for model_name in GPT_NEOX_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
model = GPTNeoXModel.from_pretrained(model_name)
self.assertIsNotNone(model)


@require_torch
class GPTNeoXModelIntegrationTest(unittest.TestCase):
@slow
def test_inference_masked_lm(self):
model = GPTNeoXForCausalLM.from_pretrained("EleutherAI/gpt-neox-20b")
input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]])
output = model(input_ids)[0]

vocab_size = model.config.vocab_size

expected_shape = torch.Size((1, 6, vocab_size))
self.assertEqual(output.shape, expected_shape)

expected_slice = torch.tensor(
[[[33.5938, 2.3789, 34.0312], [63.4688, 4.8164, 63.3438], [66.8750, 5.2422, 63.0625]]]
)

self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4))

0 comments on commit 14fb8a6

Please sign in to comment.