-
Notifications
You must be signed in to change notification settings - Fork 6.5k
Closed
Labels
bugSomething isn't workingSomething isn't workingstaleIssues that haven't received updatesIssues that haven't received updates
Description
Describe the bug
The parameter type for UNet2DModel::__init__ seems to be incorrect, at least on these three parameters:
diffusers/src/diffusers/models/unet_2d.py
Lines 97 to 99 in c4d2823
| down_block_types: Tuple[str] = ("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"), | |
| up_block_types: Tuple[str] = ("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"), | |
| block_out_channels: Tuple[int] = (224, 448, 672, 896), |
I believe the correct type is Tuple[str, ...] and Tuple[int, ...] which means "A tuple of any size, where all elements are of type str (or int)". The current type, Tuple[str] means "A tuple that must be of size 1 and that one element is of type str", which causes typing errors when trying to use UNet2DModel.
Reproduction
A typical instantiation of UNet2DModel like this:
self.model = UNet2DModel(
sample_size=self.config.sample_size,
in_channels=3,
out_channels=3,
layers_per_block=2,
block_out_channels=(128, 128, 256, 256, 512, 512),
down_block_types=(
"DownBlock2D",
"DownBlock2D",
"DownBlock2D",
"DownBlock2D",
"AttnDownBlock2D",
"DownBlock2D",
),
up_block_types=(
"UpBlock2D",
"AttnUpBlock2D",
"UpBlock2D",
"UpBlock2D",
"UpBlock2D",
"UpBlock2D",
),
).to(self.device)
Logs
Argument of type "tuple[Literal['DownBlock2D'], Literal['DownBlock2D'], Literal['DownBlock2D'], Literal['DownBlock2D'], Literal['AttnDownBlock2D'], Literal['DownBlock2D']]" cannot be assigned to parameter "down_block_types" of type "Tuple[str]" in function "__init__"
"tuple[Literal['DownBlock2D'], Literal['DownBlock2D'], Literal['DownBlock2D'], Literal['DownBlock2D'], Literal['AttnDownBlock2D'], Literal['DownBlock2D']]" is incompatible with "Tuple[str]"
Element size mismatch; expected 1 but received 6PylancereportGeneralTypeIssuesSystem Info
- `diffusers` version: 0.20.0
- Platform: Linux-5.15.0-78-generic-x86_64-with-glibc2.35
- Python version: 3.10.11
- PyTorch version (GPU?): 2.0.1 (True)
- Huggingface_hub version: 0.15.1
- Transformers version: 4.29.2
- Accelerate version: 0.19.0
- xFormers version: not installed
Who can help?
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workingstaleIssues that haven't received updatesIssues that haven't received updates