Skip to content

Commit 36ebdb0

Browse files
committed
fix(vscode): fix returning too many references
1 parent bbd5dbb commit 36ebdb0

File tree

3 files changed

+109
-3
lines changed

3 files changed

+109
-3
lines changed

sqlmesh/lsp/main.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from sqlmesh.lsp.completions import get_sql_completions
1515
from sqlmesh.lsp.context import LSPContext, ModelTarget
1616
from sqlmesh.lsp.custom import ALL_MODELS_FEATURE, AllModelsRequest, AllModelsResponse
17-
from sqlmesh.lsp.reference import get_model_definitions_for_a_path
17+
from sqlmesh.lsp.reference import get_model_definitions_for_a_path, filter_references_by_position
1818

1919

2020
class SQLMeshLanguageServer:
@@ -186,12 +186,18 @@ def goto_definition(
186186
if self.lsp_context is None:
187187
raise RuntimeError(f"No context found for document: {document.path}")
188188

189+
# Get all possible references
189190
references = get_model_definitions_for_a_path(
190191
self.lsp_context, params.text_document.uri
191192
)
192193
if not references:
193194
return []
194195

196+
# Filter references by cursor position
197+
filtered_references = filter_references_by_position(references, params.position)
198+
if not filtered_references:
199+
return []
200+
195201
return [
196202
types.LocationLink(
197203
target_uri=reference.uri,
@@ -205,7 +211,7 @@ def goto_definition(
205211
),
206212
origin_selection_range=reference.range,
207213
)
208-
for reference in references
214+
for reference in filtered_references
209215
]
210216

211217
except Exception as e:

sqlmesh/lsp/reference.py

+33
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,39 @@ class Reference(PydanticModel):
1414
uri: str
1515

1616

17+
def filter_references_by_position(
18+
references: t.List[Reference], position: Position
19+
) -> t.List[Reference]:
20+
"""
21+
Filter references to only include those that contain the given position.
22+
23+
Args:
24+
references: List of Reference objects
25+
position: The cursor position to check
26+
27+
Returns:
28+
List of Reference objects that contain the position
29+
"""
30+
filtered_references = []
31+
32+
for reference in references:
33+
# Check if position is within the reference range
34+
range_start = reference.range.start
35+
range_end = reference.range.end
36+
37+
# Position is within range if it's after or at start and before or at end
38+
if (
39+
range_start.line < position.line
40+
or (range_start.line == position.line and range_start.character <= position.character)
41+
) and (
42+
range_end.line > position.line
43+
or (range_end.line == position.line and range_end.character >= position.character)
44+
):
45+
filtered_references.append(reference)
46+
47+
return filtered_references
48+
49+
1750
def get_model_definitions_for_a_path(
1851
lint_context: LSPContext, document_uri: str
1952
) -> t.List[Reference]:

tests/lsp/test_reference.py

+68-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import pytest
2+
from lsprotocol.types import Position
23
from sqlmesh.core.context import Context
34
from sqlmesh.lsp.context import LSPContext, ModelTarget, AuditTarget
4-
from sqlmesh.lsp.reference import get_model_definitions_for_a_path
5+
from sqlmesh.lsp.reference import get_model_definitions_for_a_path, filter_references_by_position
56

67

78
@pytest.mark.fast
@@ -108,3 +109,69 @@ def get_string_from_range(file_lines, range_obj) -> str:
108109
result += file_lines[line_num]
109110
result += file_lines[end_line][:end_character] # Last line up to end_character
110111
return result
112+
113+
114+
@pytest.mark.fast
115+
def test_filter_references_by_position() -> None:
116+
"""Test that we can filter references correctly based on cursor position."""
117+
context = Context(paths=["examples/sushi"])
118+
lsp_context = LSPContext(context)
119+
120+
# Use a file with multiple references (waiter_revenue_by_day)
121+
waiter_revenue_by_day_uri = next(
122+
uri
123+
for uri, info in lsp_context.map.items()
124+
if isinstance(info, ModelTarget) and "sushi.waiter_revenue_by_day" in info.names
125+
)
126+
127+
# Get all references in the file
128+
all_references = get_model_definitions_for_a_path(lsp_context, waiter_revenue_by_day_uri)
129+
assert len(all_references) == 3
130+
131+
# Get file contents to locate positions for testing
132+
path = waiter_revenue_by_day_uri.removeprefix("file://")
133+
read_file = open(path, "r").readlines()
134+
135+
# Test positions for each reference
136+
for i, reference in enumerate(all_references):
137+
# Position inside the reference - should return exactly one reference
138+
middle_line = (reference.range.start.line + reference.range.end.line) // 2
139+
middle_char = (reference.range.start.character + reference.range.end.character) // 2
140+
position_inside = Position(line=middle_line, character=middle_char)
141+
filtered = filter_references_by_position(all_references, position_inside)
142+
assert len(filtered) == 1
143+
assert filtered[0].uri == reference.uri
144+
assert filtered[0].range == reference.range
145+
146+
# For testing outside position, use a position before the current reference
147+
# or after the last reference for the last one
148+
if i == 0:
149+
outside_line = reference.range.start.line
150+
outside_char = max(0, reference.range.start.character - 5)
151+
else:
152+
prev_ref = all_references[i - 1]
153+
outside_line = prev_ref.range.end.line
154+
outside_char = prev_ref.range.end.character + 5
155+
156+
position_outside = Position(line=outside_line, character=outside_char)
157+
filtered_outside = filter_references_by_position(all_references, position_outside)
158+
assert reference not in filtered_outside, (
159+
f"Reference {i} should not match position outside its range"
160+
)
161+
162+
# Test case: cursor at beginning of file - no references should match
163+
position_start = Position(line=0, character=0)
164+
filtered_start = filter_references_by_position(all_references, position_start)
165+
assert len(filtered_start) == 0 or all(
166+
ref.range.start.line == 0 and ref.range.start.character <= 0 for ref in filtered_start
167+
)
168+
169+
# Test case: cursor at end of file - no references should match (unless there's a reference at the end)
170+
last_line = len(read_file) - 1
171+
last_char = len(read_file[last_line]) - 1
172+
position_end = Position(line=last_line, character=last_char)
173+
filtered_end = filter_references_by_position(all_references, position_end)
174+
assert len(filtered_end) == 0 or all(
175+
ref.range.end.line >= last_line and ref.range.end.character >= last_char
176+
for ref in filtered_end
177+
)

0 commit comments

Comments
 (0)