Skip to content

Support for Unix Domain Socket #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
1 change: 1 addition & 0 deletions AUTHORS
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
Masahiro Nakagawa <repeatedly _at_ gmail.com>
INADA Naoki <songofacandy _at_ gmail.com>
Harish Vishwanath <harish dot shastry at gmail dot com>
33 changes: 33 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,39 @@
<!--
[![Build Status](https://travis-ci.org/msgpack/msgpack-rpc-python.png)](https://travis-ci.org/msgpack/msgpack-rpc-python)
-->
# Unix Domain Socket support
Unix domain socket support is now available for msgpack-rpc. Sample examples below.

## UDS examples

### Server

```python
import msgpackrpc.udsaddress
from msgpackrpc.transport import uds
class SumServer(object):
def sum(self, x, y):
return x + y

# Use builder as uds. default builder is tcp which creates tcp sockets
server = msgpackrpc.Server(SumServer(), builder=uds)
# Use UDSAddress instead of msgpackrpc.Address
server.listen(msgpackrpc.udsaddress.UDSAddress('/tmp/exrpc'))
server.start()
```

### Client
```python
import msgpackrpc.udsaddress
from msgpackrpc.transport import uds

#Use UDSAddress instead of default Address object
client = msgpackrpc.Client(msgpackrpc.udsaddress.UDSAddress("/tmp/exrpc"), builder=uds)
result = client.call('sum', 1, 2) # = >
print "Sum of 1 and 2 : %d" % result
```

Go through the below sections for general usage of Message Pack RPC Library

# MessagePack RPC for Python

Expand Down
11 changes: 11 additions & 0 deletions example/uds_simpleclient.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
'''
@author: hvishwanath | harish.shastry@gmail.com
'''

import msgpackrpc.udsaddress
from msgpackrpc.transport import uds

#Use UDSAddress instead of default Address object
client = msgpackrpc.Client(msgpackrpc.udsaddress.UDSAddress("/tmp/exrpc"), builder=uds)
result = client.call('sum', 1, 2) # = >
print "Sum of 1 and 2 : %d" % result
15 changes: 15 additions & 0 deletions example/uds_simpleserver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
'''
@author: hvishwanath | harish.shastry@gmail.com
'''

import msgpackrpc.udsaddress
from msgpackrpc.transport import uds
class SumServer(object):
def sum(self, x, y):
return x + y

# Use builder as uds. default builder is tcp which creates tcp sockets
server = msgpackrpc.Server(SumServer(), builder=uds)
# Use UDSAddress instead of msgpackrpc.Address
server.listen(msgpackrpc.udsaddress.UDSAddress('/tmp/exrpc'))
server.start()
1 change: 1 addition & 0 deletions msgpackrpc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
from msgpackrpc.client import Client
from msgpackrpc.server import Server
from msgpackrpc.address import Address
from msgpackrpc.udsaddress import UDSAddress
53 changes: 53 additions & 0 deletions msgpackrpc/transport/uds.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
'''
@author: hvishwanath | harish.shastry@gmail.com
'''

import msgpackrpc.transport
from tornado.netutil import bind_unix_socket
from tornado import tcpserver
from tornado.iostream import IOStream

# Much of the implementation will be same as that of tcp module
# Changes required for unix domain socket support are done in this module
# Rest will be automatically used from tcp

# Create namespace equals
BaseSocket = msgpackrpc.transport.tcp.BaseSocket
ClientSocket = msgpackrpc.transport.tcp.ClientSocket
ClientTransport = msgpackrpc.transport.tcp.ClientTransport

ServerSocket = msgpackrpc.transport.tcp.ServerSocket
ServerTransport = msgpackrpc.transport.tcp.ServerTransport


class UDSServer(tcpserver.TCPServer):
"""Define a Unix domain socket server.
Instead of binding to TCP/IP socket, bind to UDS socket and listen"""

def __init__(self, io_loop=None, ssl_options=None):
tcpserver.TCPServer.__init__(self, io_loop=io_loop, ssl_options=ssl_options)

def listen(self, port, address=""):
"""Bind to a unix domain socket and add to self.
Note that port in our case actually contains the uds file name"""

# Create a Unix domain socket and bind
socket = bind_unix_socket(port)

# Add to self
self.add_socket(socket)

class MessagePackServer(UDSServer):
"""The MessagePackServer inherits from UDSServer
instead of tornado's TCP Server"""

def __init__(self, transport, io_loop=None, encodings=None):
self._transport = transport
self._encodings = encodings
UDSServer.__init__(self, io_loop=io_loop)

def handle_stream(self, stream, address):
ServerSocket(stream, self._transport, self._encodings)

#Monkey patch the MessagePackServer
msgpackrpc.transport.tcp.MessagePackServer = MessagePackServer
40 changes: 40 additions & 0 deletions msgpackrpc/udsaddress.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
'''
@author: hvishwanath | harish.shastry@gmail.com
'''

import socket
from tornado.platform.auto import set_close_exec

class UDSAddress(object):
"""This class abstracts Unix domain socket address.
For compatibility with other code in the library, port is always equal to host"""

def __init__(self, host, port=None):
self._host = host

# Passed value for port is ignored.
# Port is also made equal to host.
# This is because some of the code in transport.tcp uses address._port to connect.
# For a unix socket, there is no port. Hence if port = host, that code should work.
self._port = host

@property
def host(self):
return self._host

@property
def port(self):
return self._port

def unpack(self):
# Return only the host
return self._host

def socket(self, family=socket.AF_UNSPEC):
"""Return a Unix domain socket instead of tcp socket"""

sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
set_close_exec(sock.fileno())
sock.setblocking(0)

return sock
202 changes: 202 additions & 0 deletions test/test_uds_msgpackrpc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
'''
@author: hvishwanath | harish.shastry@gmail.com
'''

from msgpackrpc.transport import uds
from time import sleep
import threading
try:
import unittest2 as unittest
except ImportError:
import unittest

import helper
import msgpackrpc
from msgpackrpc import error

class TestMessagePackRPC(unittest.TestCase):
ENABLE_TIMEOUT_TEST = False

class TestArg:
''' this class must know completely how to deserialize '''
def __init__(self, a, b, c):
self.a = a
self.b = b
self.c = c

def to_msgpack(self):
return (self.a, self.b, self.c)

def add(self, rhs):
self.a += rhs.a
self.b -= rhs.b
self.c *= rhs.c
return self

def __eq__(self, rhs):
return (self.a == rhs.a and self.b == rhs.b and self.c == rhs.c)

@staticmethod
def from_msgpack(arg):
return TestMessagePackRPC.TestArg(arg[0], arg[1], arg[2])

class TestServer(object):
def hello(self):
return "world"

def sum(self, x, y):
return x + y

def nil(self):
return None

def add_arg(self, arg0, arg1):
lhs = TestMessagePackRPC.TestArg.from_msgpack(arg0)
rhs = TestMessagePackRPC.TestArg.from_msgpack(arg1)
return lhs.add(rhs)

def raise_error(self):
raise Exception('error')

def long_exec(self):
sleep(3)
return 'finish!'

def async_result(self):
ar = msgpackrpc.server.AsyncResult()
def do_async():
sleep(2)
ar.set_result("You are async!")
threading.Thread(target=do_async).start()
return ar

def setUp(self):
# Create UDSAddress
self._address = msgpackrpc.UDSAddress('/tmp/unusedsocket')

def setup_env(self):
def _on_started():
self._server._loop.dettach_periodic_callback()
lock.release()
def _start_server(server):
server._loop.attach_periodic_callback(_on_started, 1)
server.start()
server.close()

# Use builder=uds
self._server = msgpackrpc.Server(TestMessagePackRPC.TestServer(), builder=uds)
self._server.listen(self._address)
self._thread = threading.Thread(target=_start_server, args=(self._server,))

lock = threading.Lock()
self._thread.start()
lock.acquire()
lock.acquire() # wait for the server to start

self._client = msgpackrpc.Client(self._address, unpack_encoding='utf-8')
return self._client;

def tearDown(self):
self._client.close();
self._server.stop();
self._thread.join();

def test_call(self):
client = self.setup_env();

result1 = client.call('hello')
result2 = client.call('sum', 1, 2)
result3 = client.call('nil')

self.assertEqual(result1, "world", "'hello' result is incorrect")
self.assertEqual(result2, 3, "'sum' result is incorrect")
self.assertIsNone(result3, "'nil' result is incorrect")

def test_call_userdefined_arg(self):
client = self.setup_env();

arg = TestMessagePackRPC.TestArg(0, 1, 2)
arg2 = TestMessagePackRPC.TestArg(23, 3, -23)

result1 = TestMessagePackRPC.TestArg.from_msgpack(client.call('add_arg', arg, arg2))
self.assertEqual(result1, arg.add(arg2))

result2 = TestMessagePackRPC.TestArg.from_msgpack(client.call('add_arg', arg2, arg))
self.assertEqual(result2, arg2.add(arg))

result3 = TestMessagePackRPC.TestArg.from_msgpack(client.call('add_arg', result1, result2))
self.assertEqual(result3, result1.add(result2))

def test_call_async(self):
client = self.setup_env();

future1 = client.call_async('hello')
future2 = client.call_async('sum', 1, 2)
future3 = client.call_async('nil')
future1.join()
future2.join()
future3.join()

self.assertEqual(future1.result, "world", "'hello' result is incorrect in call_async")
self.assertEqual(future2.result, 3, "'sum' result is incorrect in call_async")
self.assertIsNone(future3.result, "'nil' result is incorrect in call_async")

def test_notify(self):
client = self.setup_env();

result = True
try:
client.notify('hello')
client.notify('sum', 1, 2)
client.notify('nil')
except:
result = False

self.assertTrue(result)

def test_raise_error(self):
client = self.setup_env();
self.assertRaises(error.RPCError, lambda: client.call('raise_error'))

def test_unknown_method(self):
client = self.setup_env();
self.assertRaises(error.RPCError, lambda: client.call('unknown', True))
try:
client.call('unknown', True)
self.assertTrue(False)
except error.RPCError as e:
message = e.args[0]
self.assertEqual(message, "'unknown' method not found", "Error message mismatched")

def test_async_result(self):
client = self.setup_env();
self.assertEqual(client.call('async_result'), "You are async!")

def test_connect_failed(self):
client = self.setup_env();
port = helper.unused_port()
client = msgpackrpc.Client(msgpackrpc.Address('localhost', port), unpack_encoding='utf-8')
self.assertRaises(error.TransportError, lambda: client.call('hello'))

def test_timeout(self):
client = self.setup_env();

if self.__class__.ENABLE_TIMEOUT_TEST:
self.assertEqual(client.call('long_exec'), 'finish!', "'long_exec' result is incorrect")

client = msgpackrpc.Client(self._address, timeout=1, unpack_encoding='utf-8')
self.assertRaises(error.TimeoutError, lambda: client.call('long_exec'))
else:
print("Skip test_timeout")


if __name__ == '__main__':
import sys

try:
sys.argv.remove('--timeout-test')
TestMessagePackRPC.ENABLE_TIMEOUT_TEST = True
except:
pass

unittest.main()
Loading