From 7c3496a76acfad992ebd7f1af1c2b22174b79530 Mon Sep 17 00:00:00 2001 From: Richard Newton Date: Thu, 12 Sep 2013 18:09:37 +0100 Subject: [PATCH] Fix race condition and support multiple socket connects before bind. --- .gitignore | 1 + src/command.hpp | 1 + src/ctx.cpp | 43 ++++-- src/ctx.hpp | 4 +- src/object.cpp | 16 +- src/object.hpp | 3 +- src/socket_base.cpp | 46 +----- tests/Makefile.am | 4 +- tests/test_inproc_connect_before_bind.cpp | 180 +++++++++++++++++++++- 9 files changed, 232 insertions(+), 66 deletions(-) diff --git a/.gitignore b/.gitignore index 519d591d3f..d86376e2cb 100644 --- a/.gitignore +++ b/.gitignore @@ -64,6 +64,7 @@ tests/test_req_request_ids tests/test_req_strict tests/test_fork tests/test_conflate +tests/test_inproc_connect_before_bind tests/test_linger tests/test_security_null tests/test_security_null.opp diff --git a/src/command.hpp b/src/command.hpp index ef48850451..fc0181b44f 100644 --- a/src/command.hpp +++ b/src/command.hpp @@ -55,6 +55,7 @@ namespace zmq term_ack, reap, reaped, + inproc_connected, done } type; diff --git a/src/ctx.cpp b/src/ctx.cpp index 0dcbf93943..e6271a86dd 100644 --- a/src/ctx.cpp +++ b/src/ctx.cpp @@ -396,28 +396,49 @@ void zmq::ctx_t::pend_connection (const char *addr_, const pending_connection_t { endpoints_sync.lock (); - // Todo, use multimap to support multiple pending connections - pending_connections[addr_] = pending_connection_; + endpoints_t::iterator it = endpoints.find (addr_); + if (it == endpoints.end ()) + { + // Still no bind. + pending_connection_.socket->inc_seqnum (); + pending_connections.insert (pending_connections_t::value_type (std::string (addr_), pending_connection_)); + } + else + { + // Bind has happened in the mean time, connect directly + pending_connection_t copy = pending_connection_; + it->second.socket->inc_seqnum(); + copy.pipe->set_tid(it->second.socket->get_tid()); + command_t cmd; + cmd.type = command_t::bind; + cmd.args.bind.pipe = copy.pipe; + it->second.socket->process_command(cmd); + } endpoints_sync.unlock (); } -zmq::pending_connection_t zmq::ctx_t::next_pending_connection(const char *addr_) +void zmq::ctx_t::connect_pending (const char *addr_, zmq::socket_base_t *bind_socket_) { endpoints_sync.lock (); - pending_connections_t::iterator it = pending_connections.find (addr_); - if (it == pending_connections.end ()) { + std::pair pending = pending_connections.equal_range(addr_); - endpoints_sync.unlock (); - pending_connection_t empty = {NULL, NULL}; - return empty; + for (pending_connections_t::iterator p = pending.first; p != pending.second; ++p) + { + bind_socket_->inc_seqnum(); + p->second.pipe->set_tid(bind_socket_->get_tid()); + command_t cmd; + cmd.type = command_t::bind; + cmd.args.bind.pipe = p->second.pipe; + bind_socket_->process_command(cmd); + + bind_socket_->send_inproc_connected(p->second.socket); } - pending_connection_t pending_connection = it->second; - pending_connections.erase(it); + + pending_connections.erase(pending.first, pending.second); endpoints_sync.unlock (); - return pending_connection; } // The last used socket ID, or 0 if no socket was used so far. Note that this diff --git a/src/ctx.hpp b/src/ctx.hpp index 10a34dec66..1a5cc43061 100644 --- a/src/ctx.hpp +++ b/src/ctx.hpp @@ -109,7 +109,7 @@ namespace zmq void unregister_endpoints (zmq::socket_base_t *socket_); endpoint_t find_endpoint (const char *addr_); void pend_connection (const char *addr_, const pending_connection_t &pending_connection_); - pending_connection_t next_pending_connection (const char *addr_); + void connect_pending (const char *addr_, zmq::socket_base_t *bind_socket_); enum { term_tid = 0, @@ -166,7 +166,7 @@ namespace zmq endpoints_t endpoints; // List of inproc connection endpoints pending a bind - typedef std::map pending_connections_t; + typedef std::multimap pending_connections_t; pending_connections_t pending_connections; // Synchronisation of access to the list of inproc endpoints. diff --git a/src/object.cpp b/src/object.cpp index fec067554a..ba20b8b099 100644 --- a/src/object.cpp +++ b/src/object.cpp @@ -127,6 +127,10 @@ void zmq::object_t::process_command (command_t &cmd_) process_reaped (); break; + case command_t::inproc_connected: + process_seqnum (); + break; + case command_t::done: default: zmq_assert (false); @@ -153,9 +157,9 @@ void zmq::object_t::pend_connection (const char *addr_, const pending_connection ctx->pend_connection (addr_, pending_connection_); } -zmq::pending_connection_t zmq::object_t::next_pending_connection (const char *addr_) +void zmq::object_t::connect_pending (const char *addr_, zmq::socket_base_t *bind_socket_) { - return ctx->next_pending_connection(addr_); + return ctx->connect_pending(addr_, bind_socket_); } void zmq::object_t::destroy_socket (socket_base_t *socket_) @@ -312,6 +316,14 @@ void zmq::object_t::send_reaped () send_command (cmd); } +void zmq::object_t::send_inproc_connected (zmq::socket_base_t *socket_) +{ + command_t cmd; + cmd.destination = socket_; + cmd.type = command_t::inproc_connected; + send_command (cmd); +} + void zmq::object_t::send_done () { command_t cmd; diff --git a/src/object.hpp b/src/object.hpp index 13851652f3..ab171e78aa 100644 --- a/src/object.hpp +++ b/src/object.hpp @@ -51,6 +51,7 @@ namespace zmq void set_tid(uint32_t id); ctx_t *get_ctx (); void process_command (zmq::command_t &cmd_); + void send_inproc_connected (zmq::socket_base_t *socket_); protected: @@ -60,7 +61,7 @@ namespace zmq void unregister_endpoints (zmq::socket_base_t *socket_); zmq::endpoint_t find_endpoint (const char *addr_); void pend_connection (const char *addr_, const pending_connection_t &pending_connection_); - zmq::pending_connection_t next_pending_connection (const char *addr_); + void connect_pending (const char *addr_, zmq::socket_base_t *bind_socket_); void destroy_socket (zmq::socket_base_t *socket_); diff --git a/src/socket_base.cpp b/src/socket_base.cpp index ebadd08089..8709b3c452 100644 --- a/src/socket_base.cpp +++ b/src/socket_base.cpp @@ -342,52 +342,8 @@ int zmq::socket_base_t::bind (const char *addr_) endpoint_t endpoint = {this, options}; int rc = register_endpoint (addr_, endpoint); if (rc == 0) { - // Save last endpoint URI + connect_pending(addr_, this); last_endpoint.assign (addr_); - - pending_connection_t pending_connection = next_pending_connection(addr_); - while (pending_connection.pipe != NULL) - { - inc_seqnum(); - //// If required, send the identity of the local socket to the peer. - //if (peer.options.recv_identity) { - // msg_t id; - // rc = id.init_size (options.identity_size); - // errno_assert (rc == 0); - // memcpy (id.data (), options.identity, options.identity_size); - // id.set_flags (msg_t::identity); - // bool written = new_pipes [0]->write (&id); - // zmq_assert (written); - // new_pipes [0]->flush (); - //} - - //// If required, send the identity of the peer to the local socket. - //if (options.recv_identity) { - // msg_t id; - // rc = id.init_size (peer.options.identity_size); - // errno_assert (rc == 0); - // memcpy (id.data (), peer.options.identity, peer.options.identity_size); - // id.set_flags (msg_t::identity); - // bool written = new_pipes [1]->write (&id); - // zmq_assert (written); - // new_pipes [1]->flush (); - //} - - //// Attach remote end of the pipe to the peer socket. Note that peer's - //// seqnum was incremented in find_endpoint function. We don't need it - //// increased here. - //send_bind (peer.socket, new_pipes [1], false); - - pending_connection.pipe->set_tid(get_tid()); - - command_t cmd; - cmd.type = command_t::bind; - cmd.args.bind.pipe = pending_connection.pipe; - process_command(cmd); - - - pending_connection = next_pending_connection(addr_); - } } return rc; } diff --git a/tests/Makefile.am b/tests/Makefile.am index cb4cfca221..7bcc09f9f9 100644 --- a/tests/Makefile.am +++ b/tests/Makefile.am @@ -36,7 +36,8 @@ noinst_PROGRAMS = test_system \ test_spec_pushpull \ test_req_request_ids \ test_req_strict \ - test_conflate + test_conflate \ + test_inproc_connect_before_bind if !ON_MINGW noinst_PROGRAMS += test_shutdown_stress \ @@ -80,6 +81,7 @@ test_spec_pushpull_SOURCES = test_spec_pushpull.cpp test_req_request_ids_SOURCES = test_req_request_ids.cpp test_req_strict_SOURCES = test_req_strict.cpp test_conflate_SOURCES = test_conflate.cpp +test_inproc_connect_before_bind_SOURCES = test_inproc_connect_before_bind.cpp if !ON_MINGW test_shutdown_stress_SOURCES = test_shutdown_stress.cpp test_pair_ipc_SOURCES = test_pair_ipc.cpp testutil.hpp diff --git a/tests/test_inproc_connect_before_bind.cpp b/tests/test_inproc_connect_before_bind.cpp index 04f78aac17..90004e8050 100644 --- a/tests/test_inproc_connect_before_bind.cpp +++ b/tests/test_inproc_connect_before_bind.cpp @@ -17,9 +17,27 @@ along with this program. If not, see . */ +#include "../include/zmq_utils.h" #include #include "testutil.hpp" +static void pusher (void *ctx) +{ + // Connect first + void *connectSocket = zmq_socket (ctx, ZMQ_PAIR); + assert (connectSocket); + int rc = zmq_connect (connectSocket, "inproc://a"); + assert (rc == 0); + + // Queue up some data + rc = zmq_send_const (connectSocket, "foobar", 6, 0); + assert (rc == 6); + + // Cleanup + rc = zmq_close (connectSocket); + assert (rc == 0); +} + void test_bind_before_connect() { void *ctx = zmq_ctx_new (); @@ -45,7 +63,7 @@ void test_bind_before_connect() zmq_msg_t msg; rc = zmq_msg_init (&msg); assert (rc == 0); - rc = zmq_msg_recv (&msg, bindSocket, ZMQ_NOBLOCK); + rc = zmq_msg_recv (&msg, bindSocket, 0); assert (rc == 6); void *data = zmq_msg_data (&msg); assert (memcmp ("foobar", data, 6) == 0); @@ -72,7 +90,6 @@ void test_connect_before_bind() int rc = zmq_connect (connectSocket, "inproc://a"); assert (rc == 0); - // Queue up some data rc = zmq_send_const (connectSocket, "foobar", 6, 0); assert (rc == 6); @@ -87,7 +104,7 @@ void test_connect_before_bind() zmq_msg_t msg; rc = zmq_msg_init (&msg); assert (rc == 0); - rc = zmq_msg_recv (&msg, bindSocket, ZMQ_NOBLOCK); + rc = zmq_msg_recv (&msg, bindSocket, 0); assert (rc == 6); void *data = zmq_msg_data (&msg); assert (memcmp ("foobar", data, 6) == 0); @@ -103,12 +120,167 @@ void test_connect_before_bind() assert (rc == 0); } +void test_connect_before_bind_pub_sub() +{ + void *ctx = zmq_ctx_new (); + assert (ctx); + + // Connect first + void *connectSocket = zmq_socket (ctx, ZMQ_PUB); + assert (connectSocket); + int rc = zmq_connect (connectSocket, "inproc://a"); + assert (rc == 0); + + // Queue up some data, this will be dropped + rc = zmq_send_const (connectSocket, "before", 6, 0); + assert (rc == 6); + + // Now bind + void *bindSocket = zmq_socket (ctx, ZMQ_SUB); + assert (bindSocket); + rc = zmq_setsockopt (bindSocket, ZMQ_SUBSCRIBE, "", 0); + assert (rc == 0); + rc = zmq_bind (bindSocket, "inproc://a"); + assert (rc == 0); + + // Wait for pub-sub connection to happen + zmq_sleep (1); + + // Queue up some data, this not will be dropped + rc = zmq_send_const (connectSocket, "after", 6, 0); + assert (rc == 6); + + // Read pending message + zmq_msg_t msg; + rc = zmq_msg_init (&msg); + assert (rc == 0); + rc = zmq_msg_recv (&msg, bindSocket, 0); + assert (rc == 6); + void *data = zmq_msg_data (&msg); + assert (memcmp ("after", data, 5) == 0); + + // Cleanup + rc = zmq_close (connectSocket); + assert (rc == 0); + + rc = zmq_close (bindSocket); + assert (rc == 0); + + rc = zmq_ctx_term (ctx); + assert (rc == 0); +} + +void test_multiple_connects() +{ + const unsigned int no_of_connects = 10; + void *ctx = zmq_ctx_new (); + assert (ctx); + + int rc; + void *connectSocket[no_of_connects]; + + // Connect first + for (unsigned int i = 0; i < no_of_connects; ++i) + { + connectSocket [i] = zmq_socket (ctx, ZMQ_PUSH); + assert (connectSocket [i]); + rc = zmq_connect (connectSocket [i], "inproc://a"); + assert (rc == 0); + + // Queue up some data + rc = zmq_send_const (connectSocket [i], "foobar", 6, 0); + assert (rc == 6); + } + + // Now bind + void *bindSocket = zmq_socket (ctx, ZMQ_PULL); + assert (bindSocket); + rc = zmq_bind (bindSocket, "inproc://a"); + assert (rc == 0); + + for (unsigned int i = 0; i < no_of_connects; ++i) + { + // Read pending message + zmq_msg_t msg; + rc = zmq_msg_init (&msg); + assert (rc == 0); + rc = zmq_msg_recv (&msg, bindSocket, 0); + assert (rc == 6); + void *data = zmq_msg_data (&msg); + assert (memcmp ("foobar", data, 6) == 0); + } + + // Cleanup + for (unsigned int i = 0; i < no_of_connects; ++i) + { + rc = zmq_close (connectSocket [i]); + assert (rc == 0); + } + + rc = zmq_close (bindSocket); + assert (rc == 0); + + rc = zmq_ctx_term (ctx); + assert (rc == 0); +} + +void test_multiple_threads() +{ + const unsigned int no_of_threads = 10; + void *ctx = zmq_ctx_new (); + assert (ctx); + + int rc; + void *threads [no_of_threads]; + + // Connect first + for (unsigned int i = 0; i < no_of_threads; ++i) + { + threads [i] = zmq_threadstart (&pusher, ctx); + } + + //zmq_sleep(1); + + // Now bind + void *bindSocket = zmq_socket (ctx, ZMQ_PULL); + assert (bindSocket); + rc = zmq_bind (bindSocket, "inproc://a"); + assert (rc == 0); + + for (unsigned int i = 0; i < no_of_threads; ++i) + { + // Read pending message + zmq_msg_t msg; + rc = zmq_msg_init (&msg); + assert (rc == 0); + rc = zmq_msg_recv (&msg, bindSocket, 0); + assert (rc == 6); + void *data = zmq_msg_data (&msg); + assert (memcmp ("foobar", data, 6) == 0); + } + + // Cleanup + for (unsigned int i = 0; i < no_of_threads; ++i) + { + zmq_threadclose (threads [i]); + } + + rc = zmq_close (bindSocket); + assert (rc == 0); + + rc = zmq_ctx_term (ctx); + assert (rc == 0); +} + int main (void) { setup_test_environment(); test_bind_before_connect(); test_connect_before_bind(); + test_connect_before_bind_pub_sub(); + test_multiple_connects(); + test_multiple_threads(); - return 0 ; + return 0; }