Skip to content

Commit

Permalink
assistant: Add imports in a single area when using workflows (zed-ind…
Browse files Browse the repository at this point in the history
…ustries#16355)

Co-Authored-by: Kirill <kirill@zed.dev>

Release Notes:

- N/A

---------

Co-authored-by: Kirill <kirill@zed.dev>
Co-authored-by: Thorsten <thorsten@zed.dev>
  • Loading branch information
3 people authored Aug 19, 2024
1 parent 7fbea39 commit 9089770
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 65 deletions.
3 changes: 3 additions & 0 deletions assets/prompts/step_resolution.hbs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ With each location, you will produce a brief, one-line description of the change
- When generating multiple suggestions, ensure the descriptions are specific to each individual operation.
- Avoid referring to the location in the description. Focus on the change to be made, not the location where it's made. That's implicit with the symbol you provide.
- Don't generate multiple suggestions at the same location. Instead, combine them together in a single operation with a succinct combined description.
- To add imports respond with a suggestion where the `"symbol"` key is set to `"#imports"`
</guidelines>
</overview>

Expand Down Expand Up @@ -203,6 +204,7 @@ Add a 'use std::fmt;' statement at the beginning of the file
{
"kind": "PrependChild",
"path": "src/vehicle.rs",
"symbol": "#imports",
"description": "Add 'use std::fmt' statement"
}
]
Expand Down Expand Up @@ -413,6 +415,7 @@ Add a 'load_from_file' method to Config and import necessary modules
{
"kind": "PrependChild",
"path": "src/config.rs",
"symbol": "#imports",
"description": "Import std::fs and std::io modules"
},
{
Expand Down
50 changes: 18 additions & 32 deletions crates/assistant/src/assistant_panel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1719,7 +1719,6 @@ struct WorkflowAssist {
editor: WeakView<Editor>,
editor_was_open: bool,
assist_ids: Vec<InlineAssistId>,
_observe_assist_status: Task<()>,
}

pub struct ContextEditor {
Expand Down Expand Up @@ -1862,13 +1861,25 @@ impl ContextEditor {
if let Some(workflow_step) = self.workflow_steps.get(&range) {
if let Some(assist) = workflow_step.assist.as_ref() {
let assist_ids = assist.assist_ids.clone();
cx.window_context().defer(|cx| {
InlineAssistant::update_global(cx, |assistant, cx| {
for assist_id in assist_ids {
assistant.start_assist(assist_id, cx);
cx.spawn(|this, mut cx| async move {
for assist_id in assist_ids {
let mut receiver = this.update(&mut cx, |_, cx| {
cx.window_context().defer(move |cx| {
InlineAssistant::update_global(cx, |assistant, cx| {
assistant.start_assist(assist_id, cx);
})
});
InlineAssistant::update_global(cx, |assistant, _| {
assistant.observe_assist(assist_id)
})
})?;
while !receiver.borrow().is_done() {
let _ = receiver.changed().await;
}
})
});
}
anyhow::Ok(())
})
.detach_and_log_err(cx);
}
}
}
Expand Down Expand Up @@ -3006,35 +3017,10 @@ impl ContextEditor {
}
}

let mut observations = Vec::new();
InlineAssistant::update_global(cx, |assistant, _cx| {
for assist_id in &assist_ids {
observations.push(assistant.observe_assist(*assist_id));
}
});

Some(WorkflowAssist {
assist_ids,
editor: editor.downgrade(),
editor_was_open,
_observe_assist_status: cx.spawn(|this, mut cx| async move {
while !observations.is_empty() {
let (result, ix, _) = futures::future::select_all(
observations
.iter_mut()
.map(|observation| Box::pin(observation.changed())),
)
.await;

if result.is_err() {
observations.remove(ix);
}

if this.update(&mut cx, |_, cx| cx.notify()).is_err() {
break;
}
}
}),
})
}

Expand Down
35 changes: 28 additions & 7 deletions crates/assistant/src/inline_assistant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,15 +76,33 @@ pub struct InlineAssistant {
assists: HashMap<InlineAssistId, InlineAssist>,
assists_by_editor: HashMap<WeakView<Editor>, EditorInlineAssists>,
assist_groups: HashMap<InlineAssistGroupId, InlineAssistGroup>,
assist_observations:
HashMap<InlineAssistId, (async_watch::Sender<()>, async_watch::Receiver<()>)>,
assist_observations: HashMap<
InlineAssistId,
(
async_watch::Sender<AssistStatus>,
async_watch::Receiver<AssistStatus>,
),
>,
confirmed_assists: HashMap<InlineAssistId, Model<Codegen>>,
prompt_history: VecDeque<String>,
prompt_builder: Arc<PromptBuilder>,
telemetry: Option<Arc<Telemetry>>,
fs: Arc<dyn Fs>,
}

pub enum AssistStatus {
Idle,
Started,
Stopped,
Finished,
}

impl AssistStatus {
pub fn is_done(&self) -> bool {
matches!(self, Self::Stopped | Self::Finished)
}
}

impl Global for InlineAssistant {}

impl InlineAssistant {
Expand Down Expand Up @@ -925,7 +943,7 @@ impl InlineAssistant {
.log_err();

if let Some((tx, _)) = self.assist_observations.get(&assist_id) {
tx.send(()).ok();
tx.send(AssistStatus::Started).ok();
}
}

Expand All @@ -939,7 +957,7 @@ impl InlineAssistant {
assist.codegen.update(cx, |codegen, cx| codegen.stop(cx));

if let Some((tx, _)) = self.assist_observations.get(&assist_id) {
tx.send(()).ok();
tx.send(AssistStatus::Stopped).ok();
}
}

Expand Down Expand Up @@ -1141,11 +1159,14 @@ impl InlineAssistant {
})
}

pub fn observe_assist(&mut self, assist_id: InlineAssistId) -> async_watch::Receiver<()> {
pub fn observe_assist(
&mut self,
assist_id: InlineAssistId,
) -> async_watch::Receiver<AssistStatus> {
if let Some((_, rx)) = self.assist_observations.get(&assist_id) {
rx.clone()
} else {
let (tx, rx) = async_watch::channel(());
let (tx, rx) = async_watch::channel(AssistStatus::Idle);
self.assist_observations.insert(assist_id, (tx, rx.clone()));
rx
}
Expand Down Expand Up @@ -2079,7 +2100,7 @@ impl InlineAssist {
if assist.decorations.is_none() {
this.finish_assist(assist_id, false, cx);
} else if let Some(tx) = this.assist_observations.get(&assist_id) {
tx.0.send(()).ok();
tx.0.send(AssistStatus::Finished).ok();
}
}
})
Expand Down
90 changes: 65 additions & 25 deletions crates/assistant/src/workflow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ use workspace::Workspace;

pub use step_view::WorkflowStepView;

const IMPORTS_SYMBOL: &str = "#imports";

pub struct WorkflowStep {
context: WeakModel<Context>,
context_buffer_range: Range<Anchor>,
Expand Down Expand Up @@ -467,7 +469,7 @@ pub mod tool {
use super::*;
use anyhow::Context as _;
use gpui::AsyncAppContext;
use language::ParseStatus;
use language::{Outline, OutlineItem, ParseStatus};
use language_model::LanguageModelTool;
use project::ProjectPath;
use schemars::JsonSchema;
Expand Down Expand Up @@ -562,10 +564,7 @@ pub mod tool {
symbol,
description,
} => {
let (symbol_path, symbol) = outline
.find_most_similar(&symbol)
.with_context(|| format!("symbol not found: {:?}", symbol))?;
let symbol = symbol.to_point(&snapshot);
let (symbol_path, symbol) = Self::resolve_symbol(&snapshot, &outline, &symbol)?;
let start = symbol
.annotation_range
.map_or(symbol.range.start, |range| range.start);
Expand All @@ -588,10 +587,7 @@ pub mod tool {
symbol,
description,
} => {
let (symbol_path, symbol) = outline
.find_most_similar(&symbol)
.with_context(|| format!("symbol not found: {:?}", symbol))?;
let symbol = symbol.to_point(&snapshot);
let (symbol_path, symbol) = Self::resolve_symbol(&snapshot, &outline, &symbol)?;
let position = snapshot.anchor_before(
symbol
.annotation_range
Expand All @@ -609,10 +605,7 @@ pub mod tool {
symbol,
description,
} => {
let (symbol_path, symbol) = outline
.find_most_similar(&symbol)
.with_context(|| format!("symbol not found: {:?}", symbol))?;
let symbol = symbol.to_point(&snapshot);
let (symbol_path, symbol) = Self::resolve_symbol(&snapshot, &outline, &symbol)?;
let position = snapshot.anchor_after(symbol.range.end);
WorkflowSuggestion::InsertSiblingAfter {
position,
Expand All @@ -625,10 +618,8 @@ pub mod tool {
description,
} => {
if let Some(symbol) = symbol {
let (symbol_path, symbol) = outline
.find_most_similar(&symbol)
.with_context(|| format!("symbol not found: {:?}", symbol))?;
let symbol = symbol.to_point(&snapshot);
let (symbol_path, symbol) =
Self::resolve_symbol(&snapshot, &outline, &symbol)?;

let position = snapshot.anchor_after(
symbol
Expand All @@ -653,10 +644,8 @@ pub mod tool {
description,
} => {
if let Some(symbol) = symbol {
let (symbol_path, symbol) = outline
.find_most_similar(&symbol)
.with_context(|| format!("symbol not found: {:?}", symbol))?;
let symbol = symbol.to_point(&snapshot);
let (symbol_path, symbol) =
Self::resolve_symbol(&snapshot, &outline, &symbol)?;

let position = snapshot.anchor_before(
symbol
Expand All @@ -677,10 +666,7 @@ pub mod tool {
}
}
WorkflowSuggestionToolKind::Delete { symbol } => {
let (symbol_path, symbol) = outline
.find_most_similar(&symbol)
.with_context(|| format!("symbol not found: {:?}", symbol))?;
let symbol = symbol.to_point(&snapshot);
let (symbol_path, symbol) = Self::resolve_symbol(&snapshot, &outline, &symbol)?;
let start = symbol
.annotation_range
.map_or(symbol.range.start, |range| range.start);
Expand All @@ -696,6 +682,60 @@ pub mod tool {

Ok((buffer, suggestion))
}

fn resolve_symbol(
snapshot: &BufferSnapshot,
outline: &Outline<Anchor>,
symbol: &str,
) -> Result<(SymbolPath, OutlineItem<Point>)> {
if symbol == IMPORTS_SYMBOL {
let target_row = find_first_non_comment_line(snapshot);
Ok((
SymbolPath(IMPORTS_SYMBOL.to_string()),
OutlineItem {
range: Point::new(target_row, 0)..Point::new(target_row + 1, 0),
..Default::default()
},
))
} else {
let (symbol_path, symbol) = outline
.find_most_similar(symbol)
.with_context(|| format!("symbol not found: {symbol}"))?;
Ok((symbol_path, symbol.to_point(snapshot)))
}
}
}

fn find_first_non_comment_line(snapshot: &BufferSnapshot) -> u32 {
let Some(language) = snapshot.language() else {
return 0;
};

let scope = language.default_scope();
let comment_prefixes = scope.line_comment_prefixes();

let mut chunks = snapshot.as_rope().chunks();
let mut target_row = 0;
loop {
let starts_with_comment = chunks
.peek()
.map(|chunk| {
comment_prefixes
.iter()
.any(|s| chunk.starts_with(s.as_ref().trim_end()))
})
.unwrap_or(false);

if !starts_with_comment {
break;
}

target_row += 1;
if !chunks.next_line() {
break;
}
}
target_row
}

#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
Expand Down
2 changes: 1 addition & 1 deletion crates/language/src/outline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ pub struct Outline<T> {
path_candidate_prefixes: Vec<usize>,
}

#[derive(Clone, Debug, PartialEq, Eq, Hash)]
#[derive(Clone, Debug, Default, PartialEq, Eq, Hash)]
pub struct OutlineItem<T> {
pub depth: usize,
pub range: Range<T>,
Expand Down

0 comments on commit 9089770

Please sign in to comment.