Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding SegGPT #27735

Merged
merged 122 commits into from
Feb 26, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
122 commits
Select commit Hold shift + click to select a range
fecf251
First commit
EduardoPach Nov 27, 2023
2fbf69b
Improvements
EduardoPach Nov 27, 2023
155e486
More improvements
EduardoPach Nov 27, 2023
2de229a
Converted original checkpoint to HF checkpoint
EduardoPach Nov 28, 2023
b3d5049
Fix style
EduardoPach Nov 28, 2023
fe53c92
Fixed forward
EduardoPach Nov 29, 2023
051b6c8
More improvements
EduardoPach Nov 30, 2023
70d0290
More improvements
EduardoPach Dec 5, 2023
951ac86
Update src/transformers/models/seggpt/modeling_seggpt.py
EduardoPach Dec 7, 2023
14b70a7
Remove asserts
EduardoPach Dec 7, 2023
dce5f4a
Merge branch 'adding-seggpt' of https://github.com/EduardoPach/transf…
EduardoPach Dec 7, 2023
3c8e12a
Remove unnecessary attributes
EduardoPach Dec 7, 2023
449fa80
Changed model name to camel case
EduardoPach Dec 7, 2023
41e409b
Improve forward doc
EduardoPach Dec 7, 2023
3e0a77a
Improve tests
EduardoPach Dec 7, 2023
b039594
More improvements
EduardoPach Dec 7, 2023
5c604d4
Fix copies
EduardoPach Dec 7, 2023
00c8bda
Fix doc
EduardoPach Dec 7, 2023
dfc48dd
Make SegGptImageProcessor more flexible
EduardoPach Dec 7, 2023
4152a68
Added few-shot test
EduardoPach Dec 8, 2023
c04a177
Fix merge
NielsRogge Dec 10, 2023
acc58cc
Fix style
NielsRogge Dec 10, 2023
8bb30d5
Update READMEs and docs
NielsRogge Dec 10, 2023
6ec868b
Update READMEs
NielsRogge Dec 10, 2023
88e8144
Make inputs required
NielsRogge Dec 10, 2023
2a066e9
Add SegGptForImageSegmentation
NielsRogge Dec 11, 2023
09187e0
Make tests pass
EduardoPach Dec 12, 2023
5190205
Rename to out_indicies
EduardoPach Dec 12, 2023
bf0bab9
Update src/transformers/models/seggpt/image_processing_seggpt.py
EduardoPach Dec 13, 2023
c38de07
Update src/transformers/models/seggpt/image_processing_seggpt.py
EduardoPach Dec 13, 2023
52de2a7
Fixed naming convention
EduardoPach Dec 13, 2023
bebb958
Copying SegGptMlp from modeling_sam.py
EduardoPach Dec 13, 2023
2c7c311
Some minor improvements
NielsRogge Dec 12, 2023
75b2d90
Remove mlp_ratio
NielsRogge Dec 13, 2023
a612330
Fix docstrings
NielsRogge Dec 14, 2023
74383a8
Fixed docstring match
Jan 7, 2024
932a01f
Objects defined before use
Jan 7, 2024
f54d036
Storing only patch_size and beta for SegGptLoss
Jan 7, 2024
b283608
removed _prepare_inputs method
Jan 7, 2024
0fcfbcf
Removed modified from headers
Jan 7, 2024
64d2a90
Renamed to output_indicies
Jan 7, 2024
559c5be
Removed unnecessary einsums
Jan 7, 2024
45ce96b
Update tests/models/seggpt/test_modeling_seggpt.py
EduardoPach Jan 7, 2024
6f982aa
Update tests/models/seggpt/test_modeling_seggpt.py
EduardoPach Jan 7, 2024
c4d5c00
Update tests/models/seggpt/test_modeling_seggpt.py
EduardoPach Jan 7, 2024
6bc0571
Update src/transformers/models/seggpt/image_processing_seggpt.py
EduardoPach Jan 7, 2024
33c3f4d
Update src/transformers/models/seggpt/image_processing_seggpt.py
EduardoPach Jan 7, 2024
a435033
Update src/transformers/models/seggpt/image_processing_seggpt.py
EduardoPach Jan 7, 2024
798a7d3
Update src/transformers/models/seggpt/modeling_seggpt.py
EduardoPach Jan 7, 2024
3b443dc
Update src/transformers/models/seggpt/modeling_seggpt.py
EduardoPach Jan 7, 2024
e24c369
Fixing issues
Jan 7, 2024
bd0b552
Raise error as soon as possible
Jan 7, 2024
cca0937
More fixes
Jan 7, 2024
39e2767
Jan 7, 2024
3545672
Fix merge
NielsRogge Jan 14, 2024
7228221
Fix merge
NielsRogge Jan 14, 2024
6133e40
Added palette to SegGptImageProcessor
Jan 25, 2024
cf93d25
Merge branch 'adding-seggpt' of https://github.com/EduardoPach/transf…
Jan 25, 2024
1837324
Fixed typo
Jan 25, 2024
dc3cf80
Fixed shape typo
Jan 25, 2024
fc2304c
Added permute before doing palette to class mapping
Jan 25, 2024
b086384
Fixed style
Jan 26, 2024
0c9ad32
Fixed and added tests
Jan 26, 2024
beab961
Fixed docstrings
Jan 26, 2024
bd35b95
Matching SegFormer API for post_processing_semantic_segmentation
Feb 5, 2024
48d7fc3
Merge remote-tracking branch 'upstream/main' into adding-seggpt
Feb 5, 2024
3bcdf52
Fixed copies
Feb 5, 2024
900fac9
Fixed SegGptImageProcessor to handle both binary and RGB masks
Feb 5, 2024
abfc78a
Updated docstrings of SegGptImageProcessor
Feb 5, 2024
ba3f9cb
Update src/transformers/models/seggpt/image_processing_seggpt.py
EduardoPach Feb 7, 2024
09cdd0e
Update docs/source/en/model_doc/seggpt.md
EduardoPach Feb 7, 2024
690de06
Update src/transformers/models/seggpt/configuration_seggpt.py
EduardoPach Feb 7, 2024
37fabe2
Update src/transformers/models/seggpt/convert_seggpt_to_hf.py
EduardoPach Feb 7, 2024
a2a45cd
Update src/transformers/models/seggpt/image_processing_seggpt.py
EduardoPach Feb 7, 2024
59787c2
Update src/transformers/models/seggpt/modeling_seggpt.py
EduardoPach Feb 7, 2024
cf3c1da
Update src/transformers/models/seggpt/image_processing_seggpt.py
EduardoPach Feb 7, 2024
706285f
Update src/transformers/models/seggpt/image_processing_seggpt.py
EduardoPach Feb 7, 2024
755a7b5
Update src/transformers/models/seggpt/image_processing_seggpt.py
EduardoPach Feb 7, 2024
c6257f5
Update src/transformers/models/seggpt/modeling_seggpt.py
EduardoPach Feb 7, 2024
041c400
Update tests/models/seggpt/test_image_processing_seggpt.py
EduardoPach Feb 7, 2024
3ffecbc
Update tests/models/seggpt/test_modeling_seggpt.py
EduardoPach Feb 7, 2024
2d1d77c
Update src/transformers/models/seggpt/modeling_seggpt.py
EduardoPach Feb 7, 2024
6354a60
Update src/transformers/models/seggpt/modeling_seggpt.py
EduardoPach Feb 7, 2024
f5e23c2
Update src/transformers/models/seggpt/modeling_seggpt.py
EduardoPach Feb 7, 2024
d347473
Object definitions above & fix style
Feb 7, 2024
9e35aa6
Renamed output_indices to intermediate_feature_indices
Feb 7, 2024
f6f068c
Removed unnecessary check on bool_masked_pos
Feb 7, 2024
a0cbe9b
Loss first in the outputs
Feb 7, 2024
f1dc953
Added validation for do_normalize
Feb 7, 2024
b8b1d5e
Improved SegGptImageProcessor and added new tests
Feb 10, 2024
d172ca0
Added comment
Feb 10, 2024
88db53f
Added docstrings to SegGptLoss
Feb 10, 2024
db06f21
Reimplemented ensemble condition logic in SegGptEncoder
Feb 10, 2024
5514650
Merge remote-tracking branch 'upstream/main' into adding-seggpt
Feb 10, 2024
7c4805e
Update src/transformers/models/seggpt/__init__.py
EduardoPach Feb 10, 2024
6ad819b
Update src/transformers/models/seggpt/modeling_seggpt.py
EduardoPach Feb 10, 2024
9dab61f
Update src/transformers/models/seggpt/convert_seggpt_to_hf.py
EduardoPach Feb 10, 2024
1b87260
Update src/transformers/models/seggpt/configuration_seggpt.py
EduardoPach Feb 10, 2024
c74ca80
Updated docstrings to use post_process_semantic_segmentation
Feb 10, 2024
e5f2c8c
Merge branch 'adding-seggpt' of https://github.com/EduardoPach/transf…
Feb 10, 2024
71dfbf2
Merge remote-tracking branch 'upstream/main' into adding-seggpt
Feb 12, 2024
af21937
Fixed typo on docstrings
Feb 12, 2024
62a82eb
moved pixel values test to test_image_processing_seggpt
Feb 13, 2024
4d425df
Merge remote-tracking branch 'upstream/main' into adding-seggpt
Feb 15, 2024
460f3fa
Addressed comments
Feb 15, 2024
f62b21e
Update src/transformers/models/seggpt/configuration_seggpt.py
EduardoPach Feb 15, 2024
620381d
Update src/transformers/models/seggpt/image_processing_seggpt.py
EduardoPach Feb 15, 2024
9999d0b
Update src/transformers/models/seggpt/configuration_seggpt.py
EduardoPach Feb 15, 2024
b73b21e
Update src/transformers/models/seggpt/modeling_seggpt.py
EduardoPach Feb 15, 2024
01c9f7e
Updated docstrings for SegGptLoss
Feb 15, 2024
c507557
Merge branch 'adding-seggpt' of https://github.com/EduardoPach/transf…
Feb 15, 2024
3373cbb
Address comments
Feb 15, 2024
f701039
Added SegGpt example to model docs
Feb 15, 2024
0e46681
Update src/transformers/models/seggpt/modeling_seggpt.py
EduardoPach Feb 21, 2024
43f2d34
moved patchify and unpatchify
EduardoPach Feb 21, 2024
ebc96f4
Merge remote-tracking branch 'upstream/main' into adding-seggpt
EduardoPach Feb 21, 2024
90f911d
Rename checkpoint
EduardoPach Feb 22, 2024
efb85d6
Renamed intermediate_features to intermediate_hidden_states for consi…
EduardoPach Feb 22, 2024
afeb9f2
Update src/transformers/models/seggpt/configuration_seggpt.py
EduardoPach Feb 22, 2024
7e22958
Replaced post_process_masks for post_process_semantic_segmentation in…
EduardoPach Feb 26, 2024
32dd142
Merge branch 'adding-seggpt' of https://github.com/EduardoPach/transf…
EduardoPach Feb 26, 2024
05f0a85
Merge remote-tracking branch 'upstream/main' into adding-seggpt
EduardoPach Feb 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
More improvements
  • Loading branch information
EduardoPach committed Nov 30, 2023
commit 051b6c8353da824c0cfdae7c9934aebd4826de82
2 changes: 0 additions & 2 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2779,7 +2779,6 @@
_import_structure["models.seggpt"].extend(
[
"SEGGPT_PRETRAINED_MODEL_ARCHIVE_LIST",
"SegGPTForInstanceSegmentation",
"SegGPTModel",
"SegGPTPreTrainedModel",
]
Expand Down Expand Up @@ -6673,7 +6672,6 @@
)
from .models.seggpt import (
SEGGPT_PRETRAINED_MODEL_ARCHIVE_LIST,
SegGPTForInstanceSegmentation,
SegGPTModel,
SegGPTPreTrainedModel,
)
Expand Down
2 changes: 0 additions & 2 deletions src/transformers/models/seggpt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
_import_structure["modeling_seggpt"] = [
"SEGGPT_PRETRAINED_MODEL_ARCHIVE_LIST",
"SegGPTModel",
"SegGPTForInstanceSegmentation",
"SegGPTPreTrainedModel",
]

Expand All @@ -52,7 +51,6 @@
else:
from .modeling_seggpt import (
SEGGPT_PRETRAINED_MODEL_ARCHIVE_LIST,
SegGPTForInstanceSegmentation,
SegGPTModel,
SegGPTPreTrainedModel,
)
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/seggpt/configuration_seggpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ class SegGPTConfig(PretrainedConfig):
use_rel_pos (`<fill_type>`, *optional*, defaults to `True`): <fill_docstring>
merge_index (`<fill_type>`, *optional*, defaults to 2): <fill_docstring>
encoder_output_indicies (`<fill_type>`, *optional*, defaults to `[5, 11, 17, 23]`): <fill_docstring>
beta (`<fill_type>`, *optional*, defaults to 0.01): <fill_docstring>

Example:

Expand Down Expand Up @@ -117,6 +118,7 @@ def __init__(
use_rel_pos=True,
merge_index=2,
encoder_output_indicies=[5, 11, 17, 23],
beta=0.01,
**kwargs,
):
super().__init__(**kwargs)
Expand Down
106 changes: 67 additions & 39 deletions src/transformers/models/seggpt/convert_seggpt_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import torch
from PIL import Image

from transformers import SegGPTConfig, SegGPTForInstanceSegmentation, SegGPTImageProcessor
from transformers import SegGPTConfig, SegGPTImageProcessor, SegGPTModel
from transformers.utils import logging


Expand All @@ -36,18 +36,18 @@ def create_rename_keys(config):
# fmt: off

# rename embedding and its parameters
rename_keys.append(("patch_embed.proj.weight", "model.embeddings.patch_embeddings.projection.weight"))
rename_keys.append(("patch_embed.proj.bias", "model.embeddings.patch_embeddings.projection.bias"))
rename_keys.append(("mask_token", "model.embeddings.mask_token"))
rename_keys.append(("segment_token_x", "model.embeddings.segment_token_input"))
rename_keys.append(("segment_token_y", "model.embeddings.segment_token_prompt"))
rename_keys.append(("type_token_cls", "model.embeddings.type_token_semantic"))
rename_keys.append(("type_token_ins", "model.embeddings.type_token_instance"))
rename_keys.append(("pos_embed", "model.embeddings.position_embeddings"))
rename_keys.append(("patch_embed.proj.weight", "embeddings.patch_embeddings.projection.weight"))
rename_keys.append(("patch_embed.proj.bias", "embeddings.patch_embeddings.projection.bias"))
rename_keys.append(("mask_token", "embeddings.mask_token"))
rename_keys.append(("segment_token_x", "embeddings.segment_token_input"))
rename_keys.append(("segment_token_y", "embeddings.segment_token_prompt"))
rename_keys.append(("type_token_cls", "embeddings.type_token_semantic"))
rename_keys.append(("type_token_ins", "embeddings.type_token_instance"))
rename_keys.append(("pos_embed", "embeddings.position_embeddings"))

# rename decoder and other
rename_keys.append(("norm.weight", "model.encoder.layernorm.weight"))
rename_keys.append(("norm.bias", "model.encoder.layernorm.bias"))
rename_keys.append(("norm.weight", "encoder.layernorm.weight"))
rename_keys.append(("norm.bias", "encoder.layernorm.bias"))
rename_keys.append(("decoder_embed.weight", "decoder.decoder_embed.weight"))
rename_keys.append(("decoder_embed.bias", "decoder.decoder_embed.bias"))
rename_keys.append(("decoder_pred.0.weight", "decoder.decoder_pred.conv.weight"))
Expand All @@ -59,22 +59,22 @@ def create_rename_keys(config):

# rename blocks
for i in range(config.num_hidden_layers):
rename_keys.append((f"blocks.{i}.attn.qkv.weight", f"model.encoder.layers.{i}.attention.qkv.weight"))
rename_keys.append((f"blocks.{i}.attn.qkv.bias", f"model.encoder.layers.{i}.attention.qkv.bias"))
rename_keys.append((f"blocks.{i}.attn.proj.weight", f"model.encoder.layers.{i}.attention.proj.weight"))
rename_keys.append((f"blocks.{i}.attn.proj.bias", f"model.encoder.layers.{i}.attention.proj.bias"))
rename_keys.append((f"blocks.{i}.attn.rel_pos_h", f"model.encoder.layers.{i}.attention.rel_pos_h"))
rename_keys.append((f"blocks.{i}.attn.rel_pos_w", f"model.encoder.layers.{i}.attention.rel_pos_w"))

rename_keys.append((f"blocks.{i}.mlp.fc1.weight", f"model.encoder.layers.{i}.mlp.fc1.weight"))
rename_keys.append((f"blocks.{i}.mlp.fc1.bias", f"model.encoder.layers.{i}.mlp.fc1.bias"))
rename_keys.append((f"blocks.{i}.mlp.fc2.weight", f"model.encoder.layers.{i}.mlp.fc2.weight"))
rename_keys.append((f"blocks.{i}.mlp.fc2.bias", f"model.encoder.layers.{i}.mlp.fc2.bias"))

rename_keys.append((f"blocks.{i}.norm1.weight", f"model.encoder.layers.{i}.layernorm_before.weight"))
rename_keys.append((f"blocks.{i}.norm1.bias", f"model.encoder.layers.{i}.layernorm_before.bias"))
rename_keys.append((f"blocks.{i}.norm2.weight", f"model.encoder.layers.{i}.layernorm_after.weight"))
rename_keys.append((f"blocks.{i}.norm2.bias", f"model.encoder.layers.{i}.layernorm_after.bias"))
rename_keys.append((f"blocks.{i}.attn.qkv.weight", f"encoder.layers.{i}.attention.qkv.weight"))
rename_keys.append((f"blocks.{i}.attn.qkv.bias", f"encoder.layers.{i}.attention.qkv.bias"))
rename_keys.append((f"blocks.{i}.attn.proj.weight", f"encoder.layers.{i}.attention.proj.weight"))
rename_keys.append((f"blocks.{i}.attn.proj.bias", f"encoder.layers.{i}.attention.proj.bias"))
rename_keys.append((f"blocks.{i}.attn.rel_pos_h", f"encoder.layers.{i}.attention.rel_pos_h"))
rename_keys.append((f"blocks.{i}.attn.rel_pos_w", f"encoder.layers.{i}.attention.rel_pos_w"))

rename_keys.append((f"blocks.{i}.mlp.fc1.weight", f"encoder.layers.{i}.mlp.fc1.weight"))
rename_keys.append((f"blocks.{i}.mlp.fc1.bias", f"encoder.layers.{i}.mlp.fc1.bias"))
rename_keys.append((f"blocks.{i}.mlp.fc2.weight", f"encoder.layers.{i}.mlp.fc2.weight"))
rename_keys.append((f"blocks.{i}.mlp.fc2.bias", f"encoder.layers.{i}.mlp.fc2.bias"))

rename_keys.append((f"blocks.{i}.norm1.weight", f"encoder.layers.{i}.layernorm_before.weight"))
rename_keys.append((f"blocks.{i}.norm1.bias", f"encoder.layers.{i}.layernorm_before.bias"))
rename_keys.append((f"blocks.{i}.norm2.weight", f"encoder.layers.{i}.layernorm_after.weight"))
rename_keys.append((f"blocks.{i}.norm2.bias", f"encoder.layers.{i}.layernorm_after.bias"))

# fmt: on

Expand Down Expand Up @@ -126,7 +126,7 @@ def convert_seggpt_checkpoint(args):
rename_key(new_state_dict, src, dest)

# Load HF model
model = SegGPTForInstanceSegmentation(config)
model = SegGPTModel(config)
model.eval()
missing_keys, unexpected_keys = model.load_state_dict(new_state_dict, strict=False)
print("Missing keys:", missing_keys)
Expand All @@ -135,43 +135,71 @@ def convert_seggpt_checkpoint(args):
input_img, prompt_img, prompt_mask = prepare_input()
image_processor = SegGPTImageProcessor()
inputs = image_processor(images=input_img, prompt_images=prompt_img, prompt_masks=prompt_mask, return_tensors="pt")

expected_prompt_pixel_values = torch.tensor(
[
[[-0.6965, -0.6965, -0.6965], [-0.6965, -0.6965, -0.6965], [-0.6965, -0.6965, -0.6965]],
[[1.6583, 1.6583, 1.6583], [1.6583, 1.6583, 1.6583], [1.6583, 1.6583, 1.6583]],
[[2.3088, 2.3088, 2.3088], [2.3088, 2.3088, 2.3088], [2.3088, 2.3088, 2.3088]],
]
)

expected_pixel_values = torch.tensor(
[
[[1.6324, 1.6153, 1.5810], [1.6153, 1.5982, 1.5810], [1.5810, 1.5639, 1.5639]],
[[1.2731, 1.2556, 1.2206], [1.2556, 1.2381, 1.2031], [1.2206, 1.2031, 1.1681]],
[[1.6465, 1.6465, 1.6465], [1.6465, 1.6465, 1.6465], [1.6291, 1.6291, 1.6291]],
]
)

expected_prompt_masks = torch.tensor(
[
[[-2.1179, -2.1179, -2.1179], [-2.1179, -2.1179, -2.1179], [-2.1179, -2.1179, -2.1179]],
[[-2.0357, -2.0357, -2.0357], [-2.0357, -2.0357, -2.0357], [-2.0357, -2.0357, -2.0357]],
[[-1.8044, -1.8044, -1.8044], [-1.8044, -1.8044, -1.8044], [-1.8044, -1.8044, -1.8044]],
]
)

assert torch.allclose(inputs.pixel_values[0, :, :3, :3], expected_pixel_values, atol=1e-4)
assert torch.allclose(inputs.prompt_pixel_values[0, :, :3, :3], expected_prompt_pixel_values, atol=1e-4)
assert torch.allclose(inputs.prompt_masks[0, :, :3, :3], expected_prompt_masks, atol=1e-4)

torch.manual_seed(2)
inputs = {k: torch.ones_like(v) for k, v in inputs.items()}
outputs = model(**inputs)
print(outputs)

expected_dummy_output = torch.tensor(
expected_output = torch.tensor(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this true for all the checkpoints? Why not just load the original model and check the outputs there?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's only one model checkpoint I've added the expected results here as a pre-test writing step. Would you like me to remove this checks?

[
[[0.9410, 1.0075, 0.9631], [0.9487, 1.0162, 0.9713], [0.9502, 1.0162, 0.9684]],
[[0.9338, 1.0081, 0.9691], [0.9428, 1.0179, 0.9773], [0.9429, 1.0172, 0.9722]],
[[0.9412, 1.0122, 0.9720], [0.9465, 1.0193, 0.9778], [0.9449, 1.0184, 0.9692]],
[[-2.1208, -2.1190, -2.1198], [-2.1237, -2.1228, -2.1227], [-2.1232, -2.1226, -2.1228]],
[[-2.0405, -2.0396, -2.0403], [-2.0434, -2.0434, -2.0433], [-2.0428, -2.0432, -2.0434]],
[[-1.8102, -1.8088, -1.8099], [-1.8131, -1.8126, -1.8129], [-1.8130, -1.8128, -1.8131]],
]
)

assert torch.allclose(outputs.pred_masks[0, :3, :3, :3], expected_dummy_output)
assert torch.allclose(outputs.pred_masks[0, :, :3, :3], expected_output, atol=1e-4)

print("Looks good!")

if pytorch_dump_folder_path is not None:
print(f"Saving model and processor for {model_name} to {pytorch_dump_folder_path}")
model.save_pretrained(pytorch_dump_folder_path)
# processor.save_pretrained(pytorch_dump_folder_path)
image_processor.save_pretrained(pytorch_dump_folder_path)

if push_to_hub:
print(f"Pushing model and processor for {model_name} to hub")
model.push_to_hub(f"EduardoPacheco/{model_name}")
# processor.push_to_hub(f"EduardoPacheco/{model_name}")
image_processor.push_to_hub(f"EduardoPacheco/{model_name}")


if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--model_name",
default="grounding-dino-tiny",
default="seggpt-vit-large",
type=str,
choices=["grounding-dino-tiny", "grounding-dino-base"],
help="Name of the GroundingDINO model you'd like to convert.",
choices=["seggpt-vit-large"],
help="Name of the SegGPT model you'd like to convert.",
)
parser.add_argument(
"--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
Expand Down
12 changes: 6 additions & 6 deletions src/transformers/models/seggpt/image_processing_seggpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def __init__(
self,
do_resize: bool = True,
size: Optional[Dict[str, int]] = None,
resample: PILImageResampling = PILImageResampling.BILINEAR,
resample: PILImageResampling = PILImageResampling.BICUBIC,
do_rescale: bool = True,
rescale_factor: Union[int, float] = 1 / 255,
do_normalize: bool = True,
Expand All @@ -96,12 +96,12 @@ def __init__(
self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN
self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD

# Copied from transformers.models.vit.image_processing_vit.ViTImageProcessor.resize
# Modifed from transformers.models.vit.image_processing_vit.ViTImageProcessor.resize
EduardoPach marked this conversation as resolved.
Show resolved Hide resolved
def resize(
EduardoPach marked this conversation as resolved.
Show resolved Hide resolved
self,
image: np.ndarray,
size: Dict[str, int],
resample: PILImageResampling = PILImageResampling.BILINEAR,
resample: PILImageResampling = PILImageResampling.BICUBIC,
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs,
Expand All @@ -114,8 +114,8 @@ def resize(
Image to resize.
size (`Dict[str, int]`):
Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image.
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
`PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BILINEAR`.
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
`PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BICUBIC`.
data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the output image. If unset, the channel dimension format of the input
image is used. Can be one of:
Expand Down Expand Up @@ -380,5 +380,5 @@ def preprocess(
**kwargs,
)

data = {"pixel_values": images, "prompt_pixel_values": prompt_images, "prompt_mask": prompt_masks}
data = {"pixel_values": images, "prompt_pixel_values": prompt_images, "prompt_masks": prompt_masks}
return BatchFeature(data=data, tensor_type=return_tensors)
Loading