Skip to content

Commit

Permalink
Better renaming to avoid visiting same file multiple times
Browse files Browse the repository at this point in the history
  • Loading branch information
Cyrilvallez committed Oct 29, 2024
1 parent c7b20dc commit 01f7119
Showing 1 changed file with 32 additions and 56 deletions.
88 changes: 32 additions & 56 deletions utils/modular_model_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,24 @@ def get_module_source_from_name(module_name: str) -> str:
source_code = file.read()
return source_code

def preserve_case_replace(text, patterns: dict, default_name: str):
# Create a regex pattern to match all variations
regex_pattern = "|".join(re.escape(key) for key in patterns.keys())
compiled_regex = re.compile(regex_pattern, re.IGNORECASE)

def replace(match):
word = match.group(0)
result = patterns.get(word, default_name)
return result

return compiled_regex.sub(replace, text)

def convert_to_camelcase(text, old_name: str, default_old_name: str):
# Regex pattern to match consecutive uppercase letters and lowercase the first set
result = re.sub(
rf"^({old_name})(?=[a-z]+)", lambda m: default_old_name, text, flags=re.IGNORECASE, count=1
)
return result

class ReplaceNameTransformer(m.MatcherDecoratableTransformer):
"""A transformer that replaces `old_name` with `new_name` in comments, string and any references.
Expand Down Expand Up @@ -92,34 +110,15 @@ def __init__(
if self.default_old_name.isupper():
self.default_old_name = self.default_old_name.capitalize()

def preserve_case_replace(self, text):
# Create a regex pattern to match all variations
regex_pattern = "|".join(re.escape(key) for key in self.patterns.keys())
compiled_regex = re.compile(regex_pattern, re.IGNORECASE)

def replace(match):
word = match.group(0)
result = self.patterns.get(word, self.default_name)
return result

return compiled_regex.sub(replace, text)

def convert_to_camelcase(self, text):
# Regex pattern to match consecutive uppercase letters and lowercase the first set
result = re.sub(
rf"^({self.old_name})(?=[a-z]+)", lambda m: self.default_old_name, text, flags=re.IGNORECASE, count=1
)
return result

@m.leave(m.Name() | m.SimpleString() | m.Comment())
def replace_name(self, original_node, updated_node):
if re.findall(r"# Copied from", updated_node.value):
return cst.RemoveFromParent()
update = self.preserve_case_replace(updated_node.value)
update = preserve_case_replace(updated_node.value, self.patterns, self.default_name)
return updated_node.with_changes(value=update)

def leave_ClassDef(self, original_node, updated_node):
return updated_node.with_changes(name=cst.Name(self.convert_to_camelcase(updated_node.name.value)))
return updated_node.with_changes(name=cst.Name(convert_to_camelcase(updated_node.name.value, self.old_name, self.default_old_name)))


DOCSTRING_NODE = m.SimpleStatementLine(
Expand Down Expand Up @@ -236,13 +235,12 @@ def merge_docstrings(original_docstring, updated_docstring):
class SuperTransformer(cst.CSTTransformer):
METADATA_DEPENDENCIES = (ParentNodeProvider,)

def __init__(self, python_module: cst.Module, original_methods, updated_methods, class_name="", all_bases=None):
def __init__(self, python_module: cst.Module, original_methods, updated_methods, all_bases=None):
self.python_module = python_module
self.original_methods = original_methods
self.updated_methods = updated_methods
self.all_assign_target = {}
self.deleted_targets = {} # child node can delete some arguments
self.class_name = class_name
self.all_bases = all_bases or []
self.transformer = ReplaceMethodCallTransformer(set(self.all_bases))

Expand Down Expand Up @@ -743,7 +741,7 @@ def visit_and_merge_dependencies(
return mapper


def replace_class_node(mapper: ModelFileMapper, class_node: cst.ClassDef):
def replace_class_node(mapper: ModelFileMapper, class_node: cst.ClassDef, renamed_super_class: str):
"""
Replace a class node which inherits from an imported model-class. This function works in the following way:
- start from the class node of the inherited class
Expand All @@ -769,9 +767,8 @@ def replace_class_node(mapper: ModelFileMapper, class_node: cst.ClassDef):
| ```
"""
all_bases = [k.value.value for k in class_node.bases]
class_name = class_node.name.value

original_node = mapper.classes[class_name]
original_node = mapper.classes[renamed_super_class]
original_methods = {
f.name.value if hasattr(f, "name") else mapper.python_module.code_for_node(f): f
for f in original_node.body.body
Expand Down Expand Up @@ -846,9 +843,7 @@ def replace_class_node(mapper: ModelFileMapper, class_node: cst.ClassDef):
result_node = original_node.with_changes(body=cst.IndentedBlock(body=end_meth))
temp_module = cst.Module(body=[result_node])
new_module = MetadataWrapper(temp_module)
new_replacement_class = new_module.visit(
SuperTransformer(temp_module, original_methods, updated_methods, class_name, all_bases)
)
new_replacement_class = new_module.visit(SuperTransformer(temp_module, original_methods, updated_methods, all_bases))
new_replacement_body = new_replacement_class.body[0].body # get the indented block

# Use decorators redefined in `modular_xxx.py` if any
Expand All @@ -867,32 +862,6 @@ def replace_class_node(mapper: ModelFileMapper, class_node: cst.ClassDef):
"FeatureExtractor": "feature_extractor",
}


def get_new_part(class_name, base_class):
"""
When `MyClassNameAttention` inherits from `MistralAttention`, we need
to process the name to properly find dependencies.
Here we take what is the same (Attention) and what is different
when finding the dependencies.
"""
common_suffix_len = 0
for i in range(1, min(len(class_name), len(base_class)) + 1):
if class_name[-i] == base_class[-i]:
common_suffix_len += 1
else:
break

if common_suffix_len > 0:
new_part = class_name[:-common_suffix_len]
else:
new_part = class_name

# Convert the remaining new part to snake_case
snake_case = re.sub(r"(?<!^)(?=[A-Z])", "_", new_part).lower()
return snake_case


def find_file_type(class_name: str) -> str:
"""Based on a class name, find the file type corresponding to the class."""
match_pattern = "|".join(TYPE_TO_FILE_TYPE.keys())
Expand Down Expand Up @@ -1102,6 +1071,7 @@ def leave_Module(self, node):

# Now, visit every model-specific files found in the imports, and merge their dependencies
self.visited_modules = {}
self.renamers = {}
for file, module in self.model_specific_modules.items():
file_model_name = re.search(r"models\.\w*?\.\w*?_(\S*)", file).groups()[0]
renamer = ReplaceNameTransformer(
Expand All @@ -1115,6 +1085,8 @@ def leave_Module(self, node):
self.assignments,
self.start_lines,
)
# We record it so that we can rename classes later the exact same way
self.renamers[file] = renamer

# In turn, we need to add the imported functions/assignments to the dependencies of the modular mapper, using the
# definitions found in the visited files
Expand Down Expand Up @@ -1193,9 +1165,13 @@ class node based on the inherited classes if needed.

# Get the mapper corresponding to the inherited class
mapper = self.visited_modules[super_file_name]
# Rename the super class according to the exact same rule we used when renaming the whole module
renamer = self.renamers[super_file_name]
renamed_super_class = preserve_case_replace(super_class, renamer.patterns, renamer.default_name)
renamed_super_class = convert_to_camelcase(renamed_super_class, renamer.old_name, renamer.default_old_name)

# Create the new class node
updated_node = replace_class_node(mapper, node)
updated_node = replace_class_node(mapper, node, renamed_super_class)

# The node was modified -> look for all dependencies (recursively) of the new node
new_node_dependencies = ClassDependencyMapper.dependencies_for_new_node(updated_node, mapper)
Expand Down

0 comments on commit 01f7119

Please sign in to comment.