Skip to content

Commit

Permalink
imports
Browse files Browse the repository at this point in the history
  • Loading branch information
Cyrilvallez committed Oct 29, 2024
1 parent 422dec9 commit f0b932a
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,6 @@ def get_video_features(
video_features = torch.split(video_features, frames, dim=0)
return video_features

@replace_return_docstrings(output_type=LlavaNextVideoCausalLMOutputWithPast, config_class="LlavaNextVideoConfig")
def forward(
self,
input_ids: torch.LongTensor = None,
Expand Down
21 changes: 15 additions & 6 deletions utils/modular_model_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -959,7 +959,7 @@ def get_needed_imports(body: dict[str, dict], all_imports: list[cst.CSTNode]) ->
ref_count = len(assignment.references)
name = assignment.name
# Similar imports may be redefined, and only used between their 1st and 2nd definition
# so if we already have a ref count > 0, the imports is not unused
# so if we already have a ref count > 0, the imports is actually used
if (ref_count == 0 and import_ref_count.get(name, -1) <= 0) or name in body.keys():
unused_imports.add(name)
import_ref_count[name] = ref_count
Expand All @@ -976,7 +976,15 @@ def get_needed_imports(body: dict[str, dict], all_imports: list[cst.CSTNode]) ->
else:
append_new_import_node(node, unused_imports, imports_to_keep, idx)

return list(imports_to_keep.values())
protected_import_nodes = [node for node in imports_to_keep.values() if m.matches(node, m.If())]
usual_import_nodes = [node for node in imports_to_keep.values() if not m.matches(node, m.If())]
# If the same import is both protected and unprotected, only keep the protected one
for protected_node in protected_import_nodes:
for stmt_node in protected_node.body.body:
usual_import_nodes = [node for node in usual_import_nodes if node.body[0] != stmt_node.body[0]]

# Protected imports always appear at the end of all imports
return usual_import_nodes + protected_import_nodes


class ModularFileMapper(ModuleMapper):
Expand Down Expand Up @@ -1239,12 +1247,13 @@ def create_modules(self) -> dict[str, cst.Module]:
idx = current_file_indices[file_type]
files[file_type]["__all__"] = {"insert_idx": idx, "node": node}

# Aggregate all the imports statements (we look for duplicates with the code_for_node, not the nodes themselves)
# Aggregate all the imports statements (we look for duplicates with the code_for_node, not the nodes themselves because
# they are wrapped in SimpleStatementLine or If which could have different newlines, blanks etc)
all_imports = self.imports.copy()
all_imports_code = {self.python_module.code_for_node(node) for node in all_imports}
all_imports_code = {self.python_module.code_for_node(node).strip() for node in all_imports}
for file, mapper in self.visited_modules.items():
new_imports = [node for node in mapper.imports if mapper.python_module.code_for_node(node) not in all_imports_code]
new_imports_code = {mapper.python_module.code_for_node(node) for node in new_imports}
new_imports = [node for node in mapper.imports if mapper.python_module.code_for_node(node).strip() not in all_imports_code]
new_imports_code = {mapper.python_module.code_for_node(node).strip() for node in new_imports}
all_imports.extend(new_imports)
all_imports_code.update(new_imports_code)

Expand Down

0 comments on commit f0b932a

Please sign in to comment.