Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
ArthurZucker committed Sep 22, 2024
1 parent 73590ad commit ae095c5
Showing 1 changed file with 9 additions and 8 deletions.
17 changes: 9 additions & 8 deletions utils/modular_model_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,8 +253,7 @@ def SUPER_CALL_NODE(func_name):

def merge_docstrings(original_docstring, updated_docstring):
if " Args:\n " not in updated_docstring:
logger.warning("We detected a docstring that will be appended to the super's doc")
# Split the docstring at the example section, assuming `"""` or `'''` is used to define the docstring
# Split the docstring at the example section, assuming `"""` is used to define the docstring
parts = original_docstring.split("```")
if "```" in updated_docstring and len(parts) > 0:
# an example is provide! Overwrite the other example
Expand All @@ -263,19 +262,21 @@ def merge_docstrings(original_docstring, updated_docstring):
updated_docstring = "".join(split_updated_docstring[:1] + split_updated_docstring[2:])

if len(parts) > 1:
doc = re.sub(r"\n\s{3}\s", "\n ", updated_docstring, count=1)
doc = doc.lstrip('r\"').replace('"""', "")
# tabulation when tabulation is missing
if False:
updated_docstring = re.sub(r"\n\s{4}", "\n ", updated_docstring, count=1)
doc = updated_docstring.lstrip('r\"').replace('"""', "")
updated_docstring = "".join(
[
parts[0].rstrip(" \n") + doc,
parts[0].rstrip(" \n") + doc,
"\n ```",
parts[1],
"```",
parts[2],
]
)
elif updated_docstring not in original_docstring:
# add tabulation:
# add tabulation if we are at the lowest level.
updated_docstring = original_docstring.rstrip('\"')+ "\n" + updated_docstring.lstrip('r\"')
return updated_docstring

Expand Down Expand Up @@ -720,7 +721,7 @@ def save_modeling_file(diff_file, converted_file):
[line for line in converted_file[file_type][0].strip().split("\n") if not line.strip().startswith("#")]
)
if len(converted_file[file_type][1].strip()) > 0 and non_comment_lines > 0:
logger.warning("The modeling code contains erros, it's written without formatting")
logger.warning("The modeling code contains errors, it's written without formatting")
with open(diff_file.replace("modular_", f"{file_type}_"), "w") as f:
f.write(converted_file[file_type][1])

Expand All @@ -729,7 +730,7 @@ def save_modeling_file(diff_file, converted_file):
parser = argparse.ArgumentParser()
parser.add_argument(
"--files_to_parse",
default=["examples/diff-conversion/modular_my_new_model.py"],
default=["all"],
nargs="+",
help="A list of `diff_xxxx` files that should be converted to single model file",
)
Expand Down

0 comments on commit ae095c5

Please sign in to comment.