Skip to content

Commit eb4b3fa

Browse files
authored
Fix selection of thread-local wasm when enqueuing (proxy-wasm#36)
Signed-off-by: Gregory Brail <gregbrail@google.com>
1 parent 1cb51cf commit eb4b3fa

File tree

1 file changed

+25
-13
lines changed

1 file changed

+25
-13
lines changed

src/context.cc

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ class SharedData {
8787
}
8888

8989
uint32_t registerQueue(string_view vm_id, string_view queue_name, uint32_t context_id,
90-
CallOnThreadFunction call_on_thread) {
90+
CallOnThreadFunction call_on_thread, string_view vm_key) {
9191
std::lock_guard<std::mutex> lock(mutex_);
9292
auto key = std::make_pair(std::string(vm_id), std::string(queue_name));
9393
auto it = queue_tokens_.insert(std::make_pair(key, static_cast<uint32_t>(0)));
@@ -97,7 +97,7 @@ class SharedData {
9797
}
9898
uint32_t token = it.first->second;
9999
auto &q = queues_[token];
100-
q.vm_id = std::string(vm_id);
100+
q.vm_key = std::string(vm_key);
101101
q.context_id = context_id;
102102
q.call_on_thread = std::move(call_on_thread);
103103
// Preserve any existing data.
@@ -127,17 +127,29 @@ class SharedData {
127127
it->second.queue.pop_front();
128128
return WasmResult::Ok;
129129
}
130+
130131
WasmResult enqueue(uint32_t token, string_view value) {
131-
std::lock_guard<std::mutex> lock(mutex_);
132-
auto it = queues_.find(token);
133-
if (it == queues_.end()) {
134-
return WasmResult::NotFound;
132+
std::string vm_key;
133+
uint32_t context_id;
134+
CallOnThreadFunction call_on_thread;
135+
136+
{
137+
std::lock_guard<std::mutex> lock(mutex_);
138+
auto it = queues_.find(token);
139+
if (it == queues_.end()) {
140+
return WasmResult::NotFound;
141+
}
142+
Queue *target_queue = &(it->second);
143+
vm_key = target_queue->vm_key;
144+
context_id = target_queue->context_id;
145+
call_on_thread = target_queue->call_on_thread;
146+
target_queue->queue.push_back(std::string(value));
135147
}
136-
it->second.queue.push_back(std::string(value));
137-
auto vm_id = it->second.vm_id;
138-
auto context_id = it->second.context_id;
139-
it->second.call_on_thread([vm_id, context_id, token] {
140-
auto wasm = getThreadLocalWasm(vm_id);
148+
149+
call_on_thread([vm_key, context_id, token] {
150+
// This code may or may not execute in another thread.
151+
// Make sure that the lock is no longer held here.
152+
auto wasm = getThreadLocalWasm(vm_key);
141153
if (wasm) {
142154
auto context = wasm->wasm()->getContext(context_id);
143155
if (context) {
@@ -171,7 +183,7 @@ class SharedData {
171183
}
172184

173185
struct Queue {
174-
std::string vm_id;
186+
std::string vm_key;
175187
uint32_t context_id;
176188
CallOnThreadFunction call_on_thread;
177189
std::deque<std::string> queue;
@@ -341,7 +353,7 @@ WasmResult ContextBase::registerSharedQueue(string_view queue_name,
341353
// root.
342354
*result = global_shared_data.registerQueue(wasm_->vm_id(), queue_name,
343355
isRootContext() ? id_ : parent_context_id_,
344-
wasm_->callOnThreadFunction());
356+
wasm_->callOnThreadFunction(), wasm_->vm_key());
345357
return WasmResult::Ok;
346358
}
347359

0 commit comments

Comments
 (0)