Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Perceiver IO #14487

Merged
merged 147 commits into from
Dec 8, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
147 commits
Select commit Hold shift + click to select a range
beef8c1
First draft
NielsRogge Aug 2, 2021
28f9541
Style and remove mlm
NielsRogge Sep 6, 2021
7f70799
Make forward pass work
NielsRogge Sep 6, 2021
7574fc0
More improvements
NielsRogge Sep 6, 2021
77d55ec
More improvements
NielsRogge Sep 7, 2021
bdccd62
Fix bug
NielsRogge Sep 7, 2021
7b7dcd2
More improvements
NielsRogge Sep 7, 2021
25d7725
More improvements
NielsRogge Sep 7, 2021
4a804b6
Add PerceiverTokenizer first draft
NielsRogge Sep 8, 2021
9a84428
Improve conversion script
NielsRogge Sep 8, 2021
65e4edd
More improvements
NielsRogge Sep 8, 2021
649c66a
Make conversion script work for the encoder
NielsRogge Sep 8, 2021
df1c0c9
Make conversion script work with local pickle files
NielsRogge Sep 8, 2021
6a8a981
Style & quality, fix-copies
NielsRogge Sep 8, 2021
79b3f9d
Add dummy input to conversion script
NielsRogge Sep 8, 2021
6d1fb56
Add absolute position embeddings to TextPreProcessor
NielsRogge Sep 8, 2021
9ef09dc
Make forward pass of encoder work
NielsRogge Sep 9, 2021
8e15a42
More improvements
NielsRogge Sep 10, 2021
8852bd6
Move text preprocessor to separate script
NielsRogge Sep 10, 2021
e003753
More improvements
NielsRogge Sep 10, 2021
cfe4d01
More improvements
NielsRogge Sep 10, 2021
2eb4869
Add post processor
NielsRogge Sep 10, 2021
091903e
Make MLM model work
NielsRogge Sep 10, 2021
4f6c31d
Style
NielsRogge Sep 10, 2021
edaf54d
Add PerceiverForMaskedLM
NielsRogge Sep 10, 2021
5a1dea3
Add PerceiverImagePreprocessor
NielsRogge Sep 13, 2021
af33282
Make style
NielsRogge Sep 13, 2021
63b556a
Make PerceiverForImageClassification work
NielsRogge Sep 13, 2021
54d5335
More improvements
NielsRogge Sep 14, 2021
853268e
More improvements
NielsRogge Sep 14, 2021
d579251
Use tokenizer in conversion script
NielsRogge Sep 14, 2021
e8a8772
Use PerceiverForMaskedLM in conversion script
NielsRogge Sep 14, 2021
f8293b9
Define custom PerceiverModelOutput
NielsRogge Sep 14, 2021
3a62362
Improve PerceiverAttention to make it work for both MLM and image cla…
NielsRogge Sep 14, 2021
7795f6d
More improvements
NielsRogge Sep 14, 2021
2c3342f
More improvements
NielsRogge Sep 15, 2021
3151607
More improvements to the conversion script
NielsRogge Sep 15, 2021
a2e6b0e
Make conversion script work for both MLM and image classification
NielsRogge Sep 15, 2021
c1dbe7c
Add PerceiverFeatureExtractor
NielsRogge Sep 15, 2021
e6d9122
More improvements
NielsRogge Sep 15, 2021
cfd32c6
Style and quality
NielsRogge Sep 15, 2021
07b090f
Add center cropping
NielsRogge Sep 15, 2021
4cd722c
Fix bug
NielsRogge Sep 15, 2021
4ed297e
Small fix
NielsRogge Sep 15, 2021
8d4b748
Add print statement
NielsRogge Sep 15, 2021
2bb92b7
Fix bug in image preprocessor
NielsRogge Sep 15, 2021
4248229
Fix bug with conversion script
NielsRogge Sep 15, 2021
a7f75a2
Make output position embeddings an nn.Parameter layer instead of nn.E…
NielsRogge Sep 15, 2021
4592338
Comment out print statements
NielsRogge Sep 16, 2021
dd91215
Add position encoding classes
NielsRogge Sep 16, 2021
ac82fce
More improvements
NielsRogge Sep 16, 2021
b369c09
Use position_encoding_kwargs
NielsRogge Sep 17, 2021
7d1863f
Add PerceiverForImageClassificationFourier
NielsRogge Sep 17, 2021
e77c6b4
Make style & quality
NielsRogge Sep 17, 2021
0a7c3f0
Add PerceiverForImageClassificationConvProcessing
NielsRogge Sep 17, 2021
d3bcf09
Style & quality
NielsRogge Sep 17, 2021
0e4241c
Add flow model
NielsRogge Sep 18, 2021
92c7c62
Move processors to modeling file
NielsRogge Sep 20, 2021
9933942
Make position encodings modular
NielsRogge Sep 20, 2021
00d2ce3
Make basic decoder use modular position encodings
NielsRogge Sep 20, 2021
f1276f8
Add PerceiverForOpticalFlow to conversion script
NielsRogge Sep 20, 2021
15ded27
Add AudioPreprocessor
NielsRogge Sep 21, 2021
1347c20
Make it possible for the basic decoder to use Fourier position embedd…
NielsRogge Sep 21, 2021
8bb1289
Add PerceiverForMultimodalAutoencoding
NielsRogge Sep 21, 2021
8c5d100
Improve model for optical flow
NielsRogge Sep 22, 2021
5dbea95
Improve _build_network_inputs method
NielsRogge Sep 22, 2021
5472500
Add print statement
NielsRogge Sep 22, 2021
fea12e6
Fix device issue
NielsRogge Sep 22, 2021
3daed24
Fix device of Fourier embeddings
NielsRogge Sep 23, 2021
a45c064
Add print statements for debugging
NielsRogge Sep 23, 2021
1e7b1c9
Add another print statement
NielsRogge Sep 23, 2021
8c0f886
Add another print statement
NielsRogge Sep 23, 2021
32cca82
Add another print statement
NielsRogge Sep 23, 2021
f1c3720
Add another print statement
NielsRogge Sep 23, 2021
275a59f
Improve PerceiverAudioPreprocessor
NielsRogge Sep 24, 2021
aedb68e
Improve conversion script for multimodal modal
NielsRogge Sep 24, 2021
adc1205
More improvements
NielsRogge Sep 24, 2021
89da95d
More improvements
NielsRogge Sep 25, 2021
a7f4870
Improve multimodal model
NielsRogge Sep 27, 2021
54021d3
Make forward pass multimodal model work
NielsRogge Sep 28, 2021
327d16c
More improvements
NielsRogge Sep 29, 2021
f3a2d0c
Improve tests
NielsRogge Oct 6, 2021
1f34526
Fix some more tests
NielsRogge Oct 6, 2021
7c4cbbc
Add output dataclasses
NielsRogge Oct 6, 2021
2a4dab2
Make more tests pass
NielsRogge Oct 7, 2021
1205dd9
Add print statements for debuggin
NielsRogge Oct 7, 2021
4408a69
Add tests for image classification
NielsRogge Oct 7, 2021
1a60c6a
Add PerceiverClassifierOutput
NielsRogge Oct 7, 2021
0a1bfcd
More improvements
NielsRogge Oct 7, 2021
27f7190
Make more tests pass for the optical flow model
NielsRogge Oct 7, 2021
6815bf7
Make style & quality
NielsRogge Oct 7, 2021
d7fedc7
Small improvements
NielsRogge Oct 7, 2021
06839cb
Don't support training for optical flow model for now
NielsRogge Oct 11, 2021
5acb88c
Fix _prepare_for_class for tests
NielsRogge Oct 11, 2021
db7b6bb
Make more tests pass, add some docs
NielsRogge Oct 12, 2021
0264043
Add multimodal model to tests
NielsRogge Oct 12, 2021
107c971
Minor fixes
NielsRogge Nov 3, 2021
ed7d7ea
Fix tests
NielsRogge Nov 4, 2021
f62a6f5
Improve conversion script
NielsRogge Nov 4, 2021
d32808b
Make fixup
NielsRogge Nov 4, 2021
08b67de
Remove pos_dim argument
NielsRogge Nov 4, 2021
e7f8329
Fix device issue
NielsRogge Nov 4, 2021
0a93591
Potential fix for OOM
NielsRogge Nov 4, 2021
1091cfe
Revert previous commit
NielsRogge Nov 4, 2021
4c10a9d
Fix test_initialization
NielsRogge Nov 5, 2021
06c7b06
Add print statements for debugging
NielsRogge Nov 5, 2021
adfda8f
Fix print statement
NielsRogge Nov 5, 2021
927dd92
Add print statement
NielsRogge Nov 5, 2021
786f57f
Add print statement
NielsRogge Nov 5, 2021
bde8cf3
Add print statement
NielsRogge Nov 5, 2021
d832391
Add print statement
NielsRogge Nov 8, 2021
8aa3228
Add print statement
NielsRogge Nov 8, 2021
5a84a3e
Add print statement
NielsRogge Nov 8, 2021
8887f98
Remove need for output_shape
NielsRogge Nov 8, 2021
f9800c5
Comment out output_shape
NielsRogge Nov 8, 2021
134bfc4
Remove unnecessary code
NielsRogge Nov 8, 2021
d5187fb
Improve docs
NielsRogge Nov 10, 2021
e9003fb
Fix make fixup
NielsRogge Nov 19, 2021
d965bca
Remove PerceiverTextProcessor from init
NielsRogge Nov 19, 2021
42630e7
Improve docs
NielsRogge Nov 19, 2021
29037ba
Small improvement
NielsRogge Nov 22, 2021
4a2b81a
Apply first batch of suggestions from code review
NielsRogge Nov 30, 2021
3235318
Apply more suggestions from code review
NielsRogge Nov 30, 2021
22becd9
Update docstrings
NielsRogge Nov 30, 2021
dc95e00
Define dicts beforehand for readability
NielsRogge Nov 30, 2021
31ae669
Rename task to architecture in conversion script, include PerceiverMo…
NielsRogge Dec 1, 2021
fa41b1a
Add print statements for debugging
NielsRogge Dec 1, 2021
a3f16f2
Fix tests on GPU
NielsRogge Dec 1, 2021
afcb875
Remove preprocessors, postprocessors and decoders from main init
NielsRogge Dec 1, 2021
c5e3af7
Add integration test
NielsRogge Dec 1, 2021
dc68fed
Fix docs
NielsRogge Dec 1, 2021
ffc6fde
Replace einops by torch
NielsRogge Dec 2, 2021
83a6776
Update for new docs frontend
NielsRogge Dec 2, 2021
46c8e04
Rename PerceiverForImageClassification
NielsRogge Dec 2, 2021
a358e38
Improve docs
NielsRogge Dec 2, 2021
c5ae758
Improve docs
NielsRogge Dec 2, 2021
48503c0
Improve docs of PerceiverModel
NielsRogge Dec 2, 2021
ec0e016
Fix some more tests
NielsRogge Dec 3, 2021
da79d8a
Improve center_crop
NielsRogge Dec 3, 2021
2a3c57c
Add PerceiverForSequenceClassification
NielsRogge Dec 3, 2021
60eefd7
Small improvements
NielsRogge Dec 6, 2021
b36ba76
Fix tests
NielsRogge Dec 6, 2021
e8cf21a
Add integration test for optical flow model
NielsRogge Dec 7, 2021
e084c05
Clean up
NielsRogge Dec 7, 2021
d1c0245
Add tests for tokenizer
NielsRogge Dec 7, 2021
520f132
Fix tokenizer by adding special tokens properly
NielsRogge Dec 8, 2021
cf534be
Fix CI
NielsRogge Dec 8, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add PerceiverForSequenceClassification
  • Loading branch information
NielsRogge committed Dec 3, 2021
commit 2a3c57c3b3f9b6802079cd83673af077a34a4dbb
7 changes: 7 additions & 0 deletions docs/source/model_doc/perceiver.rst
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,13 @@ PerceiverForMaskedLM
:members: forward


PerceiverForSequenceClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.PerceiverForSequenceClassification
:members: forward


PerceiverForImageClassificationLearned
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1140,6 +1140,7 @@
"PerceiverForMaskedLM",
"PerceiverForMultimodalAutoencoding",
"PerceiverForOpticalFlow",
"PerceiverForSequenceClassification",
"PerceiverLayer",
"PerceiverModel",
"PerceiverPreTrainedModel",
Expand Down Expand Up @@ -3008,6 +3009,7 @@
PerceiverForMaskedLM,
PerceiverForMultimodalAutoencoding,
PerceiverForOpticalFlow,
PerceiverForSequenceClassification,
PerceiverLayer,
PerceiverModel,
PerceiverPreTrainedModel,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,7 @@
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
# Model for Sequence Classification mapping
("perceiver", "PerceiverForSequenceClassification"),
("qdqbert", "QDQBertForSequenceClassification"),
("fnet", "FNetForSequenceClassification"),
("gptj", "GPTJForSequenceClassification"),
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/perceiver/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
"PerceiverForMaskedLM",
"PerceiverForMultimodalAutoencoding",
"PerceiverForOpticalFlow",
"PerceiverForSequenceClassification",
"PerceiverLayer",
"PerceiverModel",
"PerceiverPreTrainedModel",
Expand All @@ -59,6 +60,7 @@
PerceiverForMaskedLM,
PerceiverForMultimodalAutoencoding,
PerceiverForOpticalFlow,
PerceiverForSequenceClassification,
PerceiverLayer,
PerceiverModel,
PerceiverPreTrainedModel,
Expand Down
109 changes: 109 additions & 0 deletions src/transformers/models/perceiver/modeling_perceiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -956,6 +956,115 @@ def forward(
)


@add_start_docstrings("""Example use of Perceiver for text classification. """, PERCEIVER_START_DOCSTRING)
class PerceiverForSequenceClassification(PerceiverPreTrainedModel):
def __init__(self, config):
super().__init__(config)

trainable_position_encoding_kwargs_decoder = dict(num_channels=config.d_latents, index_dims=1)

self.num_labels = config.num_labels
self.perceiver = PerceiverModel(
config,
input_preprocessor=PerceiverTextPreprocessor(config),
decoder=PerceiverClassificationDecoder(
config,
num_channels=config.d_latents,
trainable_position_encoding_kwargs=trainable_position_encoding_kwargs_decoder,
use_query_residual=True,
),
)

# Initialize weights and apply final processing
self.post_init()

@add_start_docstrings_to_model_forward(PERCEIVER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=PerceiverClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
inputs=None,
attention_mask=None,
head_mask=None,
output_attentions=None,
output_hidden_states=None,
labels=None,
return_dict=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
Labels for computing the classification/regression loss. Indices should be in :obj:`[0, ...,
config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).

Returns:

Examples::

>>> from transformers import PerceiverTokenizer, PerceiverForSequenceClassification

>>> tokenizer = PerceiverTokenizer.from_pretrained('deepmind/vision-perceiver')
>>> model = PerceiverForSequenceClassification.from_pretrained('deepmind/vision-perceiver')

>>> text = "hello world"
>>> inputs = tokenizer(images=image, return_tensors="pt")
>>> outputs = model(**inputs)
>>> logits = outputs.logits
"""

return_dict = return_dict if return_dict is not None else self.config.use_return_dict

outputs = self.perceiver(
inputs=inputs,
attention_mask=attention_mask,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)

logits = outputs.logits if return_dict else outputs[0]

loss = None
if labels is not None:
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"

if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)

if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output

return PerceiverClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
cross_attentions=outputs.cross_attentions,
)


@add_start_docstrings(
"""
Example use of Perceiver for image classification, for tasks such as ImageNet.
Expand Down
12 changes: 12 additions & 0 deletions src/transformers/utils/dummy_pt_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -3769,6 +3769,18 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])


class PerceiverForSequenceClassification:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])

@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])

def forward(self, *args, **kwargs):
requires_backends(self, ["torch"])


class PerceiverLayer:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
Expand Down
108 changes: 82 additions & 26 deletions tests/test_modeling_perceiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from transformers import PerceiverConfig
from transformers.file_utils import is_torch_available, is_vision_available
from transformers.models.auto import get_values
from transformers.models.perceiver.modeling_perceiver import PerceiverForSequenceClassification
from transformers.testing_utils import require_torch, slow, torch_device

from .test_configuration_common import ConfigTester
Expand Down Expand Up @@ -137,7 +138,7 @@ def prepare_config_and_inputs(self, model_class=None):
if model_class is None or model_class.__name__ == "PerceiverModel":
inputs = floats_tensor([self.batch_size, self.seq_length, config.d_model], self.vocab_size)
return config, inputs, input_mask, sequence_labels, token_labels
elif model_class.__name__ == "PerceiverForMaskedLM":
elif model_class.__name__ in ["PerceiverForMaskedLM", "PerceiverForSequenceClassification"]:
inputs = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
# input mask is only relevant for text inputs
if self.use_input_mask:
Expand Down Expand Up @@ -171,33 +172,33 @@ def prepare_config_and_inputs(self, model_class=None):

return config, inputs, input_mask, sequence_labels, token_labels

def prepare_config_and_inputs_masked_lm(self):
inputs = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
# def prepare_config_and_inputs_masked_lm(self):
# inputs = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)

input_mask = None
if self.use_input_mask:
input_mask = random_attention_mask([self.batch_size, self.seq_length])
# input_mask = None
# if self.use_input_mask:
# input_mask = random_attention_mask([self.batch_size, self.seq_length])

token_labels = None
if self.use_labels:
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
# token_labels = None
# if self.use_labels:
# token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)

config = self.get_config()
# config = self.get_config()

return config, inputs, input_mask, token_labels
# return config, inputs, input_mask, token_labels

def prepare_config_and_inputs_image_classification(self):
inputs = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
# def prepare_config_and_inputs_classification(self):
# inputs = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])

input_mask = None
# input_mask = None

image_labels = None
if self.use_labels:
image_labels = ids_tensor([self.batch_size], self.num_labels)
# classification_labels = None
# if self.use_labels:
# classification_labels = ids_tensor([self.batch_size], self.num_labels)

config = self.get_config()
# config = self.get_config()

return config, inputs, input_mask, image_labels
# return config, inputs, input_mask, classification_labels

def get_config(self):
return PerceiverConfig(
Expand All @@ -220,21 +221,56 @@ def get_config(self):
num_labels=self.num_labels,
)

def create_and_check_for_masked_lm(self, config, inputs, input_mask, token_labels):
def create_and_check_for_masked_lm(self, config, inputs, input_mask, sequence_labels, token_labels):
model = PerceiverForMaskedLM(config=config)
model.to(torch_device)
model.eval()
result = model(inputs, attention_mask=input_mask, labels=token_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))

def create_and_check_for_image_classification(self, config, inputs, input_mask, image_labels):
def create_and_check_for_sequence_classification(self, config, inputs, input_mask, sequence_labels, token_labels):
# set num_labels
config.num_labels = self.num_labels
model = PerceiverForSequenceClassification(config=config)
model.to(torch_device)
model.eval()
result = model(inputs, attention_mask=input_mask, labels=sequence_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))

def create_and_check_for_image_classification_learned(
self, config, inputs, input_mask, sequence_labels, token_labels
):
# set d_model and num_labels
config.d_model = 512
config.num_labels = self.num_labels
model = PerceiverForImageClassificationLearned(config=config)
model.to(torch_device)
model.eval()
result = model(inputs, attention_mask=input_mask, labels=image_labels)
result = model(inputs, attention_mask=input_mask, labels=sequence_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))

def create_and_check_for_image_classification_fourier(
self, config, inputs, input_mask, sequence_labels, token_labels
):
# set d_model and num_labels
config.d_model = 261
config.num_labels = self.num_labels
model = PerceiverForImageClassificationFourier(config=config)
model.to(torch_device)
model.eval()
result = model(inputs, attention_mask=input_mask, labels=sequence_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))

def create_and_check_for_image_classification_conv(
self, config, inputs, input_mask, sequence_labels, token_labels
):
# set d_model and num_labels
config.d_model = 322
config.num_labels = self.num_labels
model = PerceiverForImageClassificationConvProcessing(config=config)
model.to(torch_device)
model.eval()
result = model(inputs, attention_mask=input_mask, labels=sequence_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))

def prepare_config_and_inputs_for_common(self):
Expand Down Expand Up @@ -263,6 +299,7 @@ class PerceiverModelTest(ModelTesterMixin, unittest.TestCase):
PerceiverForImageClassificationFourier,
PerceiverForOpticalFlow,
PerceiverForMultimodalAutoencoding,
PerceiverForSequenceClassification,
)
if is_torch_available()
else ()
Expand Down Expand Up @@ -309,12 +346,30 @@ def test_config(self):
self.config_tester.check_config_can_be_init_without_params()

def test_for_masked_lm(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs_masked_lm()
config_and_inputs = self.model_tester.prepare_config_and_inputs(model_class=PerceiverForMaskedLM)
self.model_tester.create_and_check_for_masked_lm(*config_and_inputs)

def test_for_image_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs_image_classification()
self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
def test_for_sequence_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs(model_class=PerceiverForSequenceClassification)
self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs)

def test_for_image_classification_learned(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs(
model_class=PerceiverForImageClassificationLearned
)
self.model_tester.create_and_check_for_image_classification_learned(*config_and_inputs)

def test_for_image_classification_fourier(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs(
model_class=PerceiverForImageClassificationFourier
)
self.model_tester.create_and_check_for_image_classification_fourier(*config_and_inputs)

def test_for_image_classification_conv(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs(
model_class=PerceiverForImageClassificationConvProcessing
)
self.model_tester.create_and_check_for_image_classification_conv(*config_and_inputs)

def test_model_common_attributes(self):
for model_class in self.all_model_classes:
Expand Down Expand Up @@ -676,6 +731,7 @@ def test_correct_missing_keys(self):
if model_class in [
PerceiverForOpticalFlow,
PerceiverForMultimodalAutoencoding,
*get_values(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING),
*get_values(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING),
]:
continue
Expand Down