Skip to content

Incorrect Parameter Type on UNet2DModel #4806

@fpgaminer

Description

@fpgaminer

Describe the bug

The parameter type for UNet2DModel::__init__ seems to be incorrect, at least on these three parameters:

down_block_types: Tuple[str] = ("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"),
up_block_types: Tuple[str] = ("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"),
block_out_channels: Tuple[int] = (224, 448, 672, 896),

I believe the correct type is Tuple[str, ...] and Tuple[int, ...] which means "A tuple of any size, where all elements are of type str (or int)". The current type, Tuple[str] means "A tuple that must be of size 1 and that one element is of type str", which causes typing errors when trying to use UNet2DModel.

Reproduction

A typical instantiation of UNet2DModel like this:

self.model = UNet2DModel(
			sample_size=self.config.sample_size,
			in_channels=3,
			out_channels=3,
			layers_per_block=2,
			block_out_channels=(128, 128, 256, 256, 512, 512),
			down_block_types=(
				"DownBlock2D",
				"DownBlock2D",
				"DownBlock2D",
				"DownBlock2D",
				"AttnDownBlock2D",
				"DownBlock2D",
			),
			up_block_types=(
				"UpBlock2D",
				"AttnUpBlock2D",
				"UpBlock2D",
				"UpBlock2D",
				"UpBlock2D",
				"UpBlock2D",
			),
		).to(self.device)

Logs

Argument of type "tuple[Literal['DownBlock2D'], Literal['DownBlock2D'], Literal['DownBlock2D'], Literal['DownBlock2D'], Literal['AttnDownBlock2D'], Literal['DownBlock2D']]" cannot be assigned to parameter "down_block_types" of type "Tuple[str]" in function "__init__"
  "tuple[Literal['DownBlock2D'], Literal['DownBlock2D'], Literal['DownBlock2D'], Literal['DownBlock2D'], Literal['AttnDownBlock2D'], Literal['DownBlock2D']]" is incompatible with "Tuple[str]"
    Element size mismatch; expected 1 but received 6PylancereportGeneralTypeIssues

System Info

- `diffusers` version: 0.20.0
- Platform: Linux-5.15.0-78-generic-x86_64-with-glibc2.35
- Python version: 3.10.11
- PyTorch version (GPU?): 2.0.1 (True)
- Huggingface_hub version: 0.15.1
- Transformers version: 4.29.2
- Accelerate version: 0.19.0
- xFormers version: not installed

Who can help?

@patrickvonplaten @sayakpaul

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingstaleIssues that haven't received updates

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions