Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Perceiver IO #14487

Merged
merged 147 commits into from
Dec 8, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
147 commits
Select commit Hold shift + click to select a range
beef8c1
First draft
NielsRogge Aug 2, 2021
28f9541
Style and remove mlm
NielsRogge Sep 6, 2021
7f70799
Make forward pass work
NielsRogge Sep 6, 2021
7574fc0
More improvements
NielsRogge Sep 6, 2021
77d55ec
More improvements
NielsRogge Sep 7, 2021
bdccd62
Fix bug
NielsRogge Sep 7, 2021
7b7dcd2
More improvements
NielsRogge Sep 7, 2021
25d7725
More improvements
NielsRogge Sep 7, 2021
4a804b6
Add PerceiverTokenizer first draft
NielsRogge Sep 8, 2021
9a84428
Improve conversion script
NielsRogge Sep 8, 2021
65e4edd
More improvements
NielsRogge Sep 8, 2021
649c66a
Make conversion script work for the encoder
NielsRogge Sep 8, 2021
df1c0c9
Make conversion script work with local pickle files
NielsRogge Sep 8, 2021
6a8a981
Style & quality, fix-copies
NielsRogge Sep 8, 2021
79b3f9d
Add dummy input to conversion script
NielsRogge Sep 8, 2021
6d1fb56
Add absolute position embeddings to TextPreProcessor
NielsRogge Sep 8, 2021
9ef09dc
Make forward pass of encoder work
NielsRogge Sep 9, 2021
8e15a42
More improvements
NielsRogge Sep 10, 2021
8852bd6
Move text preprocessor to separate script
NielsRogge Sep 10, 2021
e003753
More improvements
NielsRogge Sep 10, 2021
cfe4d01
More improvements
NielsRogge Sep 10, 2021
2eb4869
Add post processor
NielsRogge Sep 10, 2021
091903e
Make MLM model work
NielsRogge Sep 10, 2021
4f6c31d
Style
NielsRogge Sep 10, 2021
edaf54d
Add PerceiverForMaskedLM
NielsRogge Sep 10, 2021
5a1dea3
Add PerceiverImagePreprocessor
NielsRogge Sep 13, 2021
af33282
Make style
NielsRogge Sep 13, 2021
63b556a
Make PerceiverForImageClassification work
NielsRogge Sep 13, 2021
54d5335
More improvements
NielsRogge Sep 14, 2021
853268e
More improvements
NielsRogge Sep 14, 2021
d579251
Use tokenizer in conversion script
NielsRogge Sep 14, 2021
e8a8772
Use PerceiverForMaskedLM in conversion script
NielsRogge Sep 14, 2021
f8293b9
Define custom PerceiverModelOutput
NielsRogge Sep 14, 2021
3a62362
Improve PerceiverAttention to make it work for both MLM and image cla…
NielsRogge Sep 14, 2021
7795f6d
More improvements
NielsRogge Sep 14, 2021
2c3342f
More improvements
NielsRogge Sep 15, 2021
3151607
More improvements to the conversion script
NielsRogge Sep 15, 2021
a2e6b0e
Make conversion script work for both MLM and image classification
NielsRogge Sep 15, 2021
c1dbe7c
Add PerceiverFeatureExtractor
NielsRogge Sep 15, 2021
e6d9122
More improvements
NielsRogge Sep 15, 2021
cfd32c6
Style and quality
NielsRogge Sep 15, 2021
07b090f
Add center cropping
NielsRogge Sep 15, 2021
4cd722c
Fix bug
NielsRogge Sep 15, 2021
4ed297e
Small fix
NielsRogge Sep 15, 2021
8d4b748
Add print statement
NielsRogge Sep 15, 2021
2bb92b7
Fix bug in image preprocessor
NielsRogge Sep 15, 2021
4248229
Fix bug with conversion script
NielsRogge Sep 15, 2021
a7f75a2
Make output position embeddings an nn.Parameter layer instead of nn.E…
NielsRogge Sep 15, 2021
4592338
Comment out print statements
NielsRogge Sep 16, 2021
dd91215
Add position encoding classes
NielsRogge Sep 16, 2021
ac82fce
More improvements
NielsRogge Sep 16, 2021
b369c09
Use position_encoding_kwargs
NielsRogge Sep 17, 2021
7d1863f
Add PerceiverForImageClassificationFourier
NielsRogge Sep 17, 2021
e77c6b4
Make style & quality
NielsRogge Sep 17, 2021
0a7c3f0
Add PerceiverForImageClassificationConvProcessing
NielsRogge Sep 17, 2021
d3bcf09
Style & quality
NielsRogge Sep 17, 2021
0e4241c
Add flow model
NielsRogge Sep 18, 2021
92c7c62
Move processors to modeling file
NielsRogge Sep 20, 2021
9933942
Make position encodings modular
NielsRogge Sep 20, 2021
00d2ce3
Make basic decoder use modular position encodings
NielsRogge Sep 20, 2021
f1276f8
Add PerceiverForOpticalFlow to conversion script
NielsRogge Sep 20, 2021
15ded27
Add AudioPreprocessor
NielsRogge Sep 21, 2021
1347c20
Make it possible for the basic decoder to use Fourier position embedd…
NielsRogge Sep 21, 2021
8bb1289
Add PerceiverForMultimodalAutoencoding
NielsRogge Sep 21, 2021
8c5d100
Improve model for optical flow
NielsRogge Sep 22, 2021
5dbea95
Improve _build_network_inputs method
NielsRogge Sep 22, 2021
5472500
Add print statement
NielsRogge Sep 22, 2021
fea12e6
Fix device issue
NielsRogge Sep 22, 2021
3daed24
Fix device of Fourier embeddings
NielsRogge Sep 23, 2021
a45c064
Add print statements for debugging
NielsRogge Sep 23, 2021
1e7b1c9
Add another print statement
NielsRogge Sep 23, 2021
8c0f886
Add another print statement
NielsRogge Sep 23, 2021
32cca82
Add another print statement
NielsRogge Sep 23, 2021
f1c3720
Add another print statement
NielsRogge Sep 23, 2021
275a59f
Improve PerceiverAudioPreprocessor
NielsRogge Sep 24, 2021
aedb68e
Improve conversion script for multimodal modal
NielsRogge Sep 24, 2021
adc1205
More improvements
NielsRogge Sep 24, 2021
89da95d
More improvements
NielsRogge Sep 25, 2021
a7f4870
Improve multimodal model
NielsRogge Sep 27, 2021
54021d3
Make forward pass multimodal model work
NielsRogge Sep 28, 2021
327d16c
More improvements
NielsRogge Sep 29, 2021
f3a2d0c
Improve tests
NielsRogge Oct 6, 2021
1f34526
Fix some more tests
NielsRogge Oct 6, 2021
7c4cbbc
Add output dataclasses
NielsRogge Oct 6, 2021
2a4dab2
Make more tests pass
NielsRogge Oct 7, 2021
1205dd9
Add print statements for debuggin
NielsRogge Oct 7, 2021
4408a69
Add tests for image classification
NielsRogge Oct 7, 2021
1a60c6a
Add PerceiverClassifierOutput
NielsRogge Oct 7, 2021
0a1bfcd
More improvements
NielsRogge Oct 7, 2021
27f7190
Make more tests pass for the optical flow model
NielsRogge Oct 7, 2021
6815bf7
Make style & quality
NielsRogge Oct 7, 2021
d7fedc7
Small improvements
NielsRogge Oct 7, 2021
06839cb
Don't support training for optical flow model for now
NielsRogge Oct 11, 2021
5acb88c
Fix _prepare_for_class for tests
NielsRogge Oct 11, 2021
db7b6bb
Make more tests pass, add some docs
NielsRogge Oct 12, 2021
0264043
Add multimodal model to tests
NielsRogge Oct 12, 2021
107c971
Minor fixes
NielsRogge Nov 3, 2021
ed7d7ea
Fix tests
NielsRogge Nov 4, 2021
f62a6f5
Improve conversion script
NielsRogge Nov 4, 2021
d32808b
Make fixup
NielsRogge Nov 4, 2021
08b67de
Remove pos_dim argument
NielsRogge Nov 4, 2021
e7f8329
Fix device issue
NielsRogge Nov 4, 2021
0a93591
Potential fix for OOM
NielsRogge Nov 4, 2021
1091cfe
Revert previous commit
NielsRogge Nov 4, 2021
4c10a9d
Fix test_initialization
NielsRogge Nov 5, 2021
06c7b06
Add print statements for debugging
NielsRogge Nov 5, 2021
adfda8f
Fix print statement
NielsRogge Nov 5, 2021
927dd92
Add print statement
NielsRogge Nov 5, 2021
786f57f
Add print statement
NielsRogge Nov 5, 2021
bde8cf3
Add print statement
NielsRogge Nov 5, 2021
d832391
Add print statement
NielsRogge Nov 8, 2021
8aa3228
Add print statement
NielsRogge Nov 8, 2021
5a84a3e
Add print statement
NielsRogge Nov 8, 2021
8887f98
Remove need for output_shape
NielsRogge Nov 8, 2021
f9800c5
Comment out output_shape
NielsRogge Nov 8, 2021
134bfc4
Remove unnecessary code
NielsRogge Nov 8, 2021
d5187fb
Improve docs
NielsRogge Nov 10, 2021
e9003fb
Fix make fixup
NielsRogge Nov 19, 2021
d965bca
Remove PerceiverTextProcessor from init
NielsRogge Nov 19, 2021
42630e7
Improve docs
NielsRogge Nov 19, 2021
29037ba
Small improvement
NielsRogge Nov 22, 2021
4a2b81a
Apply first batch of suggestions from code review
NielsRogge Nov 30, 2021
3235318
Apply more suggestions from code review
NielsRogge Nov 30, 2021
22becd9
Update docstrings
NielsRogge Nov 30, 2021
dc95e00
Define dicts beforehand for readability
NielsRogge Nov 30, 2021
31ae669
Rename task to architecture in conversion script, include PerceiverMo…
NielsRogge Dec 1, 2021
fa41b1a
Add print statements for debugging
NielsRogge Dec 1, 2021
a3f16f2
Fix tests on GPU
NielsRogge Dec 1, 2021
afcb875
Remove preprocessors, postprocessors and decoders from main init
NielsRogge Dec 1, 2021
c5e3af7
Add integration test
NielsRogge Dec 1, 2021
dc68fed
Fix docs
NielsRogge Dec 1, 2021
ffc6fde
Replace einops by torch
NielsRogge Dec 2, 2021
83a6776
Update for new docs frontend
NielsRogge Dec 2, 2021
46c8e04
Rename PerceiverForImageClassification
NielsRogge Dec 2, 2021
a358e38
Improve docs
NielsRogge Dec 2, 2021
c5ae758
Improve docs
NielsRogge Dec 2, 2021
48503c0
Improve docs of PerceiverModel
NielsRogge Dec 2, 2021
ec0e016
Fix some more tests
NielsRogge Dec 3, 2021
da79d8a
Improve center_crop
NielsRogge Dec 3, 2021
2a3c57c
Add PerceiverForSequenceClassification
NielsRogge Dec 3, 2021
60eefd7
Small improvements
NielsRogge Dec 6, 2021
b36ba76
Fix tests
NielsRogge Dec 6, 2021
e8cf21a
Add integration test for optical flow model
NielsRogge Dec 7, 2021
e084c05
Clean up
NielsRogge Dec 7, 2021
d1c0245
Add tests for tokenizer
NielsRogge Dec 7, 2021
520f132
Fix tokenizer by adding special tokens properly
NielsRogge Dec 8, 2021
cf534be
Fix CI
NielsRogge Dec 8, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add PerceiverForImageClassificationConvProcessing
  • Loading branch information
NielsRogge committed Dec 2, 2021
commit 0a7c3f0bb95d6db4d4c217cbf76bfa885efd4d84
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1139,6 +1139,7 @@
"PerceiverClassificationDecoder",
"PerceiverForImageClassification",
"PerceiverForImageClassificationFourier",
"PerceiverForImageClassificationConvProcessing",
"PerceiverForMaskedLM",
"PerceiverImagePreprocessor",
"PerceiverLayer",
Expand Down Expand Up @@ -3010,6 +3011,7 @@
PerceiverClassificationDecoder,
PerceiverForImageClassification,
PerceiverForImageClassificationFourier,
PerceiverForImageClassificationConvProcessing,
PerceiverForMaskedLM,
PerceiverImagePreprocessor,
PerceiverLayer,
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/perceiver/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
"PerceiverClassificationDecoder",
"PerceiverForImageClassification",
"PerceiverForImageClassificationFourier",
"PerceiverForImageClassificationConvProcessing",
"PerceiverForMaskedLM",
"PerceiverLayer",
"PerceiverModel",
Expand Down Expand Up @@ -62,6 +63,7 @@
PerceiverClassificationDecoder,
PerceiverForImageClassification,
PerceiverForImageClassificationFourier,
PerceiverForImageClassificationConvProcessing,
PerceiverForMaskedLM,
PerceiverLayer,
PerceiverModel,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
PerceiverFeatureExtractor,
PerceiverForImageClassification,
PerceiverForImageClassificationFourier,
PerceiverForImageClassificationConvProcessing,
PerceiverForMaskedLM,
PerceiverTokenizer,
)
Expand Down Expand Up @@ -61,7 +62,7 @@ def rename_keys(state_dict):
"trainable_position_encoding/pos_embs", "input_preprocessor.position_embeddings.weight"
)

# rename image preprocessor embeddings (for image classification model)
# rename image preprocessor embeddings (for image classification model with learned position embeddings)
name = name.replace("image_preprocessor/~/conv2_d/w", "input_preprocessor.convnet_1x1.weight")
name = name.replace("image_preprocessor/~/conv2_d/b", "input_preprocessor.convnet_1x1.bias")
name = name.replace(
Expand All @@ -77,6 +78,15 @@ def rename_keys(state_dict):
"input_preprocessor.positions_projection.bias",
)

# rename image preprocessor embeddings (for image classification model with conv processing)
if "counter" in name or "hidden" in name:
continue
name = name.replace("image_preprocessor/~/conv2_d_downsample/~/conv/w", "input_preprocessor.convnet.conv.weight")
name = name.replace("image_preprocessor/~/conv2_d_downsample/~/batchnorm/offset", "input_preprocessor.convnet.batchnorm.bias")
name = name.replace("image_preprocessor/~/conv2_d_downsample/~/batchnorm/scale", "input_preprocessor.convnet.batchnorm.weight")
name = name.replace("image_preprocessor/~/conv2_d_downsample/~/batchnorm/~/mean_ema/average", "input_preprocessor.convnet.batchnorm.running_mean")
name = name.replace("image_preprocessor/~/conv2_d_downsample/~/batchnorm/~/var_ema/average", "input_preprocessor.convnet.batchnorm.running_var")

## DECODERS ##

# rename prefix of decoders
Expand Down Expand Up @@ -167,6 +177,10 @@ def rename_keys(state_dict):
if name[-6:] == "weight" and "embeddings" not in name:
param = np.transpose(param)

# if batchnorm, we need to squeeze it
if "batchnorm" in name:
param = np.squeeze(param)

state_dict["perceiver." + name] = torch.from_numpy(param)


Expand All @@ -180,10 +194,11 @@ def convert_perceiver_checkpoint(pickle_file, pytorch_dump_folder_path, task="ML
with open(pickle_file, "rb") as f:
checkpoint = pickle.loads(f.read())

state = None
if isinstance(checkpoint, dict):
if task not in ["image_classification", "image_classification_fourier"]:
if task not in ["image_classification", "image_classification_fourier", "image_classification_conv"]:
raise ValueError("Make sure to set task to image classification")
# the image classification checkpoint with conv_preprocessing also has batchnorm state
# the image classification_conv checkpoint also has batchnorm states (running_mean and running_var)
params = checkpoint["params"]
state = checkpoint["state"]
else:
Expand All @@ -195,6 +210,12 @@ def convert_perceiver_checkpoint(pickle_file, pytorch_dump_folder_path, task="ML
for param_name, param in parameters.items():
state_dict[scope_name + "/" + param_name] = param

if state is not None:
# add state variables
for scope_name, parameters in hk.data_structures.to_mutable_dict(state).items():
for param_name, param in parameters.items():
state_dict[scope_name + "/" + param_name] = param

# rename keys
rename_keys(state_dict)

Expand Down Expand Up @@ -226,6 +247,9 @@ def convert_perceiver_checkpoint(pickle_file, pytorch_dump_folder_path, task="ML
elif task == "image_classification_fourier":
config.d_model = 261
model = PerceiverForImageClassificationFourier(config)
elif task == "image_classification_conv":
config.d_model = 322
model = PerceiverForImageClassificationConvProcessing(config)
else:
raise ValueError(f"Task {task} not supported")
else:
Expand All @@ -244,7 +268,7 @@ def convert_perceiver_checkpoint(pickle_file, pytorch_dump_folder_path, task="ML
encoding.input_ids[0, 51:60] = tokenizer.mask_token_id
inputs = encoding.input_ids
input_mask = encoding.attention_mask
elif task in ["image_classification", "image_classification_fourier"]:
elif task in ["image_classification", "image_classification_fourier", "image_classification_conv"]:
feature_extractor = PerceiverFeatureExtractor()
image = prepare_img()
encoding = feature_extractor(image, return_tensors="pt")
Expand Down Expand Up @@ -272,7 +296,7 @@ def convert_perceiver_checkpoint(pickle_file, pytorch_dump_folder_path, task="ML
print("Predicted string:")
print(tokenizer.decode(masked_tokens_predictions))

elif task in ["image_classification", "image_classification_fourier"]:
elif task in ["image_classification", "image_classification_fourier", "image_classification_conv"]:
print("Predicted class:", model.config.id2label[logits.argmax(-1).item()])

# Finally, save files
Expand Down
74 changes: 74 additions & 0 deletions src/transformers/models/perceiver/modeling_perceiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -998,6 +998,80 @@ def forward(
)


# @add_start_docstrings("""Example use of Perceiver for image classification. """, PERCEIVER_START_DOCSTRING)
class PerceiverForImageClassificationConvProcessing(PerceiverPreTrainedModel):
def __init__(self, config):
super().__init__(config)

self.num_labels = config.num_labels
self.perceiver = PerceiverModel(
config,
input_preprocessor=PerceiverImagePreprocessor(
config,
prep_type="conv",
spatial_downsample=1,
position_encoding_type="fourier",
concat_pos=True,
max_resolution=(56, 56),
num_bands=64,
sine_only=False,
),
decoder=PerceiverClassificationDecoder(config, num_channels=config.d_latents, use_query_residual=True),
)

self.init_weights()

# @add_start_docstrings_to_model_forward(PERCEIVER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
# @add_code_sample_docstrings(
# tokenizer_class=_TOKENIZER_FOR_DOC,
# checkpoint=_CHECKPOINT_FOR_DOC,
# output_type=SequenceClassifierOutput,
# config_class=_CONFIG_FOR_DOC,
# )
def forward(
self,
inputs=None,
attention_mask=None,
head_mask=None,
output_attentions=None,
output_hidden_states=None,
labels=None,
return_dict=None,
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

outputs = self.perceiver(
inputs=inputs,
attention_mask=attention_mask,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
logits = outputs.logits

loss = None
if labels is not None:
if self.num_labels == 1:
# We are doing regression
loss_fct = MSELoss()
loss = loss_fct(logits.view(-1), labels.view(-1))
else:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

NielsRogge marked this conversation as resolved.
Show resolved Hide resolved
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output

return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)


class PerceiverAbstractDecoder(nn.Module, metaclass=abc.ABCMeta):
"""Perceiver abstract decoder."""

Expand Down
86 changes: 81 additions & 5 deletions src/transformers/models/perceiver/processing_perceiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,69 @@
"""
import abc
from typing import Optional
import math
from functools import reduce
from operator import __add__

import numpy as np
import torch
import torch.nn as nn


# ------------------------------------------------------------
# ------------------- Up/down-sampling ---------------------
# ------------------------------------------------------------


class Conv2dSamePadding(nn.Conv2d):
"""Conv2d layer with padding="same" support.
Source: https://gist.github.com/sumanmichael/4de9dee93f972d47c80c4ade8e149ea6
"""

def __init__(self,*args,**kwargs):
super(Conv2dSamePadding, self).__init__(*args, **kwargs)
self.zero_pad_2d = nn.ZeroPad2d(reduce(__add__,
[(k // 2 + (k - 2 * (k // 2)) - 1, k // 2) for k in self.kernel_size[::-1]]))

def forward(self, input):
return self._conv_forward(self.zero_pad_2d(input), self.weight, self.bias)


class Conv2DDownsample(nn.Module):
"""Downsamples 4x by applying a 2D convolution and doing max pooling."""

def __init__(
self,
num_layers: int = 1,
in_channels: int = 3,
out_channels: int = 64,
use_batchnorm: bool = True,
):
"""Constructs a Conv2DDownsample model.
Args:
in_channels: The number of input channels.
out_channels: The number of conv output channels.
use_batchnorm: Whether to use batchnorm.
"""
super().__init__()

self.conv = Conv2dSamePadding(in_channels=in_channels, out_channels=out_channels,
kernel_size=7,
stride=2,
bias=False)
self.batchnorm = nn.BatchNorm2d(num_features=out_channels) if use_batchnorm else nn.Identity()
self.relu = nn.ReLU()
self.max_pool = nn.MaxPool2d(kernel_size=3, stride=2)

def forward(self, inputs):
out = inputs
out = self.conv(inputs)
out = self.batchnorm(out)
out = self.relu(out)
out = self.max_pool(out)
return out


def generate_fourier_features(pos, num_bands, max_resolution=(224, 224), concat_pos=True, sine_only=False):
"""
Generate a Fourier frequency position encoding with linear spacing.
Expand Down Expand Up @@ -240,18 +297,32 @@ def __init__(
self.conv_after_patching = conv_after_patching
self.out_channels = out_channels

if self.prep_type in ["conv", "patches"]:
raise NotImplementedError(f"Preparation type {prep_type} is not yet supported")
if self.prep_type == 'conv':
# Downsampling with conv is currently restricted
convnet_num_layers = math.log(spatial_downsample, 4)
convnet_num_layers_is_int = (convnet_num_layers == np.round(convnet_num_layers))
if not convnet_num_layers_is_int or temporal_downsample != 1:
raise ValueError('Only powers of 4 expected for spatial '
'and 1 expected for temporal '
'downsampling with conv.')
self.convnet = Conv2DDownsample(
num_layers=int(convnet_num_layers),
out_channels=out_channels,
use_batchnorm=conv2d_use_batchnorm)

elif self.prep_type == "conv1x1":
assert temporal_downsample == 1, "conv1x1 does not downsample in time."
if temporal_downsample != 1:
raise ValueError("Conv1x1 does not downsample in time.")
self.convnet_1x1 = nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=[1, 1],
# spatial_downsample is unconstrained for 1x1 convolutions.
stride=(spatial_downsample, spatial_downsample),
)

elif self.prep_type == "patches":
raise NotImplementedError(f"Preparation type {prep_type} is not yet supported")

if self.position_encoding_type == "trainable":
self.position_embeddings = PerceiverTrainablePositionEncoding(**position_encoding_kwargs)
elif self.position_encoding_type == "fourier":
Expand Down Expand Up @@ -329,8 +400,13 @@ def _build_network_inputs(self, inputs: torch.Tensor, pos: torch.Tensor, network
return inputs_with_pos, inputs

def forward(self, inputs: torch.Tensor, pos: Optional[torch.Tensor] = None, network_input_is_1d: bool = True):
if self.prep_type in ["conv", "patches"]:
if self.prep_type == "patches":
raise NotImplementedError("TODO")
elif self.prep_type == "conv":
# Convnet image featurization.
# Downsamples spatially by a factor of 4
inputs = self.convnet(inputs)

elif self.prep_type == "conv1x1":
# map inputs to self.out_channels
inputs = self.convnet_1x1(inputs)
Expand Down