diff --git a/tests/pytorch/test_pytorch_autolog.py b/tests/pytorch/test_pytorch_autolog.py index 5a71aa2c8aa74..40e90553f7e67 100644 --- a/tests/pytorch/test_pytorch_autolog.py +++ b/tests/pytorch/test_pytorch_autolog.py @@ -48,7 +48,7 @@ def pytorch_model_without_validation(): @pytest.mark.parametrize("log_models", [True, False]) def test_pytorch_autolog_log_models_configuration(log_models): mlflow.pytorch.autolog(log_models=log_models) - model = IrisClassificationWithoutValidation() + model = IrisClassification() dm = IrisDataModule() dm.prepare_data() dm.setup(stage="fit")