Skip to content

Commit

Permalink
assistant panel: Fix entering credentials not updating view (zed-indu…
Browse files Browse the repository at this point in the history
…stries#15527)

Co-authored-by: Bennet <bennet@zed.dev>

Release Notes:

- N/A

Co-authored-by: Bennet <bennet@zed.dev>
  • Loading branch information
mrnugget and bennetbo authored Jul 31, 2024
1 parent c78ea0d commit b571bc8
Show file tree
Hide file tree
Showing 10 changed files with 92 additions and 52 deletions.
38 changes: 28 additions & 10 deletions crates/assistant/src/assistant_panel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -393,8 +393,13 @@ impl AssistantPanel {
cx.subscribe(&context_store, Self::handle_context_store_event),
cx.subscribe(
&LanguageModelRegistry::global(cx),
|this, _, _: &language_model::ActiveModelChanged, cx| {
this.completion_provider_changed(cx);
|this, _, event: &language_model::Event, cx| match event {
language_model::Event::ActiveModelChanged => {
this.completion_provider_changed(cx);
}
language_model::Event::ProviderStateChanged => {
this.ensure_authenticated(cx);
}
},
),
];
Expand Down Expand Up @@ -587,6 +592,16 @@ impl AssistantPanel {
}

fn ensure_authenticated(&mut self, cx: &mut ViewContext<Self>) {
if self.is_authenticated(cx) {
for context_editor in self.context_editors(cx) {
context_editor.update(cx, |editor, cx| {
editor.set_authentication_prompt(None, cx);
});
}
cx.notify();
return;
}

let Some(provider_id) = LanguageModelRegistry::read_global(cx)
.active_provider()
.map(|p| p.id())
Expand All @@ -595,15 +610,18 @@ impl AssistantPanel {
};

let load_credentials = self.authenticate(cx);
let task = cx.spawn(|this, mut cx| async move {
let _ = load_credentials.await;
this.update(&mut cx, |this, cx| {
this.show_authentication_prompt(cx);
})
.log_err();
});

self.authenticate_provider_task = Some((provider_id, task));
self.authenticate_provider_task = Some((
provider_id,
cx.spawn(|this, mut cx| async move {
let _ = load_credentials.await;
this.update(&mut cx, |this, cx| {
this.show_authentication_prompt(cx);
this.authenticate_provider_task = None;
})
.log_err();
}),
));
}

fn show_authentication_prompt(&mut self, cx: &mut ViewContext<Self>) {
Expand Down
17 changes: 16 additions & 1 deletion crates/language_model/src/language_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,25 @@ pub trait LanguageModelProvider: 'static {
fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>>;
fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView;
fn reset_credentials(&self, cx: &mut AppContext) -> Task<Result<()>>;

// fn observable_entity(&self) ;
}

pub trait LanguageModelProviderState: 'static {
fn subscribe<T: 'static>(&self, cx: &mut gpui::ModelContext<T>) -> Option<gpui::Subscription>;
type ObservableEntity;

fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>>;

fn subscribe<T: 'static>(
&self,
cx: &mut gpui::ModelContext<T>,
callback: impl Fn(&mut T, &mut gpui::ModelContext<T>) + 'static,
) -> Option<gpui::Subscription> {
let entity = self.observable_entity()?;
Some(cx.observe(&entity, move |this, _, cx| {
callback(this, cx);
}))
}
}

#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
Expand Down
11 changes: 6 additions & 5 deletions crates/language_model/src/provider/anthropic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ pub struct AnthropicLanguageModelProvider {
state: gpui::Model<State>,
}

struct State {
pub struct State {
api_key: Option<String>,
_subscription: Subscription,
}
Expand All @@ -61,11 +61,12 @@ impl AnthropicLanguageModelProvider {
Self { http_client, state }
}
}

impl LanguageModelProviderState for AnthropicLanguageModelProvider {
fn subscribe<T: 'static>(&self, cx: &mut gpui::ModelContext<T>) -> Option<gpui::Subscription> {
Some(cx.observe(&self.state, |_, _, cx| {
cx.notify();
}))
type ObservableEntity = State;

fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>> {
Some(self.state.clone())
}
}

Expand Down
10 changes: 5 additions & 5 deletions crates/language_model/src/provider/cloud.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ pub struct CloudLanguageModelProvider {
_maintain_client_status: Task<()>,
}

struct State {
pub struct State {
client: Arc<Client>,
status: client::Status,
_subscription: Subscription,
Expand Down Expand Up @@ -99,10 +99,10 @@ impl CloudLanguageModelProvider {
}

impl LanguageModelProviderState for CloudLanguageModelProvider {
fn subscribe<T: 'static>(&self, cx: &mut gpui::ModelContext<T>) -> Option<gpui::Subscription> {
Some(cx.observe(&self.state, |_, _, cx| {
cx.notify();
}))
type ObservableEntity = State;

fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>> {
Some(self.state.clone())
}
}

Expand Down
12 changes: 6 additions & 6 deletions crates/language_model/src/provider/copilot_chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ use futures::future::BoxFuture;
use futures::stream::BoxStream;
use futures::{FutureExt, StreamExt};
use gpui::{
percentage, svg, Animation, AnimationExt, AnyView, AppContext, AsyncAppContext, Model,
ModelContext, Render, Subscription, Task, Transformation,
percentage, svg, Animation, AnimationExt, AnyView, AppContext, AsyncAppContext, Model, Render,
Subscription, Task, Transformation,
};
use settings::{Settings, SettingsStore};
use std::time::Duration;
Expand Down Expand Up @@ -67,10 +67,10 @@ impl CopilotChatLanguageModelProvider {
}

impl LanguageModelProviderState for CopilotChatLanguageModelProvider {
fn subscribe<T: 'static>(&self, cx: &mut ModelContext<T>) -> Option<Subscription> {
Some(cx.observe(&self.state, |_, _, cx| {
cx.notify();
}))
type ObservableEntity = State;

fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>> {
Some(self.state.clone())
}
}

Expand Down
4 changes: 3 additions & 1 deletion crates/language_model/src/provider/fake.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ pub struct FakeLanguageModelProvider {
}

impl LanguageModelProviderState for FakeLanguageModelProvider {
fn subscribe<T: 'static>(&self, _: &mut gpui::ModelContext<T>) -> Option<gpui::Subscription> {
type ObservableEntity = ();

fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>> {
None
}
}
Expand Down
10 changes: 5 additions & 5 deletions crates/language_model/src/provider/google.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ pub struct GoogleLanguageModelProvider {
state: gpui::Model<State>,
}

struct State {
pub struct State {
api_key: Option<String>,
_subscription: Subscription,
}
Expand All @@ -63,10 +63,10 @@ impl GoogleLanguageModelProvider {
}

impl LanguageModelProviderState for GoogleLanguageModelProvider {
fn subscribe<T: 'static>(&self, cx: &mut gpui::ModelContext<T>) -> Option<gpui::Subscription> {
Some(cx.observe(&self.state, |_, _, cx| {
cx.notify();
}))
type ObservableEntity = State;

fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>> {
Some(self.state.clone())
}
}

Expand Down
10 changes: 5 additions & 5 deletions crates/language_model/src/provider/ollama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ pub struct OllamaLanguageModelProvider {
state: gpui::Model<State>,
}

struct State {
pub struct State {
http_client: Arc<dyn HttpClient>,
available_models: Vec<ollama::Model>,
_subscription: Subscription,
Expand Down Expand Up @@ -87,10 +87,10 @@ impl OllamaLanguageModelProvider {
}

impl LanguageModelProviderState for OllamaLanguageModelProvider {
fn subscribe<T: 'static>(&self, cx: &mut gpui::ModelContext<T>) -> Option<gpui::Subscription> {
Some(cx.observe(&self.state, |_, _, cx| {
cx.notify();
}))
type ObservableEntity = State;

fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>> {
Some(self.state.clone())
}
}

Expand Down
10 changes: 5 additions & 5 deletions crates/language_model/src/provider/open_ai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ pub struct OpenAiLanguageModelProvider {
state: gpui::Model<State>,
}

struct State {
pub struct State {
api_key: Option<String>,
_subscription: Subscription,
}
Expand All @@ -64,10 +64,10 @@ impl OpenAiLanguageModelProvider {
}

impl LanguageModelProviderState for OpenAiLanguageModelProvider {
fn subscribe<T: 'static>(&self, cx: &mut gpui::ModelContext<T>) -> Option<gpui::Subscription> {
Some(cx.observe(&self.state, |_, _, cx| {
cx.notify();
}))
type ObservableEntity = State;

fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>> {
Some(self.state.clone())
}
}

Expand Down
22 changes: 13 additions & 9 deletions crates/language_model/src/registry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,7 @@ fn register_language_model_providers(
registry.register_provider(CloudLanguageModelProvider::new(client.clone(), cx), cx);
} else {
registry.unregister_provider(
&LanguageModelProviderId::from(
crate::provider::cloud::PROVIDER_NAME.to_string(),
),
&LanguageModelProviderId::from(crate::provider::cloud::PROVIDER_ID.to_string()),
cx,
);
}
Expand All @@ -80,9 +78,12 @@ pub struct ActiveModel {
model: Option<Arc<dyn LanguageModel>>,
}

pub struct ActiveModelChanged;
pub enum Event {
ActiveModelChanged,
ProviderStateChanged,
}

impl EventEmitter<ActiveModelChanged> for LanguageModelRegistry {}
impl EventEmitter<Event> for LanguageModelRegistry {}

impl LanguageModelRegistry {
pub fn global(cx: &AppContext) -> Model<Self> {
Expand Down Expand Up @@ -114,7 +115,10 @@ impl LanguageModelRegistry {
) {
let name = provider.id();

if let Some(subscription) = provider.subscribe(cx) {
let subscription = provider.subscribe(cx, |_, cx| {
cx.emit(Event::ProviderStateChanged);
});
if let Some(subscription) = subscription {
subscription.detach();
}

Expand Down Expand Up @@ -187,7 +191,7 @@ impl LanguageModelRegistry {
provider,
model: None,
});
cx.emit(ActiveModelChanged);
cx.emit(Event::ActiveModelChanged);
}

pub fn set_active_model(
Expand All @@ -202,13 +206,13 @@ impl LanguageModelRegistry {
provider,
model: Some(model),
});
cx.emit(ActiveModelChanged);
cx.emit(Event::ActiveModelChanged);
} else {
log::warn!("Active model's provider not found in registry");
}
} else {
self.active_model = None;
cx.emit(ActiveModelChanged);
cx.emit(Event::ActiveModelChanged);
}
}

Expand Down

0 comments on commit b571bc8

Please sign in to comment.