From f8100600068e6d6e656a5a9c4b6a4d4dfff4928a Mon Sep 17 00:00:00 2001 From: Mishig Davaadorj Date: Wed, 21 Sep 2022 11:17:15 +0200 Subject: [PATCH] Fix flax from_pretrained pytorch weight check (#603) --- src/diffusers/modeling_flax_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/modeling_flax_utils.py b/src/diffusers/modeling_flax_utils.py index ed62b5fe57..e06e7fb7e6 100644 --- a/src/diffusers/modeling_flax_utils.py +++ b/src/diffusers/modeling_flax_utils.py @@ -307,7 +307,7 @@ def from_pretrained( # Load model if os.path.isdir(pretrained_model_name_or_path): if from_pt: - if not os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME): + if not os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)): raise EnvironmentError( f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} " ) @@ -315,8 +315,8 @@ def from_pretrained( elif os.path.isfile(os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)): # Load from a Flax checkpoint model_file = os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME) - # At this stage we don't have a weight file so we will raise an error. - elif os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME): + # Check if pytorch weights exist instead + elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)): raise EnvironmentError( f"{WEIGHTS_NAME} file found in directory {pretrained_model_name_or_path}. Please load the model" " using `from_pt=True`."