Skip to content

[tests] Changes to the torch.compile() CI and tests #11508

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/benchmark.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ jobs:
runs-on:
group: aws-g6-4xlarge-plus
container:
image: diffusers/diffusers-pytorch-compile-cuda
image: diffusers/diffusers-pytorch-cuda
options: --shm-size "16gb" --ipc host --gpus 0
steps:
- name: Checkout diffusers
Expand Down
8 changes: 7 additions & 1 deletion .github/workflows/build_docker_images.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,12 @@ jobs:
run: |
CHANGED_FILES="${{ steps.file_changes.outputs.all }}"
for FILE in $CHANGED_FILES; do
# skip anything that isn’t still on disk
if [[ ! -f "$FILE" ]]; then
echo "Skipping removed file $FILE"
continue
fi

if [[ "$FILE" == docker/*Dockerfile ]]; then
DOCKER_PATH="${FILE%/Dockerfile}"
DOCKER_TAG=$(basename "$DOCKER_PATH")
Expand All @@ -65,7 +71,7 @@ jobs:
image-name:
- diffusers-pytorch-cpu
- diffusers-pytorch-cuda
- diffusers-pytorch-compile-cuda
- diffusers-pytorch-cuda
- diffusers-pytorch-xformers-cuda
- diffusers-pytorch-minimum-cuda
- diffusers-flax-cpu
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/nightly_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ jobs:
group: aws-g4dn-2xlarge

container:
image: diffusers/diffusers-pytorch-compile-cuda
image: diffusers/diffusers-pytorch-cuda
options: --gpus 0 --shm-size "16gb" --ipc host

steps:
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/push_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ jobs:
group: aws-g4dn-2xlarge

container:
image: diffusers/diffusers-pytorch-compile-cuda
image: diffusers/diffusers-pytorch-cuda
options: --gpus 0 --shm-size "16gb" --ipc host

steps:
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/release_tests_fast.yml
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ jobs:
group: aws-g4dn-2xlarge

container:
image: diffusers/diffusers-pytorch-compile-cuda
image: diffusers/diffusers-pytorch-cuda
options: --gpus 0 --shm-size "16gb" --ipc host

steps:
Expand Down
50 changes: 0 additions & 50 deletions docker/diffusers-pytorch-compile-cuda/Dockerfile

This file was deleted.

20 changes: 12 additions & 8 deletions tests/models/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1721,14 +1721,14 @@ class TorchCompileTesterMixin:
def setUp(self):
# clean up the VRAM before each test
super().setUp()
torch._dynamo.reset()
torch.compiler.reset()
gc.collect()
backend_empty_cache(torch_device)

def tearDown(self):
# clean up the VRAM after each test in case of CUDA runtime errors
super().tearDown()
torch._dynamo.reset()
torch.compiler.reset()
gc.collect()
backend_empty_cache(torch_device)

Expand All @@ -1737,13 +1737,17 @@ def tearDown(self):
@is_torch_compile
@slow
def test_torch_compile_recompilation_and_graph_break(self):
torch._dynamo.reset()
torch.compiler.reset()
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

model = self.model_class(**init_dict).to(torch_device)
model = torch.compile(model, fullgraph=True)

with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad():
with (
torch._inductor.utils.fresh_inductor_cache(),
torch._dynamo.config.patch(error_on_recompile=True),
torch.no_grad(),
):
_ = model(**inputs_dict)
_ = model(**inputs_dict)

Expand Down Expand Up @@ -1771,7 +1775,7 @@ def tearDown(self):
# It is critical that the dynamo cache is reset for each test. Otherwise, if the test re-uses the same model,
# there will be recompilation errors, as torch caches the model when run in the same process.
super().tearDown()
torch._dynamo.reset()
torch.compiler.reset()
gc.collect()
backend_empty_cache(torch_device)

Expand Down Expand Up @@ -1905,21 +1909,21 @@ def test_hotswapping_model(self, rank0, rank1):
def test_hotswapping_compiled_model_linear(self, rank0, rank1):
# It's important to add this context to raise an error on recompilation
target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
with torch._dynamo.config.patch(error_on_recompile=True):
with torch._dynamo.config.patch(error_on_recompile=True), torch._inductor.utils.fresh_inductor_cache():
self.check_model_hotswap(do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules)

@parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa
def test_hotswapping_compiled_model_conv2d(self, rank0, rank1):
# It's important to add this context to raise an error on recompilation
target_modules = ["conv", "conv1", "conv2"]
with torch._dynamo.config.patch(error_on_recompile=True):
with torch._dynamo.config.patch(error_on_recompile=True), torch._inductor.utils.fresh_inductor_cache():
self.check_model_hotswap(do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules)

@parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa
def test_hotswapping_compiled_model_both_linear_and_conv2d(self, rank0, rank1):
# It's important to add this context to raise an error on recompilation
target_modules = ["to_q", "conv"]
with torch._dynamo.config.patch(error_on_recompile=True):
with torch._dynamo.config.patch(error_on_recompile=True), torch._inductor.utils.fresh_inductor_cache():
self.check_model_hotswap(do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules)

def test_enable_lora_hotswap_called_after_adapter_added_raises(self):
Expand Down
76 changes: 7 additions & 69 deletions tests/models/transformers/test_models_transformer_hunyuan_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,16 @@
from diffusers import HunyuanVideoTransformer3DModel
from diffusers.utils.testing_utils import (
enable_full_determinism,
is_torch_compile,
require_torch_2,
require_torch_gpu,
slow,
torch_device,
)

from ..test_modeling_common import ModelTesterMixin
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin


enable_full_determinism()


class HunyuanVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
class HunyuanVideoTransformer3DTests(ModelTesterMixin, TorchCompileTesterMixin, unittest.TestCase):
model_class = HunyuanVideoTransformer3DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
Expand Down Expand Up @@ -96,23 +92,8 @@ def test_gradient_checkpointing_is_applied(self):
expected_set = {"HunyuanVideoTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

@require_torch_gpu
@require_torch_2
@is_torch_compile
@slow
def test_torch_compile_recompilation_and_graph_break(self):
torch._dynamo.reset()
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

model = self.model_class(**init_dict).to(torch_device)
model = torch.compile(model, fullgraph=True)

with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad():
_ = model(**inputs_dict)
_ = model(**inputs_dict)


class HunyuanSkyreelsImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
class HunyuanSkyreelsImageToVideoTransformer3DTests(ModelTesterMixin, TorchCompileTesterMixin, unittest.TestCase):
model_class = HunyuanVideoTransformer3DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
Expand Down Expand Up @@ -179,23 +160,8 @@ def test_gradient_checkpointing_is_applied(self):
expected_set = {"HunyuanVideoTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

@require_torch_gpu
@require_torch_2
@is_torch_compile
@slow
def test_torch_compile_recompilation_and_graph_break(self):
torch._dynamo.reset()
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

model = self.model_class(**init_dict).to(torch_device)
model = torch.compile(model, fullgraph=True)

with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad():
_ = model(**inputs_dict)
_ = model(**inputs_dict)


class HunyuanVideoImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
class HunyuanVideoImageToVideoTransformer3DTests(ModelTesterMixin, TorchCompileTesterMixin, unittest.TestCase):
model_class = HunyuanVideoTransformer3DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
Expand Down Expand Up @@ -260,23 +226,10 @@ def test_gradient_checkpointing_is_applied(self):
expected_set = {"HunyuanVideoTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

@require_torch_gpu
@require_torch_2
@is_torch_compile
@slow
def test_torch_compile_recompilation_and_graph_break(self):
torch._dynamo.reset()
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

model = self.model_class(**init_dict).to(torch_device)
model = torch.compile(model, fullgraph=True)

with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad():
_ = model(**inputs_dict)
_ = model(**inputs_dict)


class HunyuanVideoTokenReplaceImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
class HunyuanVideoTokenReplaceImageToVideoTransformer3DTests(
ModelTesterMixin, TorchCompileTesterMixin, unittest.TestCase
):
model_class = HunyuanVideoTransformer3DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
Expand Down Expand Up @@ -342,18 +295,3 @@ def test_output(self):
def test_gradient_checkpointing_is_applied(self):
expected_set = {"HunyuanVideoTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

@require_torch_gpu
@require_torch_2
@is_torch_compile
@slow
def test_torch_compile_recompilation_and_graph_break(self):
torch._dynamo.reset()
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

model = self.model_class(**init_dict).to(torch_device)
model = torch.compile(model, fullgraph=True)

with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad():
_ = model(**inputs_dict)
_ = model(**inputs_dict)
23 changes: 2 additions & 21 deletions tests/models/transformers/test_models_transformer_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,16 @@
from diffusers import WanTransformer3DModel
from diffusers.utils.testing_utils import (
enable_full_determinism,
is_torch_compile,
require_torch_2,
require_torch_gpu,
slow,
torch_device,
)

from ..test_modeling_common import ModelTesterMixin
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin


enable_full_determinism()


class WanTransformer3DTests(ModelTesterMixin, unittest.TestCase):
class WanTransformer3DTests(ModelTesterMixin, TorchCompileTesterMixin, unittest.TestCase):
model_class = WanTransformer3DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
Expand Down Expand Up @@ -86,18 +82,3 @@ def prepare_init_args_and_inputs_for_common(self):
def test_gradient_checkpointing_is_applied(self):
expected_set = {"WanTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

@require_torch_gpu
@require_torch_2
@is_torch_compile
@slow
def test_torch_compile_recompilation_and_graph_break(self):
torch._dynamo.reset()
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

model = self.model_class(**init_dict).to(torch_device)
model = torch.compile(model, fullgraph=True)

with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad():
_ = model(**inputs_dict)
_ = model(**inputs_dict)
Loading
Loading