Skip to content

Optimize and improve the proc_macro RPC interface for cross-thread execution #86816

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 12 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
proc_macro: add an optimized CrossThread execution strategy, and a de…
…bug flag to use it

This new strategy supports avoiding waiting for a reply for noblock messages.
This strategy requires using a channel-like approach (similar to the previous
CrossThread1 approach).

This new CrossThread execution strategy takes a type parameter for the channel
to use, allowing rustc to use a more efficient channel which the proc_macro
crate could not declare as a dependency.
  • Loading branch information
mystor committed Jul 2, 2021
commit 6ce595a65e79e9f7fb2201b299bcb5bcc23c7e22
52 changes: 31 additions & 21 deletions compiler/rustc_expand/src/proc_macro.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@ use rustc_parse::parser::ForceCollect;
use rustc_span::def_id::CrateNum;
use rustc_span::{Span, DUMMY_SP};

const EXEC_STRATEGY: pm::bridge::server::SameThread = pm::bridge::server::SameThread;
fn exec_strategy(ecx: &ExtCtxt<'_>) -> impl pm::bridge::server::ExecutionStrategy {
<pm::bridge::server::MaybeCrossThread<pm::bridge::server::StdMessagePipe<_>>>::new(
ecx.sess.opts.debugging_opts.proc_macro_cross_thread,
)
}

pub struct BangProcMacro {
pub client: pm::bridge::client::Client<fn(pm::TokenStream) -> pm::TokenStream>,
Expand All @@ -27,14 +31,16 @@ impl base::ProcMacro for BangProcMacro {
input: TokenStream,
) -> Result<TokenStream, ErrorReported> {
let server = proc_macro_server::Rustc::new(ecx, self.krate);
self.client.run(&EXEC_STRATEGY, server, input, ecx.ecfg.proc_macro_backtrace).map_err(|e| {
let mut err = ecx.struct_span_err(span, "proc macro panicked");
if let Some(s) = e.as_str() {
err.help(&format!("message: {}", s));
}
err.emit();
ErrorReported
})
self.client.run(&exec_strategy(ecx), server, input, ecx.ecfg.proc_macro_backtrace).map_err(
|e| {
let mut err = ecx.struct_span_err(span, "proc macro panicked");
if let Some(s) = e.as_str() {
err.help(&format!("message: {}", s));
}
err.emit();
ErrorReported
},
)
}
}

Expand All @@ -53,7 +59,7 @@ impl base::AttrProcMacro for AttrProcMacro {
) -> Result<TokenStream, ErrorReported> {
let server = proc_macro_server::Rustc::new(ecx, self.krate);
self.client
.run(&EXEC_STRATEGY, server, annotation, annotated, ecx.ecfg.proc_macro_backtrace)
.run(&exec_strategy(ecx), server, annotation, annotated, ecx.ecfg.proc_macro_backtrace)
.map_err(|e| {
let mut err = ecx.struct_span_err(span, "custom attribute panicked");
if let Some(s) = e.as_str() {
Expand Down Expand Up @@ -102,18 +108,22 @@ impl MultiItemModifier for ProcMacroDerive {
};

let server = proc_macro_server::Rustc::new(ecx, self.krate);
let stream =
match self.client.run(&EXEC_STRATEGY, server, input, ecx.ecfg.proc_macro_backtrace) {
Ok(stream) => stream,
Err(e) => {
let mut err = ecx.struct_span_err(span, "proc-macro derive panicked");
if let Some(s) = e.as_str() {
err.help(&format!("message: {}", s));
}
err.emit();
return ExpandResult::Ready(vec![]);
let stream = match self.client.run(
&exec_strategy(ecx),
server,
input,
ecx.ecfg.proc_macro_backtrace,
) {
Ok(stream) => stream,
Err(e) => {
let mut err = ecx.struct_span_err(span, "proc-macro derive panicked");
if let Some(s) = e.as_str() {
err.help(&format!("message: {}", s));
}
};
err.emit();
return ExpandResult::Ready(vec![]);
}
};

let error_count_before = ecx.sess.parse_sess.span_diagnostic.err_count();
let mut parser =
Expand Down
2 changes: 2 additions & 0 deletions compiler/rustc_session/src/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1207,6 +1207,8 @@ options! {
"print layout information for each type encountered (default: no)"),
proc_macro_backtrace: bool = (false, parse_bool, [UNTRACKED],
"show backtraces for panics during proc-macro execution (default: no)"),
proc_macro_cross_thread: bool = (false, parse_bool, [UNTRACKED],
"run proc-macro code on a separate thread (default: no)"),
profile: bool = (false, parse_bool, [TRACKED],
"insert profiling code (default: no)"),
profile_closures: bool = (false, parse_no_flag, [UNTRACKED],
Expand Down
6 changes: 5 additions & 1 deletion library/proc_macro/src/bridge/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,11 @@ macro_rules! client_send_impl {

b = bridge.dispatch.call(b);

let r = Result::<(), PanicMessage>::decode(&mut &b[..], &mut ());
let r = if b.len() > 0 {
Result::<(), PanicMessage>::decode(&mut &b[..], &mut ())
} else {
Ok(())
};

bridge.cached_buffer = b;

Expand Down
169 changes: 103 additions & 66 deletions library/proc_macro/src/bridge/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

use super::*;

use std::marker::PhantomData;

// FIXME(eddyb) generate the definition of `HandleStore` in `server.rs`.
use super::client::HandleStore;

Expand Down Expand Up @@ -174,6 +176,50 @@ pub trait ExecutionStrategy {
) -> Buffer<u8>;
}

pub struct MaybeCrossThread<P> {
cross_thread: bool,
marker: PhantomData<P>,
}

impl<P> MaybeCrossThread<P> {
pub const fn new(cross_thread: bool) -> Self {
MaybeCrossThread { cross_thread, marker: PhantomData }
}
}

impl<P> ExecutionStrategy for MaybeCrossThread<P>
where
P: MessagePipe<Buffer<u8>> + Send + 'static,
{
fn run_bridge_and_client<D: Copy + Send + 'static>(
&self,
dispatcher: &mut impl DispatcherTrait,
input: Buffer<u8>,
run_client: extern "C" fn(BridgeConfig<'_>, D) -> Buffer<u8>,
client_data: D,
force_show_panics: bool,
) -> Buffer<u8> {
if self.cross_thread {
<CrossThread<P>>::new().run_bridge_and_client(
dispatcher,
input,
run_client,
client_data,
force_show_panics,
)
} else {
SameThread.run_bridge_and_client(
dispatcher,
input,
run_client,
client_data,
force_show_panics,
)
}
}
}

#[derive(Default)]
pub struct SameThread;

impl ExecutionStrategy for SameThread {
Expand All @@ -194,12 +240,18 @@ impl ExecutionStrategy for SameThread {
}
}

// NOTE(eddyb) Two implementations are provided, the second one is a bit
// faster but neither is anywhere near as fast as same-thread execution.
pub struct CrossThread<P>(PhantomData<P>);

pub struct CrossThread1;
impl<P> CrossThread<P> {
pub const fn new() -> Self {
CrossThread(PhantomData)
}
}

impl ExecutionStrategy for CrossThread1 {
impl<P> ExecutionStrategy for CrossThread<P>
where
P: MessagePipe<Buffer<u8>> + Send + 'static,
{
fn run_bridge_and_client<D: Copy + Send + 'static>(
&self,
dispatcher: &mut impl DispatcherTrait,
Expand All @@ -208,15 +260,18 @@ impl ExecutionStrategy for CrossThread1 {
client_data: D,
force_show_panics: bool,
) -> Buffer<u8> {
use std::sync::mpsc::channel;

let (req_tx, req_rx) = channel();
let (res_tx, res_rx) = channel();
let (mut server, mut client) = P::new();

let join_handle = thread::spawn(move || {
let mut dispatch = |b| {
req_tx.send(b).unwrap();
res_rx.recv().unwrap()
let mut dispatch = |b: Buffer<u8>| -> Buffer<u8> {
let method_tag = api_tags::Method::decode(&mut &b[..], &mut ());
client.send(b);

if method_tag.should_wait() {
client.recv().expect("server died while client waiting for reply")
} else {
Buffer::new()
}
};

run_client(
Expand All @@ -225,73 +280,55 @@ impl ExecutionStrategy for CrossThread1 {
)
});

for b in req_rx {
res_tx.send(dispatcher.dispatch(b)).unwrap();
while let Some(b) = server.recv() {
let method_tag = api_tags::Method::decode(&mut &b[..], &mut ());
let b = dispatcher.dispatch(b);

if method_tag.should_wait() {
server.send(b);
} else if let Err(err) = <Result<(), PanicMessage>>::decode(&mut &b[..], &mut ()) {
panic::resume_unwind(err.into());
}
}

join_handle.join().unwrap()
}
}

pub struct CrossThread2;

impl ExecutionStrategy for CrossThread2 {
fn run_bridge_and_client<D: Copy + Send + 'static>(
&self,
dispatcher: &mut impl DispatcherTrait,
input: Buffer<u8>,
run_client: extern "C" fn(BridgeConfig<'_>, D) -> Buffer<u8>,
client_data: D,
force_show_panics: bool,
) -> Buffer<u8> {
use std::sync::{Arc, Mutex};

enum State<T> {
Req(T),
Res(T),
}

let mut state = Arc::new(Mutex::new(State::Res(Buffer::new())));
/// A message pipe used for communicating between server and client threads.
pub trait MessagePipe<T>: Sized {
/// Create a new pair of endpoints for the message pipe.
fn new() -> (Self, Self);

let server_thread = thread::current();
let state2 = state.clone();
let join_handle = thread::spawn(move || {
let mut dispatch = |b| {
*state2.lock().unwrap() = State::Req(b);
server_thread.unpark();
loop {
thread::park();
if let State::Res(b) = &mut *state2.lock().unwrap() {
break b.take();
}
}
};
/// Send a message to the other endpoint of this pipe.
fn send(&mut self, value: T);

let r = run_client(
BridgeConfig { input, dispatch: (&mut dispatch).into(), force_show_panics },
client_data,
);
/// Receive a message from the other endpoint of this pipe.
///
/// Returns `None` if the other end of the pipe has been destroyed, and no
/// message was received.
fn recv(&mut self) -> Option<T>;
}

// Wake up the server so it can exit the dispatch loop.
drop(state2);
server_thread.unpark();
/// Implementation of `MessagePipe` using `std::sync::mpsc`
pub struct StdMessagePipe<T> {
tx: std::sync::mpsc::Sender<T>,
rx: std::sync::mpsc::Receiver<T>,
}

r
});
impl<T> MessagePipe<T> for StdMessagePipe<T> {
fn new() -> (Self, Self) {
let (tx1, rx1) = std::sync::mpsc::channel();
let (tx2, rx2) = std::sync::mpsc::channel();
(StdMessagePipe { tx: tx1, rx: rx2 }, StdMessagePipe { tx: tx2, rx: rx1 })
}

// Check whether `state2` was dropped, to know when to stop.
while Arc::get_mut(&mut state).is_none() {
thread::park();
let mut b = match &mut *state.lock().unwrap() {
State::Req(b) => b.take(),
_ => continue,
};
b = dispatcher.dispatch(b.take());
*state.lock().unwrap() = State::Res(b);
join_handle.thread().unpark();
}
fn send(&mut self, v: T) {
self.tx.send(v).unwrap();
}

join_handle.join().unwrap()
fn recv(&mut self) -> Option<T> {
self.rx.recv().ok()
}
}

Expand Down