Skip to content

Commit ae661e8

Browse files
Copilotgkorland
andcommitted
Improve import resolution logic and add test cases
Co-authored-by: gkorland <753206+gkorland@users.noreply.github.com>
1 parent 28872e3 commit ae661e8

File tree

4 files changed

+136
-34
lines changed

4 files changed

+136
-34
lines changed

api/analyzers/python/analyzer.py

Lines changed: 46 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -152,42 +152,54 @@ def resolve_import(self, files: dict[Path, File], lsp: SyncLanguageServer, file_
152152
"""
153153
res = []
154154

155-
# For import statements like "import os" or "from pathlib import Path"
156-
# We need to find the dotted_name nodes that represent the imported modules/names
157-
import warnings
158-
with warnings.catch_warnings():
159-
warnings.simplefilter("ignore")
155+
try:
160156
if import_node.type == 'import_statement':
161157
# Handle "import module" or "import module as alias"
162-
# Look for dotted_name or aliased_import
163-
query = self.language.query("(dotted_name) @module (aliased_import) @aliased")
164-
else: # import_from_statement
165-
# Handle "from module import name"
166-
# Get the imported names (after the 'import' keyword)
167-
query = self.language.query("""
168-
(import_from_statement
169-
(dotted_name) @imported_name)
170-
""")
171-
172-
captures = query.captures(import_node)
173-
174-
# Try to resolve each imported name
175-
if 'module' in captures:
176-
for module_node in captures['module']:
177-
resolved = self.resolve_type(files, lsp, file_path, path, module_node)
178-
res.extend(resolved)
179-
180-
if 'aliased' in captures:
181-
for aliased_node in captures['aliased']:
182-
# Get the actual module name from the aliased import
183-
if aliased_node.child_count > 0:
184-
module_name_node = aliased_node.children[0]
185-
resolved = self.resolve_type(files, lsp, file_path, path, module_name_node)
186-
res.extend(resolved)
158+
# Find all dotted_name and aliased_import nodes
159+
for child in import_node.children:
160+
if child.type == 'dotted_name':
161+
# Try to resolve the module/name
162+
identifier = child.children[0] if child.child_count > 0 else child
163+
resolved = self.resolve_type(files, lsp, file_path, path, identifier)
164+
res.extend(resolved)
165+
elif child.type == 'aliased_import':
166+
# Get the actual name from aliased import (before 'as')
167+
if child.child_count > 0:
168+
actual_name = child.children[0]
169+
if actual_name.type == 'dotted_name' and actual_name.child_count > 0:
170+
identifier = actual_name.children[0]
171+
else:
172+
identifier = actual_name
173+
resolved = self.resolve_type(files, lsp, file_path, path, identifier)
174+
res.extend(resolved)
175+
176+
elif import_node.type == 'import_from_statement':
177+
# Handle "from module import name1, name2"
178+
# Find the 'import' keyword to know where imported names start
179+
import_keyword_found = False
180+
for child in import_node.children:
181+
if child.type == 'import':
182+
import_keyword_found = True
183+
continue
184+
185+
# After 'import' keyword, dotted_name nodes are the imported names
186+
if import_keyword_found and child.type == 'dotted_name':
187+
# Try to resolve the imported name
188+
identifier = child.children[0] if child.child_count > 0 else child
189+
resolved = self.resolve_type(files, lsp, file_path, path, identifier)
190+
res.extend(resolved)
191+
elif import_keyword_found and child.type == 'aliased_import':
192+
# Handle "from module import name as alias"
193+
if child.child_count > 0:
194+
actual_name = child.children[0]
195+
if actual_name.type == 'dotted_name' and actual_name.child_count > 0:
196+
identifier = actual_name.children[0]
197+
else:
198+
identifier = actual_name
199+
resolved = self.resolve_type(files, lsp, file_path, path, identifier)
200+
res.extend(resolved)
187201

188-
if 'imported_name' in captures:
189-
for name_node in captures['imported_name']:
190-
resolved = self.resolve_type(files, lsp, file_path, path, name_node)
191-
res.extend(resolved)
202+
except Exception as e:
203+
logger.debug(f"Failed to resolve import: {e}")
192204

193205
return res
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
"""Module A with a class definition."""
2+
3+
class ClassA:
4+
"""A simple class in module A."""
5+
6+
def method_a(self):
7+
"""A method in ClassA."""
8+
return "Method A"
9+
10+
def function_a():
11+
"""A function in module A."""
12+
return "Function A"
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
"""Module B that imports from module A."""
2+
3+
from module_a import ClassA, function_a
4+
5+
class ClassB(ClassA):
6+
"""A class that extends ClassA."""
7+
8+
def method_b(self):
9+
"""A method in ClassB."""
10+
result = function_a()
11+
return f"Method B: {result}"

tests/test_py_imports.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
import os
2+
import unittest
3+
from pathlib import Path
4+
5+
from api import SourceAnalyzer, File, Graph
6+
7+
8+
class Test_PY_Imports(unittest.TestCase):
9+
def test_import_tracking(self):
10+
"""Test that Python imports are tracked correctly."""
11+
# Get test file path
12+
current_dir = os.path.dirname(os.path.abspath(__file__))
13+
test_path = os.path.join(current_dir, 'source_files', 'py_imports')
14+
15+
# Create graph and analyze
16+
g = Graph("py_imports_test")
17+
analyzer = SourceAnalyzer()
18+
19+
try:
20+
analyzer.analyze_local_folder(test_path, g)
21+
22+
# Verify files were created
23+
module_a = g.get_file('', 'module_a.py', '.py')
24+
self.assertIsNotNone(module_a, "module_a.py should be in the graph")
25+
26+
module_b = g.get_file('', 'module_b.py', '.py')
27+
self.assertIsNotNone(module_b, "module_b.py should be in the graph")
28+
29+
# Verify classes were created
30+
class_a = g.get_class_by_name('ClassA')
31+
self.assertIsNotNone(class_a, "ClassA should be in the graph")
32+
33+
class_b = g.get_class_by_name('ClassB')
34+
self.assertIsNotNone(class_b, "ClassB should be in the graph")
35+
36+
# Verify function was created
37+
func_a = g.get_function_by_name('function_a')
38+
self.assertIsNotNone(func_a, "function_a should be in the graph")
39+
40+
# Test: module_b should have IMPORTS relationship to ClassA
41+
# Query to check if module_b imports ClassA
42+
query = """
43+
MATCH (f:File {name: 'module_b.py'})-[:IMPORTS]->(c:Class {name: 'ClassA'})
44+
RETURN c
45+
"""
46+
result = g._query(query, {})
47+
self.assertGreater(len(result.result_set), 0,
48+
"module_b.py should import ClassA")
49+
50+
# Test: module_b should have IMPORTS relationship to function_a
51+
query = """
52+
MATCH (f:File {name: 'module_b.py'})-[:IMPORTS]->(fn:Function {name: 'function_a'})
53+
RETURN fn
54+
"""
55+
result = g._query(query, {})
56+
self.assertGreater(len(result.result_set), 0,
57+
"module_b.py should import function_a")
58+
59+
print("✓ Import tracking test passed")
60+
61+
finally:
62+
# Cleanup: delete the test graph
63+
g.delete()
64+
65+
66+
if __name__ == '__main__':
67+
unittest.main()

0 commit comments

Comments
 (0)