Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,18 @@ required-features = ["use_neovim_lib"]
name = "scorched_earth_as"
required-features = ["use_async-std"]

[[example]]
name = "nested_requests"
required-features = ["use_tokio"]

[[test]]
name = "nested_requests"
required-features = ["use_tokio"]

[[test]]
name = "connecting"
path = "tests/connecting/mod.rs"

[[test]]
name = "notifications"
required-features = ["use_tokio"]
2 changes: 2 additions & 0 deletions TODO.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
* Check what we're doing with outgoing request parameters, there are 2 allocations going on in call_args! and rpc_args!, and we're just reading it in the end.

* Can we use the non-generic `split` methods from tokio for unixstream, tcpstream? Supposedly better performance, but introduces lifetimes...

* Propogate errors from `model::encode()` in `handler_loop()`
175 changes: 175 additions & 0 deletions examples/nested_requests.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
use nvim_rs::{
compat::tokio::Compat,
create::tokio as create,
neovim::Neovim,
Handler,
};

use async_trait::async_trait;

use rmpv::Value;

use std::{
sync::Arc,
path::Path,
};

use tokio::{
self,
process::{ChildStdin, Command},
sync::Mutex,
spawn
};

const NVIM_BIN: &str = if cfg!(windows) {
"nvim.exe"
} else {
"nvim"
};
const NVIM_PATH: &str = if cfg!(windows) {
"neovim/build/bin/nvim.exe"
} else {
"neovim/build/bin/nvim"
};

#[derive(Clone)]
struct NeovimHandler {
froodle: Arc<Mutex<String>>,
}

#[async_trait]
impl Handler for NeovimHandler {
type Writer = Compat<ChildStdin>;

async fn handle_request(
&self,
name: String,
args: Vec<Value>,
neovim: Neovim<Compat<ChildStdin>>,
) -> Result<Value, Value> {
match name.as_ref() {
"dummy" => Ok(Value::from("o")),
"req" => {
let v = args[0].as_str().unwrap();

let neovim = neovim.clone();
match v {
"y" => {
let mut x: String = neovim
.get_vvar("progname")
.await
.unwrap()
.as_str()
.unwrap()
.into();
x.push_str(" - ");
x.push_str(
neovim.get_var("oogle").await.unwrap().as_str().unwrap(),
);
x.push_str(" - ");
x.push_str(
neovim
.eval("rpcrequest(1,'dummy')")
.await
.unwrap()
.as_str()
.unwrap(),
);
x.push_str(" - ");
x.push_str(
neovim
.eval("rpcrequest(1,'req', 'z')")
.await
.unwrap()
.as_str()
.unwrap(),
);
Ok(Value::from(x))
}
"z" => {
let x: String = neovim
.get_vvar("progname")
.await
.unwrap()
.as_str()
.unwrap()
.into();
Ok(Value::from(x))
}
&_ => Err(Value::from("wrong argument to req")),
}
}
&_ => Err(Value::from("wrong method name for request")),
}
}

async fn handle_notify(
&self,
name: String,
args: Vec<Value>,
_neovim: Neovim<Compat<ChildStdin>>,
) {
match name.as_ref() {
"set_froodle" => {
*self.froodle.lock().await = args[0].as_str().unwrap().to_string()
}
_ => {}
};
}
}

#[tokio::main(flavor = "current_thread")]
async fn main() {
let rs = r#"exe ":fun M(timer)
call rpcnotify(1, 'set_froodle', rpcrequest(1, 'req', 'y'))
endfun""#;
let rs2 = r#"exe ":fun N(timer)
call chanclose(1)
endfun""#;

let froodle = Arc::new(Mutex::new(String::new()));
let handler = NeovimHandler {
froodle: froodle.clone(),
};

let path = if Path::new(NVIM_PATH).exists() {
NVIM_PATH
} else {
NVIM_BIN
};
let (nvim, io, _child) = create::new_child_cmd(
Command::new(path).args(&[
"-u",
"NONE",
"--embed",
"--headless",
"-c",
rs,
"-c",
":let timer = timer_start(500, 'M')",
"-c",
rs2,
"-c",
":let timer = timer_start(1500, 'N')",
]),
handler,
)
.await
.unwrap();

let nv = nvim.clone();
spawn(async move { nv.set_var("oogle", Value::from("doodle")).await });

// The 2nd timer closes the channel, which will be returned as an error from
// the io handler. We only fail the test if we got another error
if let Err(err) = io.await.unwrap() {
if !err.is_channel_closed() {
panic!("Error in io: '{:?}'", err);
}
}

assert_eq!(
format!("{nvim} - doodle - o - {nvim}", nvim = NVIM_BIN),
*froodle.lock().await
);
}
120 changes: 75 additions & 45 deletions src/neovim.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,15 @@ use std::{
};

use futures::{
channel::oneshot,
channel::{
mpsc::{unbounded, UnboundedReceiver, UnboundedSender},
oneshot,
},
io::{AsyncRead, AsyncWrite, BufWriter},
lock::Mutex,
sink::SinkExt,
stream::StreamExt,
future, TryFutureExt,
};

use crate::{
Expand Down Expand Up @@ -100,8 +106,12 @@ where
queue: Arc::new(Mutex::new(Vec::new())),
};

let req_t = req.clone();
let fut = Self::io_loop(handler, reader, req_t);
let (sender, receiver) = unbounded();
let fut = future::try_join(
req.clone().io_loop(reader, sender),
req.clone().handler_loop(handler, receiver)
)
.map_ok(|_| ());

(req, fut)
}
Expand Down Expand Up @@ -180,39 +190,39 @@ where
}
}

async fn io_loop<H, R>(
async fn handler_loop<H>(
self,
handler: H,
mut reader: R,
neovim: Neovim<H::Writer>,
mut receiver: UnboundedReceiver<RpcMessage>,
) -> Result<(), Box<LoopError>>
where
H: Handler + Spawner,
R: AsyncRead + Send + Unpin + 'static,
H: Handler<Writer = W> + Spawner,
{
let mut rest: Vec<u8> = vec![];

loop {
let msg = match model::decode(&mut reader, &mut rest).await {
Ok(msg) => msg,
Err(err) => {
let e = neovim.send_error_to_callers(&neovim.queue, *err).await?;
return Err(Box::new(LoopError::DecodeError(e, None)));
}
let msg = match receiver.next().await {
Some(msg) => msg,
/* If our receiver closes, that just means that io_handler started
* shutting down. This is normal, so shut down along with it and don't
* report an error
*/
None => break Ok(()),
};

debug!("Get message {:?}", msg);
match msg {
RpcMessage::RpcRequest {
msgid,
method,
params,
} => {
let neovim = neovim.clone();
let handler_c = handler.clone();
let neovim = self.clone();
let writer = self.writer.clone();

handler.spawn(async move {
let neovim_t = neovim.clone();
let response =
match handler_c.handle_request(method, params, neovim_t).await {
let response = match handler_c
.handle_request(method, params, neovim)
.await
{
Ok(result) => RpcMessage::RpcResponse {
msgid,
result,
Expand All @@ -225,37 +235,57 @@ where
},
};

model::encode(neovim.writer, response)
model::encode(writer, response)
.await
.unwrap_or_else(|e| {
error!("Error sending response to request {}: '{}'", msgid, e);
});
});
}
RpcMessage::RpcResponse {
msgid,
result,
error,
} => {
let sender = find_sender(&neovim.queue, msgid).await?;
if error == Value::Nil {
sender
.send(Ok(Ok(result)))
.map_err(|r| (msgid, r.expect("This was an OK(_)")))?;
} else {
sender
.send(Ok(Err(error)))
.map_err(|r| (msgid, r.expect("This was an OK(_)")))?;
}
}
RpcMessage::RpcNotification { method, params } => {
let handler_c = handler.clone();
let neovim = neovim.clone();
handler.spawn(async move {
handler_c.handle_notify(method, params, neovim).await
});
},
RpcMessage::RpcNotification {
method,
params
} => handler.handle_notify(method, params, self.clone()).await,
_ => unreachable!(),
}
}
}

async fn io_loop<R>(
self,
mut reader: R,
mut sender: UnboundedSender<RpcMessage>,
) -> Result<(), Box<LoopError>>
where
R: AsyncRead + Send + Unpin + 'static,
{
let mut rest: Vec<u8> = vec![];

loop {
let msg = match model::decode(&mut reader, &mut rest).await {
Ok(msg) => msg,
Err(err) => {
let e = self.send_error_to_callers(&self.queue, *err).await?;
return Err(Box::new(LoopError::DecodeError(e, None)));
}
};

debug!("Get message {:?}", msg);
if let RpcMessage::RpcResponse { msgid, result, error, } = msg {
let sender = find_sender(&self.queue, msgid).await?;
if error == Value::Nil {
sender
.send(Ok(Ok(result)))
.map_err(|r| (msgid, r.expect("This was an OK(_)")))?;
} else {
sender
.send(Ok(Err(error)))
.map_err(|r| (msgid, r.expect("This was an OK(_)")))?;
}
} else {
// Send message to handler_loop()
sender.send(msg).await.unwrap();
}
}
}

Expand Down
4 changes: 3 additions & 1 deletion src/rpc/handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ pub trait Handler: Send + Sync + Clone + 'static {
Err(Value::from("Not implemented"))
}

/// Handling an rpc notification.
/// Handling an rpc notification. Notifications are handled one at a time in
/// the order in which they were received, and will block new requests from
/// being received until handle_notify returns.
async fn handle_notify(
&self,
_name: String,
Expand Down
Loading