Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 11 additions & 9 deletions pyrefly/lib/commands/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ use crate::commands::files::get_project_config_for_current_dir;
use crate::commands::util::CommandExitStatus;
use crate::config::error_kind::ErrorKind;
use crate::lsp::wasm::inlay_hints::ParameterAnnotation;
use crate::state::ide::ImportEdit;
use crate::state::ide::insert_import_edit_with_forced_import_format;
use crate::state::lsp::AnnotationKind;
use crate::state::require::Require;
Expand Down Expand Up @@ -292,7 +293,7 @@ impl InferArgs {
if let Some(ast) = transaction.get_ast(&handle) {
let error_range = error.range();
let unknown_name = module_info.code_at(error_range);
let imports: Vec<(TextSize, String, String)> = transaction
let imports: Vec<ImportEdit> = transaction
.search_exports_exact(unknown_name)
.into_iter()
.map(|handle_to_import_from| {
Expand All @@ -302,6 +303,7 @@ impl InferArgs {
handle_to_import_from.dupe(),
unknown_name,
true,
/*merge_with_existing=*/ false,
)
})
.collect();
Expand Down Expand Up @@ -337,16 +339,16 @@ impl InferArgs {
fs_anyhow::write(file_path, result)
}

fn add_imports_to_file(
file_path: &Path,
imports: Vec<(TextSize, String, String)>,
) -> anyhow::Result<()> {
fn add_imports_to_file(file_path: &Path, imports: Vec<ImportEdit>) -> anyhow::Result<()> {
let file_content = fs_anyhow::read_to_string(file_path)?;
let mut result = file_content;
for (position, import, _) in imports {
let offset = (position).into();
if !result.contains(&import) {
result.insert_str(offset, &import);
for import_edit in imports {
if import_edit.insert_text.is_empty() {
continue;
}
let offset = (import_edit.position).into();
if offset <= result.len() && !result.contains(&import_edit.insert_text) {
result.insert_str(offset, &import_edit.insert_text);
}
}
fs_anyhow::write(file_path, result)
Expand Down
94 changes: 86 additions & 8 deletions pyrefly/lib/state/ide.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ use pyrefly_python::symbol_kind::SymbolKind;
use pyrefly_util::gas::Gas;
use ruff_python_ast::Expr;
use ruff_python_ast::ModModule;
use ruff_python_ast::Stmt;
use ruff_python_ast::StmtImportFrom;
use ruff_python_ast::helpers::is_docstring_stmt;
use ruff_python_ast::name::Name;
use ruff_text_size::Ranged;
Expand All @@ -33,6 +35,14 @@ use crate::state::lsp::ImportFormat;

const KEY_TO_DEFINITION_INITIAL_GAS: Gas = Gas::new(100);

#[derive(Clone, Debug, PartialEq, Eq)]
pub struct ImportEdit {
pub position: TextSize,
pub insert_text: String,
pub display_text: String,
pub module_name: String,
}

pub enum IntermediateDefinition {
Local(Export),
NamedImport(TextRange, ModuleName, Name, Option<TextRange>),
Expand Down Expand Up @@ -203,7 +213,7 @@ pub fn insert_import_edit(
handle_to_import_from: Handle,
export_name: &str,
import_format: ImportFormat,
) -> (TextSize, String, String) {
) -> ImportEdit {
let use_absolute_import = match import_format {
ImportFormat::Absolute => true,
ImportFormat::Relative => {
Expand All @@ -216,6 +226,7 @@ pub fn insert_import_edit(
handle_to_import_from,
export_name,
use_absolute_import,
true,
)
}

Expand All @@ -240,12 +251,8 @@ pub fn insert_import_edit_with_forced_import_format(
handle_to_import_from: Handle,
export_name: &str,
use_absolute_import: bool,
) -> (TextSize, String, String) {
let position = if let Some(first_stmt) = ast.body.iter().find(|stmt| !is_docstring_stmt(stmt)) {
first_stmt.range().start()
} else {
ast.range.end()
};
merge_with_existing: bool,
) -> ImportEdit {
let module_name_to_import = if use_absolute_import {
handle_to_import_from.module()
} else if let Some(relative_module) = ModuleName::relative_module_name_between(
Expand All @@ -256,12 +263,38 @@ pub fn insert_import_edit_with_forced_import_format(
} else {
handle_to_import_from.module()
};
let display_text = format!(
"from {} import {}",
module_name_to_import.as_str(),
export_name
);
if merge_with_existing
&& let Some(edit) = try_extend_existing_from_import(
ast,
module_name_to_import.as_str(),
export_name,
display_text.clone(),
module_name_to_import.as_str(),
)
{
return edit;
}
let position = if let Some(first_stmt) = ast.body.iter().find(|stmt| !is_docstring_stmt(stmt)) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if there is some way we can avoid traversing the entire AST again? It might lead to slow performance for large modules.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One way to avoid the second walk would be to fold both concerns into a single pass (track the insertion position while scanning for compatible imports)

Or to compute the insertion point once before calling try_extend_existing_from_import and pass it in.

first_stmt.range().start()
} else {
ast.range.end()
};
let insert_text = format!(
"from {} import {}\n",
module_name_to_import.as_str(),
export_name
);
(position, insert_text, module_name_to_import.to_string())
ImportEdit {
position,
insert_text,
display_text,
module_name: module_name_to_import.to_string(),
}
}

/// Some handles must be imported in absolute style,
Expand All @@ -285,3 +318,48 @@ fn handle_require_absolute_import(config_finder: &ConfigFinder, handle: &Handle)
.site_package_path()
.any(|search_path| handle.path().as_path().starts_with(search_path))
}

fn try_extend_existing_from_import(
ast: &ModModule,
target_module_name: &str,
export_name: &str,
display_text: String,
module_name: &str,
) -> Option<ImportEdit> {
for stmt in &ast.body {
if let Stmt::ImportFrom(import_from) = stmt
&& import_from_module_name(import_from) == target_module_name
{
if import_from
.names
.iter()
.any(|alias| alias.asname.is_none() && alias.name.as_str() == export_name)
{
// Already imported; don't propose a duplicate edit.
return None;
}
if let Some(last_alias) = import_from.names.last() {
let position = last_alias.range.end();
let insert_text = format!(", {}", export_name);
return Some(ImportEdit {
position,
insert_text,
display_text,
module_name: module_name.to_owned(),
});
}
}
}
None
}

fn import_from_module_name(import_from: &StmtImportFrom) -> String {
let mut module_name = String::new();
if import_from.level > 0 {
module_name.push_str(&".".repeat(import_from.level as usize));
}
if let Some(module) = &import_from.module {
module_name.push_str(module.as_str());
}
module_name
}
45 changes: 32 additions & 13 deletions pyrefly/lib/state/lsp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1642,17 +1642,26 @@ impl<'a> Transaction<'a> {
if error_range.contains_range(range) {
let unknown_name = module_info.code_at(error_range);
for handle_to_import_from in self.search_exports_exact(unknown_name) {
let (position, insert_text, _) = insert_import_edit(
let import_edit = insert_import_edit(
&ast,
self.config_finder(),
handle.dupe(),
handle_to_import_from,
unknown_name,
import_format,
);
let range = TextRange::at(position, TextSize::new(0));
let title = format!("Insert import: `{}`", insert_text.trim());
code_actions.push((title, module_info.dupe(), range, insert_text));
// If the symbol was already imported we get an empty edit; skip it.
if import_edit.insert_text.is_empty() {
continue;
}
let range = TextRange::at(import_edit.position, TextSize::new(0));
let title = format!("Insert import: `{}`", import_edit.display_text);
code_actions.push((
title,
module_info.dupe(),
range,
import_edit.insert_text,
));
}

for module_name in self.search_modules_fuzzy(unknown_name) {
Expand Down Expand Up @@ -2176,30 +2185,40 @@ impl<'a> Transaction<'a> {
&& let Some(ast) = self.get_ast(handle)
&& let Some(module_info) = self.get_module_info(handle)
{
for (handle_to_import_from, name, export) in
self.search_exports_fuzzy(identifier.as_str())
{
let search_results = self.search_exports_fuzzy(identifier.as_str());
for (handle_to_import_from, name, export) in search_results {
if !identifier.as_str().starts_with('_') && name.starts_with('_') {
continue;
}
// Using handle itself doesn't always work because handles can be made separately and have different hashes
if handle_to_import_from.module() == handle.module()
|| handle_to_import_from.module() == ModuleName::builtins()
{
continue;
}
let module_description = handle_to_import_from.module().as_str().to_owned();
let (insert_text, additional_text_edits, imported_module) = {
let (position, insert_text, module_name) = insert_import_edit(
let (detail_text, additional_text_edits, imported_module) = {
let import_edit = insert_import_edit(
&ast,
self.config_finder(),
handle.dupe(),
handle_to_import_from,
&name,
import_format,
);
if import_edit.insert_text.is_empty() {
continue;
}
let import_text_edit = TextEdit {
range: module_info.to_lsp_range(TextRange::at(position, TextSize::new(0))),
new_text: insert_text.clone(),
range: module_info
.to_lsp_range(TextRange::at(import_edit.position, TextSize::new(0))),
new_text: import_edit.insert_text.clone(),
};
(insert_text, Some(vec![import_text_edit]), module_name)
(
Some(import_edit.insert_text.clone()),
Some(vec![import_text_edit]),
import_edit.module_name,
)
};
let auto_import_label_detail = format!(" (import {imported_module})");
let (label, label_details) = if supports_completion_item_details {
Expand All @@ -2215,7 +2234,7 @@ impl<'a> Transaction<'a> {
};
completions.push(CompletionItem {
label,
detail: Some(insert_text),
detail: detail_text,
kind: export
.symbol_kind
.map_or(Some(CompletionItemKind::VARIABLE), |k| {
Expand Down
6 changes: 2 additions & 4 deletions pyrefly/lib/test/lsp/code_actions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -210,8 +210,7 @@ fn insertion_test_duplicate_imports() {
],
get_test_report,
);
// The insertion won't attempt to merge imports from the same module.
// It's not illegal, but it would be nice if we do merge.
// When another import from the same module already exists, we should append to it.
assert_eq!(
r#"
# a.py
Expand All @@ -227,8 +226,7 @@ from a import another_thing
my_export
# ^
## After:
from a import my_export
from a import another_thing
from a import another_thing, my_export
my_export
# ^
"#
Expand Down
4 changes: 2 additions & 2 deletions pyrefly/lib/test/lsp/completion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ fn get_test_report(
report.push_str("[DEPRECATED] ");
}
report.push_str(&label);
if let Some(detail) = detail {
if let Some(detail) = &detail {
report.push_str(": ");
report.push_str(&detail);
}
Expand All @@ -120,7 +120,7 @@ fn get_test_report(
report.push_str(" with text edit: ");
report.push_str(&format!("{:?}", &text_edit));
}
if let Some(documentation) = documentation {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why were these changes needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We were trying to match the historical golden strings for completion reports.

Previously, every entry, even plain keywords or literal values, had a blank line appended because the reporter always added one after the optional docstring block. That produced extra empty lines compared to the expected fixtures, which is why dozens of completion tests started failing.

The adjustment keeps the docstring formatting logic intact, but stops unconditionally inserting that trailing blank line unless there was actually extra content to separate. This wa,y the rendered report matches the snapshots again without altering runtime behavior.

if let Some(ref documentation) = documentation {
report.push('\n');
match documentation {
lsp_types::Documentation::String(s) => {
Expand Down