From cd991d1e1a648cffe894405db02f34059d86809f Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 25 Dec 2024 15:37:49 +0530 Subject: [PATCH 1/5] Fix TorchAO related bugs; revert device_map changes (#10371) * Revert "Add support for sharded models when TorchAO quantization is enabled (#10256)" This reverts commit 41ba8c0bf6b3dc3ebd0fa6b96ecf671fa4171566. * update tests * udpate * update * update * update device map tests * apply review suggestions * update * make style * fix * update docs * update tests * update workflow * update * improve tests * allclose tolerance * Update src/diffusers/models/modeling_utils.py Co-authored-by: Sayak Paul * Update tests/quantization/torchao/test_torchao.py Co-authored-by: Sayak Paul * improve tests * fix * update correct slices --------- Co-authored-by: Sayak Paul --- .github/workflows/nightly_tests.yml | 2 + docs/source/en/quantization/torchao.md | 62 +++ src/diffusers/models/modeling_utils.py | 8 +- .../quantizers/torchao/torchao_quantizer.py | 2 +- tests/quantization/torchao/test_torchao.py | 401 ++++++++++++------ 5 files changed, 350 insertions(+), 125 deletions(-) diff --git a/.github/workflows/nightly_tests.yml b/.github/workflows/nightly_tests.yml index cc0abac6e4ab..9375f760a151 100644 --- a/.github/workflows/nightly_tests.yml +++ b/.github/workflows/nightly_tests.yml @@ -359,6 +359,8 @@ jobs: test_location: "bnb" - backend: "gguf" test_location: "gguf" + - backend: "torchao" + test_location: "torchao" runs-on: group: aws-g6e-xlarge-plus container: diff --git a/docs/source/en/quantization/torchao.md b/docs/source/en/quantization/torchao.md index 1f9f99a79a3b..c056876c2f09 100644 --- a/docs/source/en/quantization/torchao.md +++ b/docs/source/en/quantization/torchao.md @@ -25,6 +25,7 @@ Quantize a model by passing [`TorchAoConfig`] to [`~ModelMixin.from_pretrained`] The example below only quantizes the weights to int8. ```python +import torch from diffusers import FluxPipeline, FluxTransformer2DModel, TorchAoConfig model_id = "black-forest-labs/FLUX.1-dev" @@ -44,6 +45,10 @@ pipe = FluxPipeline.from_pretrained( ) pipe.to("cuda") +# Without quantization: ~31.447 GB +# With quantization: ~20.40 GB +print(f"Pipeline memory usage: {torch.cuda.max_memory_reserved() / 1024**3:.3f} GB") + prompt = "A cat holding a sign that says hello world" image = pipe( prompt, num_inference_steps=50, guidance_scale=4.5, max_sequence_length=512 @@ -88,6 +93,63 @@ Some quantization methods are aliases (for example, `int8wo` is the commonly use Refer to the official torchao documentation for a better understanding of the available quantization methods and the exhaustive list of configuration options available. +## Serializing and Deserializing quantized models + +To serialize a quantized model in a given dtype, first load the model with the desired quantization dtype and then save it using the [`~ModelMixin.save_pretrained`] method. + +```python +import torch +from diffusers import FluxTransformer2DModel, TorchAoConfig + +quantization_config = TorchAoConfig("int8wo") +transformer = FluxTransformer2DModel.from_pretrained( + "black-forest-labs/Flux.1-Dev", + subfolder="transformer", + quantization_config=quantization_config, + torch_dtype=torch.bfloat16, +) +transformer.save_pretrained("/path/to/flux_int8wo", safe_serialization=False) +``` + +To load a serialized quantized model, use the [`~ModelMixin.from_pretrained`] method. + +```python +import torch +from diffusers import FluxPipeline, FluxTransformer2DModel + +transformer = FluxTransformer2DModel.from_pretrained("/path/to/flux_int8wo", torch_dtype=torch.bfloat16, use_safetensors=False) +pipe = FluxPipeline.from_pretrained("black-forest-labs/Flux.1-Dev", transformer=transformer, torch_dtype=torch.bfloat16) +pipe.to("cuda") + +prompt = "A cat holding a sign that says hello world" +image = pipe(prompt, num_inference_steps=30, guidance_scale=7.0).images[0] +image.save("output.png") +``` + +Some quantization methods, such as `uint4wo`, cannot be loaded directly and may result in an `UnpicklingError` when trying to load the models, but work as expected when saving them. In order to work around this, one can load the state dict manually into the model. Note, however, that this requires using `weights_only=False` in `torch.load`, so it should be run only if the weights were obtained from a trustable source. + +```python +import torch +from accelerate import init_empty_weights +from diffusers import FluxPipeline, FluxTransformer2DModel, TorchAoConfig + +# Serialize the model +transformer = FluxTransformer2DModel.from_pretrained( + "black-forest-labs/Flux.1-Dev", + subfolder="transformer", + quantization_config=TorchAoConfig("uint4wo"), + torch_dtype=torch.bfloat16, +) +transformer.save_pretrained("/path/to/flux_uint4wo", safe_serialization=False, max_shard_size="50GB") +# ... + +# Load the model +state_dict = torch.load("/path/to/flux_uint4wo/diffusion_pytorch_model.bin", weights_only=False, map_location="cpu") +with init_empty_weights(): + transformer = FluxTransformer2DModel.from_config("/path/to/flux_uint4wo/config.json") +transformer.load_state_dict(state_dict, strict=True, assign=True) +``` + ## Resources - [TorchAO Quantization API](https://github.com/pytorch/ao/blob/main/torchao/quantization/README.md) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index d236ebb83983..d6efcc736487 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -718,10 +718,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P hf_quantizer = None if hf_quantizer is not None: - is_bnb_quantization_method = hf_quantizer.quantization_config.quant_method.value == "bitsandbytes" - if is_bnb_quantization_method and device_map is not None: + if device_map is not None: raise NotImplementedError( - "Currently, `device_map` is automatically inferred for quantized bitsandbytes models. Support for providing `device_map` as an input will be added in the future." + "Currently, providing `device_map` is not supported for quantized models. Providing `device_map` as an input will be added in the future." ) hf_quantizer.validate_environment(torch_dtype=torch_dtype, from_flax=from_flax, device_map=device_map) @@ -820,7 +819,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P revision=revision, subfolder=subfolder or "", ) - if hf_quantizer is not None and is_bnb_quantization_method: + # TODO: https://github.com/huggingface/diffusers/issues/10013 + if hf_quantizer is not None: model_file = _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata) logger.info("Merged sharded checkpoints as `hf_quantizer` is not None.") is_sharded = False diff --git a/src/diffusers/quantizers/torchao/torchao_quantizer.py b/src/diffusers/quantizers/torchao/torchao_quantizer.py index 5770e32c909e..a829234afd56 100644 --- a/src/diffusers/quantizers/torchao/torchao_quantizer.py +++ b/src/diffusers/quantizers/torchao/torchao_quantizer.py @@ -132,7 +132,7 @@ def validate_environment(self, *args, **kwargs): def update_torch_dtype(self, torch_dtype): quant_type = self.quantization_config.quant_type - if quant_type.startswith("int"): + if quant_type.startswith("int") or quant_type.startswith("uint"): if torch_dtype is not None and torch_dtype != torch.bfloat16: logger.warning( f"You are trying to set torch_dtype to {torch_dtype} for int4/int8/uintx quantization, but " diff --git a/tests/quantization/torchao/test_torchao.py b/tests/quantization/torchao/test_torchao.py index 0fa9182a3314..3c3f13db9b1c 100644 --- a/tests/quantization/torchao/test_torchao.py +++ b/tests/quantization/torchao/test_torchao.py @@ -131,8 +131,9 @@ def tearDown(self): gc.collect() torch.cuda.empty_cache() - def get_dummy_components(self, quantization_config: TorchAoConfig): - model_id = "hf-internal-testing/tiny-flux-pipe" + def get_dummy_components( + self, quantization_config: TorchAoConfig, model_id: str = "hf-internal-testing/tiny-flux-pipe" + ): transformer = FluxTransformer2DModel.from_pretrained( model_id, subfolder="transformer", @@ -211,8 +212,8 @@ def get_dummy_tensor_inputs(self, device=None, seed: int = 0): "timestep": timestep, } - def _test_quant_type(self, quantization_config: TorchAoConfig, expected_slice: List[float]): - components = self.get_dummy_components(quantization_config) + def _test_quant_type(self, quantization_config: TorchAoConfig, expected_slice: List[float], model_id: str): + components = self.get_dummy_components(quantization_config, model_id) pipe = FluxPipeline(**components) pipe.to(device=torch_device) @@ -223,44 +224,45 @@ def _test_quant_type(self, quantization_config: TorchAoConfig, expected_slice: L self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3)) def test_quantization(self): - # fmt: off - QUANTIZATION_TYPES_TO_TEST = [ - ("int4wo", np.array([0.4648, 0.5234, 0.5547, 0.4219, 0.4414, 0.6445, 0.4336, 0.4531, 0.5625])), - ("int4dq", np.array([0.4688, 0.5195, 0.5547, 0.418, 0.4414, 0.6406, 0.4336, 0.4531, 0.5625])), - ("int8wo", np.array([0.4648, 0.5195, 0.5547, 0.4199, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])), - ("int8dq", np.array([0.4648, 0.5195, 0.5547, 0.4199, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])), - ("uint4wo", np.array([0.4609, 0.5234, 0.5508, 0.4199, 0.4336, 0.6406, 0.4316, 0.4531, 0.5625])), - ("uint7wo", np.array([0.4648, 0.5195, 0.5547, 0.4219, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])), - ] - - if TorchAoConfig._is_cuda_capability_atleast_8_9(): - QUANTIZATION_TYPES_TO_TEST.extend([ - ("float8wo_e5m2", np.array([0.4590, 0.5273, 0.5547, 0.4219, 0.4375, 0.6406, 0.4316, 0.4512, 0.5625])), - ("float8wo_e4m3", np.array([0.4648, 0.5234, 0.5547, 0.4219, 0.4414, 0.6406, 0.4316, 0.4531, 0.5625])), - # ===== - # The following lead to an internal torch error: - # RuntimeError: mat2 shape (32x4 must be divisible by 16 - # Skip these for now; TODO(aryan): investigate later - # ("float8dq_e4m3", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])), - # ("float8dq_e4m3_tensor", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])), - # ===== - # Cutlass fails to initialize for below - # ("float8dq_e4m3_row", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])), - # ===== - ("fp4", np.array([0.4668, 0.5195, 0.5547, 0.4199, 0.4434, 0.6445, 0.4316, 0.4531, 0.5625])), - ("fp6", np.array([0.4668, 0.5195, 0.5547, 0.4199, 0.4434, 0.6445, 0.4316, 0.4531, 0.5625])), - ]) - # fmt: on - - for quantization_name, expected_slice in QUANTIZATION_TYPES_TO_TEST: - quant_kwargs = {} - if quantization_name in ["uint4wo", "uint7wo"]: - # The dummy flux model that we use has smaller dimensions. This imposes some restrictions on group_size here - quant_kwargs.update({"group_size": 16}) - quantization_config = TorchAoConfig( - quant_type=quantization_name, modules_to_not_convert=["x_embedder"], **quant_kwargs - ) - self._test_quant_type(quantization_config, expected_slice) + for model_id in ["hf-internal-testing/tiny-flux-pipe", "hf-internal-testing/tiny-flux-sharded"]: + # fmt: off + QUANTIZATION_TYPES_TO_TEST = [ + ("int4wo", np.array([0.4648, 0.5234, 0.5547, 0.4219, 0.4414, 0.6445, 0.4336, 0.4531, 0.5625])), + ("int4dq", np.array([0.4688, 0.5195, 0.5547, 0.418, 0.4414, 0.6406, 0.4336, 0.4531, 0.5625])), + ("int8wo", np.array([0.4648, 0.5195, 0.5547, 0.4199, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])), + ("int8dq", np.array([0.4648, 0.5195, 0.5547, 0.4199, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])), + ("uint4wo", np.array([0.4609, 0.5234, 0.5508, 0.4199, 0.4336, 0.6406, 0.4316, 0.4531, 0.5625])), + ("uint7wo", np.array([0.4648, 0.5195, 0.5547, 0.4219, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])), + ] + + if TorchAoConfig._is_cuda_capability_atleast_8_9(): + QUANTIZATION_TYPES_TO_TEST.extend([ + ("float8wo_e5m2", np.array([0.4590, 0.5273, 0.5547, 0.4219, 0.4375, 0.6406, 0.4316, 0.4512, 0.5625])), + ("float8wo_e4m3", np.array([0.4648, 0.5234, 0.5547, 0.4219, 0.4414, 0.6406, 0.4316, 0.4531, 0.5625])), + # ===== + # The following lead to an internal torch error: + # RuntimeError: mat2 shape (32x4 must be divisible by 16 + # Skip these for now; TODO(aryan): investigate later + # ("float8dq_e4m3", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])), + # ("float8dq_e4m3_tensor", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])), + # ===== + # Cutlass fails to initialize for below + # ("float8dq_e4m3_row", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])), + # ===== + ("fp4", np.array([0.4668, 0.5195, 0.5547, 0.4199, 0.4434, 0.6445, 0.4316, 0.4531, 0.5625])), + ("fp6", np.array([0.4668, 0.5195, 0.5547, 0.4199, 0.4434, 0.6445, 0.4316, 0.4531, 0.5625])), + ]) + # fmt: on + + for quantization_name, expected_slice in QUANTIZATION_TYPES_TO_TEST: + quant_kwargs = {} + if quantization_name in ["uint4wo", "uint7wo"]: + # The dummy flux model that we use has smaller dimensions. This imposes some restrictions on group_size here + quant_kwargs.update({"group_size": 16}) + quantization_config = TorchAoConfig( + quant_type=quantization_name, modules_to_not_convert=["x_embedder"], **quant_kwargs + ) + self._test_quant_type(quantization_config, expected_slice, model_id) def test_int4wo_quant_bfloat16_conversion(self): """ @@ -280,12 +282,14 @@ def test_int4wo_quant_bfloat16_conversion(self): self.assertEqual(weight.quant_max, 15) def test_device_map(self): + # Note: We were not checking if the weight tensor's were AffineQuantizedTensor's before. If we did + # it would have errored out. Now, we do. So, device_map basically never worked with or without + # sharded checkpoints. This will need to be supported in the future (TODO(aryan)) """ Test if the quantized model int4 weight-only is working properly with "auto" and custom device maps. The custom device map performs cpu/disk offloading as well. Also verifies that the device map is correctly set (in the `hf_device_map` attribute of the model). """ - custom_device_map_dict = { "time_text_embed": torch_device, "context_embedder": torch_device, @@ -297,48 +301,54 @@ def test_device_map(self): } device_maps = ["auto", custom_device_map_dict] - inputs = self.get_dummy_tensor_inputs(torch_device) - expected_slice = np.array([0.3457, -0.0366, 0.0105, -0.2275, -0.4941, 0.4395, -0.166, -0.6641, 0.4375]) + # inputs = self.get_dummy_tensor_inputs(torch_device) + # expected_slice = np.array([0.3457, -0.0366, 0.0105, -0.2275, -0.4941, 0.4395, -0.166, -0.6641, 0.4375]) for device_map in device_maps: - device_map_to_compare = {"": 0} if device_map == "auto" else device_map - - # Test non-sharded model - with tempfile.TemporaryDirectory() as offload_folder: - quantization_config = TorchAoConfig("int4_weight_only", group_size=64) - quantized_model = FluxTransformer2DModel.from_pretrained( - "hf-internal-testing/tiny-flux-pipe", - subfolder="transformer", - quantization_config=quantization_config, - device_map=device_map, - torch_dtype=torch.bfloat16, - offload_folder=offload_folder, - ) - - self.assertTrue(quantized_model.hf_device_map == device_map_to_compare) - - output = quantized_model(**inputs)[0] - output_slice = output.flatten()[-9:].detach().float().cpu().numpy() - self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3)) - - # Test sharded model - with tempfile.TemporaryDirectory() as offload_folder: - quantization_config = TorchAoConfig("int4_weight_only", group_size=64) - quantized_model = FluxTransformer2DModel.from_pretrained( - "hf-internal-testing/tiny-flux-sharded", - subfolder="transformer", - quantization_config=quantization_config, - device_map=device_map, - torch_dtype=torch.bfloat16, - offload_folder=offload_folder, - ) - - self.assertTrue(quantized_model.hf_device_map == device_map_to_compare) - - output = quantized_model(**inputs)[0] - output_slice = output.flatten()[-9:].detach().float().cpu().numpy() - - self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3)) + # device_map_to_compare = {"": 0} if device_map == "auto" else device_map + + # Test non-sharded model - should work + with self.assertRaises(NotImplementedError): + with tempfile.TemporaryDirectory() as offload_folder: + quantization_config = TorchAoConfig("int4_weight_only", group_size=64) + _ = FluxTransformer2DModel.from_pretrained( + "hf-internal-testing/tiny-flux-pipe", + subfolder="transformer", + quantization_config=quantization_config, + device_map=device_map, + torch_dtype=torch.bfloat16, + offload_folder=offload_folder, + ) + + # weight = quantized_model.transformer_blocks[0].ff.net[2].weight + # self.assertTrue(quantized_model.hf_device_map == device_map_to_compare) + # self.assertTrue(isinstance(weight, AffineQuantizedTensor)) + + # output = quantized_model(**inputs)[0] + # output_slice = output.flatten()[-9:].detach().float().cpu().numpy() + # self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3)) + + # Test sharded model - should not work + with self.assertRaises(NotImplementedError): + with tempfile.TemporaryDirectory() as offload_folder: + quantization_config = TorchAoConfig("int4_weight_only", group_size=64) + _ = FluxTransformer2DModel.from_pretrained( + "hf-internal-testing/tiny-flux-sharded", + subfolder="transformer", + quantization_config=quantization_config, + device_map=device_map, + torch_dtype=torch.bfloat16, + offload_folder=offload_folder, + ) + + # weight = quantized_model.transformer_blocks[0].ff.net[2].weight + # self.assertTrue(quantized_model.hf_device_map == device_map_to_compare) + # self.assertTrue(isinstance(weight, AffineQuantizedTensor)) + + # output = quantized_model(**inputs)[0] + # output_slice = output.flatten()[-9:].detach().float().cpu().numpy() + + # self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3)) def test_modules_to_not_convert(self): quantization_config = TorchAoConfig("int8_weight_only", modules_to_not_convert=["transformer_blocks.0"]) @@ -404,43 +414,63 @@ def test_training(self): @nightly def test_torch_compile(self): r"""Test that verifies if torch.compile works with torchao quantization.""" - quantization_config = TorchAoConfig("int8_weight_only") - components = self.get_dummy_components(quantization_config) - pipe = FluxPipeline(**components) - pipe.to(device=torch_device, dtype=torch.bfloat16) + for model_id in ["hf-internal-testing/tiny-flux-pipe", "hf-internal-testing/tiny-flux-sharded"]: + quantization_config = TorchAoConfig("int8_weight_only") + components = self.get_dummy_components(quantization_config, model_id=model_id) + pipe = FluxPipeline(**components) + pipe.to(device=torch_device) - inputs = self.get_dummy_inputs(torch_device) - normal_output = pipe(**inputs)[0].flatten()[-32:] + inputs = self.get_dummy_inputs(torch_device) + normal_output = pipe(**inputs)[0].flatten()[-32:] - pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True, dynamic=False) - inputs = self.get_dummy_inputs(torch_device) - compile_output = pipe(**inputs)[0].flatten()[-32:] + pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True, dynamic=False) + inputs = self.get_dummy_inputs(torch_device) + compile_output = pipe(**inputs)[0].flatten()[-32:] - # Note: Seems to require higher tolerance - self.assertTrue(np.allclose(normal_output, compile_output, atol=1e-2, rtol=1e-3)) + # Note: Seems to require higher tolerance + self.assertTrue(np.allclose(normal_output, compile_output, atol=1e-2, rtol=1e-3)) def test_memory_footprint(self): r""" A simple test to check if the model conversion has been done correctly by checking on the memory footprint of the converted model and the class type of the linear layers of the converted models """ - transformer_int4wo = self.get_dummy_components(TorchAoConfig("int4wo"))["transformer"] - transformer_int4wo_gs32 = self.get_dummy_components(TorchAoConfig("int4wo", group_size=32))["transformer"] - transformer_int8wo = self.get_dummy_components(TorchAoConfig("int8wo"))["transformer"] - transformer_bf16 = self.get_dummy_components(None)["transformer"] - - total_int4wo = get_model_size_in_bytes(transformer_int4wo) - total_int4wo_gs32 = get_model_size_in_bytes(transformer_int4wo_gs32) - total_int8wo = get_model_size_in_bytes(transformer_int8wo) - total_bf16 = get_model_size_in_bytes(transformer_bf16) - - # Latter has smaller group size, so more groups -> more scales and zero points - self.assertTrue(total_int4wo < total_int4wo_gs32) - # int8 quantizes more layers compare to int4 with default group size - self.assertTrue(total_int8wo < total_int4wo) - # int4wo does not quantize too many layers because of default group size, but for the layers it does - # there is additional overhead of scales and zero points - self.assertTrue(total_bf16 < total_int4wo) + for model_id in ["hf-internal-testing/tiny-flux-pipe", "hf-internal-testing/tiny-flux-sharded"]: + transformer_int4wo = self.get_dummy_components(TorchAoConfig("int4wo"), model_id=model_id)["transformer"] + transformer_int4wo_gs32 = self.get_dummy_components( + TorchAoConfig("int4wo", group_size=32), model_id=model_id + )["transformer"] + transformer_int8wo = self.get_dummy_components(TorchAoConfig("int8wo"), model_id=model_id)["transformer"] + transformer_bf16 = self.get_dummy_components(None, model_id=model_id)["transformer"] + + # Will not quantized all the layers by default due to the model weights shapes not being divisible by group_size=64 + for block in transformer_int4wo.transformer_blocks: + self.assertTrue(isinstance(block.ff.net[2].weight, AffineQuantizedTensor)) + self.assertTrue(isinstance(block.ff_context.net[2].weight, AffineQuantizedTensor)) + + # Will quantize all the linear layers except x_embedder + for name, module in transformer_int4wo_gs32.named_modules(): + if isinstance(module, nn.Linear) and name not in ["x_embedder"]: + self.assertTrue(isinstance(module.weight, AffineQuantizedTensor)) + + # Will quantize all the linear layers + for module in transformer_int8wo.modules(): + if isinstance(module, nn.Linear): + self.assertTrue(isinstance(module.weight, AffineQuantizedTensor)) + + total_int4wo = get_model_size_in_bytes(transformer_int4wo) + total_int4wo_gs32 = get_model_size_in_bytes(transformer_int4wo_gs32) + total_int8wo = get_model_size_in_bytes(transformer_int8wo) + total_bf16 = get_model_size_in_bytes(transformer_bf16) + + # TODO: refactor to align with other quantization tests + # Latter has smaller group size, so more groups -> more scales and zero points + self.assertTrue(total_int4wo < total_int4wo_gs32) + # int8 quantizes more layers compare to int4 with default group size + self.assertTrue(total_int8wo < total_int4wo) + # int4wo does not quantize too many layers because of default group size, but for the layers it does + # there is additional overhead of scales and zero points + self.assertTrue(total_bf16 < total_int4wo) def test_wrong_config(self): with self.assertRaises(ValueError): @@ -500,6 +530,8 @@ def _test_original_model_expected_slice(self, quant_method, quant_method_kwargs, inputs = self.get_dummy_tensor_inputs(torch_device) output = quantized_model(**inputs)[0] output_slice = output.flatten()[-9:].detach().float().cpu().numpy() + weight = quantized_model.transformer_blocks[0].ff.net[2].weight + self.assertTrue(isinstance(weight, (AffineQuantizedTensor, LinearActivationQuantizedTensor))) self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3)) def _check_serialization_expected_slice(self, quant_method, quant_method_kwargs, expected_slice, device): @@ -508,8 +540,8 @@ def _check_serialization_expected_slice(self, quant_method, quant_method_kwargs, with tempfile.TemporaryDirectory() as tmp_dir: quantized_model.save_pretrained(tmp_dir, safe_serialization=False) loaded_quantized_model = FluxTransformer2DModel.from_pretrained( - tmp_dir, torch_dtype=torch.bfloat16, device_map=torch_device, use_safetensors=False - ) + tmp_dir, torch_dtype=torch.bfloat16, use_safetensors=False + ).to(device=torch_device) inputs = self.get_dummy_tensor_inputs(torch_device) output = loaded_quantized_model(**inputs)[0] @@ -563,20 +595,25 @@ def tearDown(self): torch.cuda.empty_cache() def get_dummy_components(self, quantization_config: TorchAoConfig): + # This is just for convenience, so that we can modify it at one place for custom environments and locally testing + cache_dir = None model_id = "black-forest-labs/FLUX.1-dev" transformer = FluxTransformer2DModel.from_pretrained( model_id, subfolder="transformer", quantization_config=quantization_config, torch_dtype=torch.bfloat16, + cache_dir=cache_dir, + ) + text_encoder = CLIPTextModel.from_pretrained( + model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16, cache_dir=cache_dir ) - text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16) text_encoder_2 = T5EncoderModel.from_pretrained( - model_id, subfolder="text_encoder_2", torch_dtype=torch.bfloat16 + model_id, subfolder="text_encoder_2", torch_dtype=torch.bfloat16, cache_dir=cache_dir ) - tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer") - tokenizer_2 = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer_2") - vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.bfloat16) + tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer", cache_dir=cache_dir) + tokenizer_2 = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer_2", cache_dir=cache_dir) + vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.bfloat16, cache_dir=cache_dir) scheduler = FlowMatchEulerDiscreteScheduler() return { @@ -611,10 +648,12 @@ def _test_quant_type(self, quantization_config, expected_slice): pipe = FluxPipeline(**components) pipe.enable_model_cpu_offload() + weight = pipe.transformer.transformer_blocks[0].ff.net[2].weight + self.assertTrue(isinstance(weight, (AffineQuantizedTensor, LinearActivationQuantizedTensor))) + inputs = self.get_dummy_inputs(torch_device) output = pipe(**inputs)[0].flatten() output_slice = np.concatenate((output[:16], output[-16:])) - self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3)) def test_quantization(self): @@ -627,7 +666,7 @@ def test_quantization(self): if TorchAoConfig._is_cuda_capability_atleast_8_9(): QUANTIZATION_TYPES_TO_TEST.extend([ ("float8wo_e4m3", np.array([0.0546, 0.0722, 0.1328, 0.0468, 0.0585, 0.1367, 0.0605, 0.0703, 0.1328, 0.0625, 0.0703, 0.1445, 0.0585, 0.0703, 0.1406, 0.0605, 0.3496, 0.7109, 0.4843, 0.4042, 0.7226, 0.5000, 0.4160, 0.7031, 0.4824, 0.3886, 0.6757, 0.4667, 0.3710, 0.6679, 0.4902, 0.4238])), - ("fp5_e3m1", np.array([0.0527, 0.0742, 0.1289, 0.0449, 0.0625, 0.1308, 0.0585, 0.0742, 0.1269, 0.0585, 0.0722, 0.1328, 0.0566, 0.0742, 0.1347, 0.0585, 0.3691, 0.7578, 0.5429, 0.4355, 0.7695, 0.5546, 0.4414, 0.7578, 0.5468, 0.4179, 0.7265, 0.5273, 0.3945, 0.6992, 0.5234, 0.4316])), + ("fp5_e3m1", np.array([0.0527, 0.0762, 0.1309, 0.0449, 0.0645, 0.1328, 0.0566, 0.0723, 0.125, 0.0566, 0.0703, 0.1328, 0.0566, 0.0742, 0.1348, 0.0566, 0.3633, 0.7617, 0.5273, 0.4277, 0.7891, 0.5469, 0.4375, 0.8008, 0.5586, 0.4336, 0.7383, 0.5156, 0.3906, 0.6992, 0.5156, 0.4375])), ]) # fmt: on @@ -637,3 +676,125 @@ def test_quantization(self): gc.collect() torch.cuda.empty_cache() torch.cuda.synchronize() + + def test_serialization_int8wo(self): + quantization_config = TorchAoConfig("int8wo") + components = self.get_dummy_components(quantization_config) + pipe = FluxPipeline(**components) + pipe.enable_model_cpu_offload() + + weight = pipe.transformer.x_embedder.weight + self.assertTrue(isinstance(weight, AffineQuantizedTensor)) + + inputs = self.get_dummy_inputs(torch_device) + output = pipe(**inputs)[0].flatten()[:128] + + with tempfile.TemporaryDirectory() as tmp_dir: + pipe.transformer.save_pretrained(tmp_dir, safe_serialization=False) + pipe.remove_all_hooks() + del pipe.transformer + gc.collect() + torch.cuda.empty_cache() + torch.cuda.synchronize() + transformer = FluxTransformer2DModel.from_pretrained( + tmp_dir, torch_dtype=torch.bfloat16, use_safetensors=False + ) + pipe.transformer = transformer + pipe.enable_model_cpu_offload() + + weight = transformer.x_embedder.weight + self.assertTrue(isinstance(weight, AffineQuantizedTensor)) + + loaded_output = pipe(**inputs)[0].flatten()[:128] + # Seems to require higher tolerance depending on which machine it is being run. + # A difference of 0.06 in normalized pixel space (-1 to 1), corresponds to a difference of + # 0.06 / 2 * 255 = 7.65 in pixel space (0 to 255). On our CI runners, the difference is about 0.04, + # on DGX it is 0.06, and on audace it is 0.037. So, we are using a tolerance of 0.06 here. + self.assertTrue(np.allclose(output, loaded_output, atol=0.06)) + + def test_memory_footprint_int4wo(self): + # The original checkpoints are in bf16 and about 24 GB + expected_memory_in_gb = 6.0 + quantization_config = TorchAoConfig("int4wo") + cache_dir = None + transformer = FluxTransformer2DModel.from_pretrained( + "black-forest-labs/FLUX.1-dev", + subfolder="transformer", + quantization_config=quantization_config, + torch_dtype=torch.bfloat16, + cache_dir=cache_dir, + ) + int4wo_memory_in_gb = get_model_size_in_bytes(transformer) / 1024**3 + self.assertTrue(int4wo_memory_in_gb < expected_memory_in_gb) + + def test_memory_footprint_int8wo(self): + # The original checkpoints are in bf16 and about 24 GB + expected_memory_in_gb = 12.0 + quantization_config = TorchAoConfig("int8wo") + cache_dir = None + transformer = FluxTransformer2DModel.from_pretrained( + "black-forest-labs/FLUX.1-dev", + subfolder="transformer", + quantization_config=quantization_config, + torch_dtype=torch.bfloat16, + cache_dir=cache_dir, + ) + int8wo_memory_in_gb = get_model_size_in_bytes(transformer) / 1024**3 + self.assertTrue(int8wo_memory_in_gb < expected_memory_in_gb) + + +@require_torch +@require_torch_gpu +@require_torchao_version_greater_or_equal("0.7.0") +@slow +@nightly +class SlowTorchAoPreserializedModelTests(unittest.TestCase): + def tearDown(self): + gc.collect() + torch.cuda.empty_cache() + + def get_dummy_inputs(self, device: torch.device, seed: int = 0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator().manual_seed(seed) + + inputs = { + "prompt": "an astronaut riding a horse in space", + "height": 512, + "width": 512, + "num_inference_steps": 20, + "output_type": "np", + "generator": generator, + } + + return inputs + + def test_transformer_int8wo(self): + # fmt: off + expected_slice = np.array([0.0566, 0.0781, 0.1426, 0.0488, 0.0684, 0.1504, 0.0625, 0.0781, 0.1445, 0.0625, 0.0781, 0.1562, 0.0547, 0.0723, 0.1484, 0.0566, 0.5703, 0.8867, 0.7266, 0.5742, 0.875, 0.7148, 0.5586, 0.875, 0.7148, 0.5547, 0.8633, 0.7109, 0.5469, 0.8398, 0.6992, 0.5703]) + # fmt: on + + # This is just for convenience, so that we can modify it at one place for custom environments and locally testing + cache_dir = None + transformer = FluxTransformer2DModel.from_pretrained( + "hf-internal-testing/FLUX.1-Dev-TorchAO-int8wo-transformer", + torch_dtype=torch.bfloat16, + use_safetensors=False, + cache_dir=cache_dir, + ) + pipe = FluxPipeline.from_pretrained( + "black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16, cache_dir=cache_dir + ) + pipe.enable_model_cpu_offload() + + # Verify that all linear layer weights are quantized + for name, module in pipe.transformer.named_modules(): + if isinstance(module, nn.Linear): + self.assertTrue(isinstance(module.weight, AffineQuantizedTensor)) + + # Verify outputs match expected slice + inputs = self.get_dummy_inputs(torch_device) + output = pipe(**inputs)[0].flatten() + output_slice = np.concatenate((output[:16], output[-16:])) + self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3)) From 1b202c5730631417000585e3639539cefc79cbd7 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 25 Dec 2024 17:27:16 +0530 Subject: [PATCH 2/5] [LoRA] feat: support `unload_lora_weights()` for Flux Control. (#10206) * feat: support unload_lora_weights() for Flux Control. * tighten test * minor * updates * meta device fixes. --- src/diffusers/loaders/lora_pipeline.py | 56 ++++++++++++++++++++++ tests/lora/test_lora_layers_flux.py | 66 ++++++++++++++++++++++++++ 2 files changed, 122 insertions(+) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index e69681611a4a..351295e938ff 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -2286,6 +2286,50 @@ def unload_lora_weights(self): transformer.load_state_dict(transformer._transformer_norm_layers, strict=False) transformer._transformer_norm_layers = None + if getattr(transformer, "_overwritten_params", None) is not None: + overwritten_params = transformer._overwritten_params + module_names = set() + + for param_name in overwritten_params: + if param_name.endswith(".weight"): + module_names.add(param_name.replace(".weight", "")) + + for name, module in transformer.named_modules(): + if isinstance(module, torch.nn.Linear) and name in module_names: + module_weight = module.weight.data + module_bias = module.bias.data if module.bias is not None else None + bias = module_bias is not None + + parent_module_name, _, current_module_name = name.rpartition(".") + parent_module = transformer.get_submodule(parent_module_name) + + current_param_weight = overwritten_params[f"{name}.weight"] + in_features, out_features = current_param_weight.shape[1], current_param_weight.shape[0] + with torch.device("meta"): + original_module = torch.nn.Linear( + in_features, + out_features, + bias=bias, + dtype=module_weight.dtype, + ) + + tmp_state_dict = {"weight": current_param_weight} + if module_bias is not None: + tmp_state_dict.update({"bias": overwritten_params[f"{name}.bias"]}) + original_module.load_state_dict(tmp_state_dict, assign=True, strict=True) + setattr(parent_module, current_module_name, original_module) + + del tmp_state_dict + + if current_module_name in _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX: + attribute_name = _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX[current_module_name] + new_value = int(current_param_weight.shape[1]) + old_value = getattr(transformer.config, attribute_name) + setattr(transformer.config, attribute_name, new_value) + logger.info( + f"Set the {attribute_name} attribute of the model to {new_value} from {old_value}." + ) + @classmethod def _maybe_expand_transformer_param_shape_or_error_( cls, @@ -2312,6 +2356,8 @@ def _maybe_expand_transformer_param_shape_or_error_( # Expand transformer parameter shapes if they don't match lora has_param_with_shape_update = False + overwritten_params = {} + is_peft_loaded = getattr(transformer, "peft_config", None) is not None for name, module in transformer.named_modules(): if isinstance(module, torch.nn.Linear): @@ -2386,6 +2432,16 @@ def _maybe_expand_transformer_param_shape_or_error_( f"Set the {attribute_name} attribute of the model to {new_value} from {old_value}." ) + # For `unload_lora_weights()`. + # TODO: this could lead to more memory overhead if the number of overwritten params + # are large. Should be revisited later and tackled through a `discard_original_layers` arg. + overwritten_params[f"{current_module_name}.weight"] = module_weight + if module_bias is not None: + overwritten_params[f"{current_module_name}.bias"] = module_bias + + if len(overwritten_params) > 0: + transformer._overwritten_params = overwritten_params + return has_param_with_shape_update @classmethod diff --git a/tests/lora/test_lora_layers_flux.py b/tests/lora/test_lora_layers_flux.py index b22fbaaed69b..0861160de6aa 100644 --- a/tests/lora/test_lora_layers_flux.py +++ b/tests/lora/test_lora_layers_flux.py @@ -558,6 +558,72 @@ def test_load_regular_lora(self): self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features * 2) self.assertFalse(np.allclose(original_output, lora_output, atol=1e-3, rtol=1e-3)) + def test_lora_unload_with_parameter_expanded_shapes(self): + components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) + + logger = logging.get_logger("diffusers.loaders.lora_pipeline") + logger.setLevel(logging.DEBUG) + + # Change the transformer config to mimic a real use case. + num_channels_without_control = 4 + transformer = FluxTransformer2DModel.from_config( + components["transformer"].config, in_channels=num_channels_without_control + ).to(torch_device) + self.assertTrue( + transformer.config.in_channels == num_channels_without_control, + f"Expected {num_channels_without_control} channels in the modified transformer but has {transformer.config.in_channels=}", + ) + + # This should be initialized with a Flux pipeline variant that doesn't accept `control_image`. + components["transformer"] = transformer + pipe = FluxPipeline(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + _, _, inputs = self.get_dummy_inputs(with_generator=False) + control_image = inputs.pop("control_image") + original_out = pipe(**inputs, generator=torch.manual_seed(0))[0] + + control_pipe = self.pipeline_class(**components) + out_features, in_features = control_pipe.transformer.x_embedder.weight.shape + rank = 4 + + dummy_lora_A = torch.nn.Linear(2 * in_features, rank, bias=False) + dummy_lora_B = torch.nn.Linear(rank, out_features, bias=False) + lora_state_dict = { + "transformer.x_embedder.lora_A.weight": dummy_lora_A.weight, + "transformer.x_embedder.lora_B.weight": dummy_lora_B.weight, + } + with CaptureLogger(logger) as cap_logger: + control_pipe.load_lora_weights(lora_state_dict, "adapter-1") + self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser") + + inputs["control_image"] = control_image + lora_out = control_pipe(**inputs, generator=torch.manual_seed(0))[0] + + self.assertFalse(np.allclose(original_out, lora_out, rtol=1e-4, atol=1e-4)) + self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features) + self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features) + self.assertTrue(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module")) + + control_pipe.unload_lora_weights() + self.assertTrue( + control_pipe.transformer.config.in_channels == num_channels_without_control, + f"Expected {num_channels_without_control} channels in the modified transformer but has {control_pipe.transformer.config.in_channels=}", + ) + loaded_pipe = FluxPipeline.from_pipe(control_pipe) + self.assertTrue( + loaded_pipe.transformer.config.in_channels == num_channels_without_control, + f"Expected {num_channels_without_control} channels in the modified transformer but has {loaded_pipe.transformer.config.in_channels=}", + ) + inputs.pop("control_image") + unloaded_lora_out = loaded_pipe(**inputs, generator=torch.manual_seed(0))[0] + + self.assertFalse(np.allclose(unloaded_lora_out, lora_out, rtol=1e-4, atol=1e-4)) + self.assertTrue(np.allclose(unloaded_lora_out, original_out, atol=1e-4, rtol=1e-4)) + self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features) + self.assertTrue(pipe.transformer.config.in_channels == in_features) + @unittest.skip("Not supported in Flux.") def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): pass From f430a0cf32dda04a297d2e32518caffd82f300dd Mon Sep 17 00:00:00 2001 From: Alan Ponnachan <85491837+AlanPonnachan@users.noreply.github.com> Date: Fri, 27 Dec 2024 13:23:04 +0530 Subject: [PATCH 3/5] Add torch_xla support to pipeline_aura_flow.py (#10365) * Add torch_xla support to pipeline_aura_flow.py * make style --------- Co-authored-by: hlky --- .../pipelines/aura_flow/pipeline_aura_flow.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py index 8737b219c833..0bb3fb7368d8 100644 --- a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py +++ b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py @@ -21,11 +21,18 @@ from ...models import AuraFlowTransformer2DModel, AutoencoderKL from ...models.attention_processor import AttnProcessor2_0, FusedAttnProcessor2_0, XFormersAttnProcessor from ...schedulers import FlowMatchEulerDiscreteScheduler -from ...utils import logging, replace_example_docstring +from ...utils import is_torch_xla_available, logging, replace_example_docstring from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -564,6 +571,9 @@ def __call__( if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() + if XLA_AVAILABLE: + xm.mark_step() + if output_type == "latent": image = latents else: From 83da817f73776aa61086683f55432bf2915d0748 Mon Sep 17 00:00:00 2001 From: SahilCarterr <110806554+SahilCarterr@users.noreply.github.com> Date: Fri, 27 Dec 2024 14:03:11 +0530 Subject: [PATCH 4/5] [Add] torch_xla support to pipeline_sana.py (#10364) [Add] torch_xla support in pipeline_sana.py --- src/diffusers/pipelines/sana/pipeline_sana.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/diffusers/pipelines/sana/pipeline_sana.py b/src/diffusers/pipelines/sana/pipeline_sana.py index fe3c9e13aa31..c90dec4d41b3 100644 --- a/src/diffusers/pipelines/sana/pipeline_sana.py +++ b/src/diffusers/pipelines/sana/pipeline_sana.py @@ -31,6 +31,7 @@ USE_PEFT_BACKEND, is_bs4_available, is_ftfy_available, + is_torch_xla_available, logging, replace_example_docstring, scale_lora_layers, @@ -46,6 +47,13 @@ from .pipeline_output import SanaPipelineOutput +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name if is_bs4_available(): @@ -864,6 +872,9 @@ def __call__( if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() + if XLA_AVAILABLE: + xm.mark_step() + if output_type == "latent": image = latents else: From 55ac1dbdf2e77dcc93b0fa87d638d074219922e4 Mon Sep 17 00:00:00 2001 From: hlky Date: Fri, 27 Dec 2024 17:58:49 +0000 Subject: [PATCH 5/5] Default values in SD3 pipelines when submodules are not loaded (#10393) SD3 pipelines hasattr --- .../pipeline_stable_diffusion_3_img2img.py | 17 +++++++++++++---- .../pipeline_stable_diffusion_3_inpaint.py | 19 ++++++++++++++----- 2 files changed, 27 insertions(+), 9 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py index c10401324430..77daf5b0b4e0 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py @@ -226,12 +226,21 @@ def __init__( transformer=transformer, scheduler=scheduler, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + latent_channels = self.vae.config.latent_channels if hasattr(self, "vae") and self.vae is not None else 16 self.image_processor = VaeImageProcessor( - vae_scale_factor=self.vae_scale_factor, vae_latent_channels=self.vae.config.latent_channels + vae_scale_factor=self.vae_scale_factor, vae_latent_channels=latent_channels + ) + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 + ) + self.default_sample_size = ( + self.transformer.config.sample_size + if hasattr(self, "transformer") and self.transformer is not None + else 128 ) - self.tokenizer_max_length = self.tokenizer.model_max_length - self.default_sample_size = self.transformer.config.sample_size self.patch_size = ( self.transformer.config.patch_size if hasattr(self, "transformer") and self.transformer is not None else 2 ) diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py index ca32880d0df2..e1cfdb3e6e97 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py @@ -225,19 +225,28 @@ def __init__( transformer=transformer, scheduler=scheduler, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + latent_channels = self.vae.config.latent_channels if hasattr(self, "vae") and self.vae is not None else 16 self.image_processor = VaeImageProcessor( - vae_scale_factor=self.vae_scale_factor, vae_latent_channels=self.vae.config.latent_channels + vae_scale_factor=self.vae_scale_factor, vae_latent_channels=latent_channels ) self.mask_processor = VaeImageProcessor( vae_scale_factor=self.vae_scale_factor, - vae_latent_channels=self.vae.config.latent_channels, + vae_latent_channels=latent_channels, do_normalize=False, do_binarize=True, do_convert_grayscale=True, ) - self.tokenizer_max_length = self.tokenizer.model_max_length - self.default_sample_size = self.transformer.config.sample_size + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 + ) + self.default_sample_size = ( + self.transformer.config.sample_size + if hasattr(self, "transformer") and self.transformer is not None + else 128 + ) self.patch_size = ( self.transformer.config.patch_size if hasattr(self, "transformer") and self.transformer is not None else 2 )