Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(plugins): reload plugin at runtime #2372

Merged
merged 17 commits into from
Apr 19, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
refactor(plugins): load/reload plugin
  • Loading branch information
imsnif committed Apr 17, 2023
commit 4b9a3043a947266c93cc640668f9b58e5de8b32f
1 change: 1 addition & 0 deletions zellij-server/src/plugins/plugin_loader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,7 @@ impl <'a> PluginLoader <'a>{
plugin_loader.load_plugin_instance(&instance, &plugin_env)?;
plugin_loader.clone_instance_for_other_clients(&instance, &plugin_env, &connected_clients)
})
.map(|_| plugin_loader.apply_plugin_size())
.with_context(err_context)?;
display_loading_stage!(end, loading_indication, senders, plugin_id);
Ok(())
Expand Down
147 changes: 63 additions & 84 deletions zellij-server/src/plugins/wasm_bridge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ use crate::plugins::plugin_loader::{PluginLoader, VersionMismatchError};
use log::{debug, info, warn};
use serde::{de::DeserializeOwned, Serialize};
use std::{
fmt::Display,
collections::{HashMap, HashSet},
fmt,
path::PathBuf,
process,
str::FromStr,
Expand Down Expand Up @@ -164,27 +164,10 @@ impl WasmBridge {
connected_clients.clone(),
&mut loading_indication,
) {
Ok(_) => {
let _ = senders.send_to_background_jobs(
BackgroundJob::StopPluginLoadingAnimation(plugin_id),
);
let _ =
senders.send_to_plugin(PluginInstruction::ApplyCachedEvents(vec![plugin_id]));
},
Err(e) => {
let _ = senders.send_to_background_jobs(
BackgroundJob::StopPluginLoadingAnimation(plugin_id),
);
let _ =
senders.send_to_plugin(PluginInstruction::ApplyCachedEvents(vec![plugin_id]));
loading_indication.indicate_loading_error(e.to_string());
let _ =
senders.send_to_screen(ScreenInstruction::UpdatePluginLoadingStage(
plugin_id,
loading_indication.clone(),
));
},
Ok(_) => handle_plugin_successful_loading(&senders, plugin_id),
Err(e) => handle_plugin_loading_failure(&senders, plugin_id, &mut loading_indication, e),
}
let _ = senders.send_to_plugin(PluginInstruction::ApplyCachedEvents(vec![plugin_id]));
}
});
self.loading_plugins.insert((plugin_id, run.location.clone()), load_plugin_task);
Expand All @@ -204,26 +187,15 @@ impl WasmBridge {
Ok(())
}
pub fn reload_plugin(&mut self, run_plugin: &RunPlugin) -> Result<()> {
// TODO: CONTINUE HERE - break down this function into smaller parts and combine with
// load_plugin
let err_context = || "Failed to reload plugin";
let plugin_is_currently_being_loaded = self.loading_plugins.iter().find(|((_plugin_id, run_plugin_location), _)| {
run_plugin_location == &run_plugin.location
}).is_some();
if plugin_is_currently_being_loaded {
if self.plugin_is_currently_being_loaded(&run_plugin.location) {
self.pending_plugin_reloads.insert(run_plugin.clone());
return Ok(());
}
let mut plugin_ids: Vec<PluginId> = self.plugin_map.lock().unwrap().iter().filter(|((plugin_id, client_id), (instance, plugin_env, size))| {
plugin_env.plugin.location == run_plugin.location
})
.map(|((plugin_id, _client_id), _)| *plugin_id)
.collect();
if plugin_ids.is_empty() {
return Err(ZellijError::PluginDoesNotExist).with_context(err_context);
}
let first_plugin_id = *plugin_ids.get(0).unwrap();

let plugin_ids = self.all_plugin_ids_for_plugin_location(&run_plugin.location)?;
let first_plugin_id = *plugin_ids.get(0).unwrap(); // this is safe becaise the above
// methods always returns at least 1 id
let mut loading_indication = LoadingIndication::new("".into());
self.start_plugin_loading_indication(&plugin_ids, &loading_indication);
let load_plugin_task = task::spawn({
let plugin_dir = self.plugin_dir.clone();
let plugin_cache = self.plugin_cache.clone();
Expand All @@ -232,15 +204,6 @@ impl WasmBridge {
let plugin_map = self.plugin_map.clone();
let connected_clients = self.connected_clients.clone();
async move {
let mut loading_indication = LoadingIndication::new("".into());
plugin_ids.push(first_plugin_id);
for plugin_id in &plugin_ids {
let _ =
senders.send_to_screen(ScreenInstruction::StartPluginLoadingIndication(*plugin_id, loading_indication.clone()));
let _ =
senders.send_to_background_jobs(BackgroundJob::AnimatePluginLoading(*plugin_id));
}
// the plugin name will be set inside the reload_plugin function
match PluginLoader::reload_plugin(
first_plugin_id,
plugin_dir.clone(),
Expand All @@ -252,12 +215,12 @@ impl WasmBridge {
&mut loading_indication,
) {
Ok(_) => {
let _ = senders.send_to_background_jobs(
BackgroundJob::StopPluginLoadingAnimation(first_plugin_id),
);
let _ = senders.send_to_screen(ScreenInstruction::RequestStateUpdateForPlugin(first_plugin_id));
let _ = plugin_ids.pop(); // remove the first plugin we just reloaded
handle_plugin_successful_loading(&senders, first_plugin_id);
for plugin_id in &plugin_ids {
if plugin_id == &first_plugin_id {
// no need to reload the plugin we just reloaded
continue;
}
let mut loading_indication = LoadingIndication::new("".into());
match PluginLoader::reload_plugin_from_memory(
*plugin_id,
Expand All @@ -269,46 +232,18 @@ impl WasmBridge {
connected_clients.clone(),
&mut loading_indication
) {
Ok(_) => {
// TODO: combine with above (and with start_plugin?)
let _ = senders.send_to_background_jobs(
BackgroundJob::StopPluginLoadingAnimation(*plugin_id),
);
let _ = senders.send_to_screen(ScreenInstruction::RequestStateUpdateForPlugin(*plugin_id));
},
Err(e) => {
let _ = senders.send_to_background_jobs(
BackgroundJob::StopPluginLoadingAnimation(*plugin_id),
);
loading_indication.indicate_loading_error(e.to_string());
let _ =
senders.send_to_screen(ScreenInstruction::UpdatePluginLoadingStage(
*plugin_id,
loading_indication.clone(),
));
}
Ok(_) => handle_plugin_successful_loading(&senders, *plugin_id),
Err(e) => handle_plugin_loading_failure(&senders, *plugin_id, &mut loading_indication, e),
}
}
let _ = senders.send_to_plugin(PluginInstruction::ApplyCachedEvents(plugin_ids));
},
Err(e) => {
for plugin_id in &plugin_ids {
let _ = senders.send_to_background_jobs(
BackgroundJob::StopPluginLoadingAnimation(*plugin_id),
);
// let _ =
// senders.send_to_plugin(PluginInstruction::ApplyCachedEvents(plugin_id));
loading_indication.indicate_loading_error(e.to_string());
let _ =
senders.send_to_screen(ScreenInstruction::UpdatePluginLoadingStage(
*plugin_id,
loading_indication.clone(),
));
handle_plugin_loading_failure(&senders, *plugin_id, &mut loading_indication, &e);
}
let _ =
senders.send_to_plugin(PluginInstruction::ApplyCachedEvents(plugin_ids));
},
}
let _ = senders.send_to_plugin(PluginInstruction::ApplyCachedEvents(plugin_ids));
}
});
self.loading_plugins.insert((first_plugin_id, run_plugin.location.clone()), load_plugin_task);
Expand Down Expand Up @@ -523,6 +458,50 @@ impl WasmBridge {
drop(loading_plugin_task.cancel());
}
}
fn plugin_is_currently_being_loaded(&self, plugin_location: &RunPluginLocation) -> bool {
self.loading_plugins.iter().find(|((_plugin_id, run_plugin_location), _)| {
run_plugin_location == plugin_location
}).is_some()
}
fn all_plugin_ids_for_plugin_location(&self, plugin_location: &RunPluginLocation) -> Result<Vec<PluginId>> {
let err_context = || format!("Failed to get plugin ids for location {plugin_location}");
let plugin_ids: Vec<PluginId> = self.plugin_map.lock().unwrap().iter().filter(|((_plugin_id, _client_id), (_instance, plugin_env, _size))| {
&plugin_env.plugin.location == plugin_location
})
.map(|((plugin_id, _client_id), _)| *plugin_id)
.collect();
if plugin_ids.is_empty() {
return Err(ZellijError::PluginDoesNotExist).with_context(err_context);
}
Ok(plugin_ids)
}
fn start_plugin_loading_indication(&self, plugin_ids: &[PluginId], loading_indication: &LoadingIndication) {
for plugin_id in plugin_ids {
let _ =
self.senders.send_to_screen(ScreenInstruction::StartPluginLoadingIndication(*plugin_id, loading_indication.clone()));
let _ =
self.senders.send_to_background_jobs(BackgroundJob::AnimatePluginLoading(*plugin_id));
}
}
}

fn handle_plugin_successful_loading(senders: &ThreadSenders, plugin_id: PluginId) {
let _ = senders.send_to_background_jobs(
BackgroundJob::StopPluginLoadingAnimation(plugin_id),
);
let _ = senders.send_to_screen(ScreenInstruction::RequestStateUpdateForPlugin(plugin_id));
}

fn handle_plugin_loading_failure(senders: &ThreadSenders, plugin_id: PluginId, loading_indication: &mut LoadingIndication, error: impl Display) {
let _ = senders.send_to_background_jobs(
BackgroundJob::StopPluginLoadingAnimation(plugin_id),
);
loading_indication.indicate_loading_error(error.to_string());
let _ =
senders.send_to_screen(ScreenInstruction::UpdatePluginLoadingStage(
plugin_id,
loading_indication.clone(),
));
}

fn load_plugin_instance(instance: &mut Instance) -> Result<()> {
Expand Down