1
+ # -*- coding: utf-8 -*-
2
+
3
+ """
4
+ Copyright 2006-2008 SpringSource (http://springsource.com), All Rights Reserved
5
+
6
+ Licensed under the Apache License, Version 2.0 (the "License");
7
+ you may not use this file except in compliance with the License.
8
+ You may obtain a copy of the License at
9
+
10
+ http://www.apache.org/licenses/LICENSE-2.0
11
+
12
+ Unless required by applicable law or agreed to in writing, software
13
+ distributed under the License is distributed on an "AS IS" BASIS,
14
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ See the License for the specific language governing permissions and
16
+ limitations under the License.
17
+ """
18
+
19
+ # stdlib
20
+ import socket
21
+ import ssl
22
+ import threading
23
+ import time
24
+ import unittest
25
+
26
+ from SocketServer import StreamRequestHandler
27
+ from xmlrpclib import Transport
28
+
29
+ # Spring Python
30
+ from springpython .remoting .xmlrpc import SSLServer , SSLClient , RequestHandler , \
31
+ SSLClientTransport , VerificationException
32
+
33
+ RESULT_OK = "All good"
34
+
35
+ server_key = "./support/pki/server-key.pem"
36
+ server_cert = "./support/pki/server-cert.pem"
37
+ client_key = "./support/pki/client-key.pem"
38
+ client_cert = "./support/pki/client-cert.pem"
39
+ ca_certs = "./support/pki/ca-chain.pem"
40
+
41
+ class MySSLServer (SSLServer ):
42
+
43
+ def test_server (self ):
44
+ return RESULT_OK
45
+
46
+ def register_functions (self ):
47
+ self .register_function (self .shutdown )
48
+ self .register_function (self .test_server )
49
+
50
+ class _DummyServer (SSLServer ):
51
+ pass
52
+
53
+ class _DummyRequest ():
54
+ def recv (self , * ignored_args , ** ignored_kwargs ):
55
+ pass
56
+
57
+ class _MyClientTransport (object ):
58
+ def __init__ (self , ca_certs = None , keyfile = None , certfile = None , cert_reqs = None ,
59
+ ssl_version = None , timeout = None , strict = None ):
60
+ self .ca_certs = ca_certs
61
+ self .keyfile = keyfile
62
+ self .certfile = certfile
63
+ self .cert_reqs = cert_reqs
64
+ self .ssl_version = ssl_version
65
+ self .timeout = timeout
66
+ self .strict = strict
67
+
68
+ class TestInitDefaultArguments (unittest .TestCase ):
69
+ def test_init_default_arguments (self ):
70
+ """ Tests various defaults various and those passed to __init__'s.
71
+ """
72
+
73
+ self .assertTrue (issubclass (VerificationException , Exception ))
74
+ self .assertEqual (RequestHandler .rpc_paths , ("/" , "/RPC2" ))
75
+ self .assertEqual (SSLClientTransport .user_agent ,
76
+ "SSL XML-RPC Client (by http://springpython.webfactional.com)" )
77
+
78
+ server1 = MySSLServer ("127.0.0.1" , 8001 )
79
+
80
+ self .assertEqual (server1 .keyfile , None )
81
+ self .assertEqual (server1 .certfile , None )
82
+ self .assertEqual (server1 .ca_certs , None )
83
+ self .assertEqual (server1 .cert_reqs , ssl .CERT_NONE )
84
+ self .assertEqual (server1 .ssl_version , ssl .PROTOCOL_TLSv1 )
85
+ self .assertEqual (server1 .do_handshake_on_connect , True )
86
+ self .assertEqual (server1 .suppress_ragged_eofs , True )
87
+ self .assertEqual (server1 .ciphers , None )
88
+ self .assertEqual (server1 .verify_fields , None )
89
+
90
+ server_host = "127.0.0.1"
91
+ server_port = 8002
92
+ server_keyfile = "server_keyfile"
93
+ server_certfile = "server_certfile"
94
+ server_ca_certs = "server_ca_certs"
95
+ server_cert_reqs = ssl .CERT_OPTIONAL
96
+ server_ssl_version = ssl .PROTOCOL_SSLv3
97
+ server_do_handshake_on_connect = False
98
+ server_suppress_ragged_eofs = False
99
+ server_ciphers = "ALL"
100
+ server_verify_fields = {"commonName" : "Foo" , "organizationName" :"Baz" }
101
+
102
+ server2 = MySSLServer (server_host , server_port , server_keyfile ,
103
+ server_certfile , server_ca_certs , server_cert_reqs ,
104
+ server_ssl_version , server_do_handshake_on_connect ,
105
+ server_suppress_ragged_eofs , server_ciphers ,
106
+ verify_fields = server_verify_fields )
107
+
108
+ # inherited from SocketServer.BaseServer
109
+ self .assertEqual (server2 .server_address , (server_host , server_port ))
110
+
111
+ self .assertEqual (server2 .keyfile , server_keyfile )
112
+ self .assertEqual (server2 .certfile , server_certfile )
113
+ self .assertEqual (server2 .ca_certs , server_ca_certs )
114
+ self .assertEqual (server2 .cert_reqs , server_cert_reqs )
115
+ self .assertEqual (server2 .ssl_version , server_ssl_version )
116
+ self .assertEqual (server2 .do_handshake_on_connect , server_do_handshake_on_connect )
117
+ self .assertEqual (server2 .suppress_ragged_eofs , server_suppress_ragged_eofs )
118
+ self .assertEqual (server2 .ciphers , server_ciphers )
119
+ self .assertEqual (sorted (server2 .verify_fields ), sorted (server_verify_fields ))
120
+
121
+ client_uri = "https://127.0.0.1:8000/RPC2"
122
+ client_ca_certs = "client_ca_certs"
123
+ client_keyfile = "client_keyfile"
124
+ client_certfile = "client_certfile"
125
+ client_cert_reqs = ssl .CERT_OPTIONAL
126
+ client_ssl_version = ssl .PROTOCOL_SSLv23
127
+ client_transport = _MyClientTransport
128
+ client_encoding = "utf-16"
129
+ client_verbose = 1
130
+ client_allow_none = False
131
+ client_use_datetime = False
132
+ client_timeout = 13
133
+ client_strict = True
134
+
135
+ client2 = SSLClient (client_uri , client_ca_certs , client_keyfile ,
136
+ client_certfile , client_cert_reqs , client_ssl_version ,
137
+ client_transport , client_encoding , client_verbose ,
138
+ client_allow_none , client_use_datetime , client_timeout ,
139
+ client_strict )
140
+
141
+ self .assertEqual (client2 ._ServerProxy__host , "127.0.0.1:8000" )
142
+ self .assertEqual (client2 ._ServerProxy__transport .ca_certs , client_ca_certs )
143
+ self .assertEqual (client2 ._ServerProxy__transport .keyfile , client_keyfile )
144
+ self .assertEqual (client2 ._ServerProxy__transport .certfile , client_certfile )
145
+ self .assertEqual (client2 ._ServerProxy__transport .cert_reqs , client_cert_reqs )
146
+ self .assertEqual (client2 ._ServerProxy__transport .ssl_version , client_ssl_version )
147
+ self .assertTrue (isinstance (client2 ._ServerProxy__transport , _MyClientTransport ))
148
+ self .assertEqual (client2 ._ServerProxy__encoding , client_encoding )
149
+ self .assertEqual (client2 ._ServerProxy__verbose , client_verbose )
150
+ self .assertEqual (client2 ._ServerProxy__allow_none , client_allow_none )
151
+ self .assertEqual (client2 ._ServerProxy__transport .timeout , client_timeout )
152
+ self .assertEqual (client2 ._ServerProxy__transport .strict , client_strict )
153
+
154
+ self .assertRaises (NotImplementedError , _DummyServer , "127.0.0.1" , 8003 )
155
+
156
+ def test_request_handler (self ):
157
+ request = _DummyRequest ()
158
+ rh = RequestHandler (request , None , None )
159
+ rh .setup ()
160
+ self .assertTrue (rh .connection is request )
161
+ self .assertTrue (isinstance (rh .rfile , socket ._fileobject ))
162
+ self .assertTrue (isinstance (rh .wfile , socket ._fileobject ))
163
+ self .assertTrue (rh .rfile ._sock is request )
164
+ self .assertEqual (rh .rfile .mode , "rb" )
165
+ self .assertEqual (rh .rfile .bufsize , socket ._fileobject .default_bufsize )
166
+ self .assertTrue (rh .wfile ._sock is request )
167
+ self .assertEqual (rh .wfile .mode , "wb" )
168
+ self .assertEqual (rh .wfile .bufsize , StreamRequestHandler .wbufsize )
169
+
170
+ def xtest_imports (self ):
171
+ raise NotImplemented ()
172
+
173
+ class TestSSL (unittest .TestCase ):
174
+
175
+ class _ClientServerContextManager (object ):
176
+ def __init__ (self , server_port , cert_reqs = ssl .CERT_NONE , verify_fields = {}):
177
+ self .server_port = server_port
178
+ self .cert_reqs = cert_reqs
179
+ self .verify_fields = verify_fields
180
+
181
+ def __enter__ (self ):
182
+ server = MySSLServer ("127.0.0.1" , self .server_port , server_key ,
183
+ server_cert , ca_certs , cert_reqs = self .cert_reqs ,
184
+ verify_fields = self .verify_fields )
185
+ self .server_thread = self ._start_server (server )
186
+ time .sleep (0.5 )
187
+
188
+ def __exit__ (self , * ignored_args ):
189
+ self .server_thread .server .shutdown ()
190
+
191
+ def _start_server (self , server ):
192
+
193
+ class _ServerController (threading .Thread ):
194
+ def __init__ (self , server ):
195
+ self .server = server
196
+ self .isDaemon = False
197
+ super (_ServerController , self ).__init__ ()
198
+
199
+ def run (self ):
200
+ self .server .serve_forever ()
201
+
202
+ server_thread = _ServerController (server )
203
+ server_thread .start ()
204
+
205
+ return server_thread
206
+
207
+
208
+ def xtest_simple_ssl (self ):
209
+ """ Server uses its cert, client uses none.
210
+ """
211
+ server_port = 9001
212
+ with TestSSL ._ClientServerContextManager (server_port ):
213
+ client = SSLClient ("https://localhost:%d/RPC2" % server_port , ca_certs )
214
+ self .assertEqual (client .test_server (), RESULT_OK )
215
+
216
+ def xtest_client_cert (self ):
217
+ """ Server & client use certs.
218
+ """
219
+ server_port = 9002
220
+ with TestSSL ._ClientServerContextManager (server_port , ssl .CERT_REQUIRED ):
221
+ client = SSLClient ("https://localhost:%d/RPC2" % server_port , ca_certs ,
222
+ client_key , client_cert )
223
+ self .assertEqual (client .test_server (), RESULT_OK )
224
+
225
+ def xtest_client_cert_ok (self ):
226
+ """ Server & client use certs. Server succesfully validates client certificate's fields.
227
+ """
228
+ server_port = 9003
229
+ verify_fields = {"commonName" :"My Client" , "countryName" :"US" ,
230
+ "organizationalUnitName" :"My Unit" , "organizationName" :"My Company" ,
231
+ "stateOrProvinceName" :"My State" }
232
+
233
+ with TestSSL ._ClientServerContextManager (server_port , ssl .CERT_REQUIRED , verify_fields ):
234
+ client = SSLClient ("https://localhost:%d/RPC2" % server_port , ca_certs ,
235
+ client_key , client_cert )
236
+ self .assertEqual (client .test_server (), RESULT_OK )
237
+
238
+ def xtest_client_cert_failure_missing_field (self ):
239
+ """ Server & client use certs. Server fails to validate client certificate's fields
240
+ (a field is missing).
241
+ """
242
+ server_port = 9004
243
+ verify_fields = {"commonName" :"My Client" , "countryName" :"US" ,
244
+ "organizationalUnitName" :"My Unit" , "organizationName" :"My Company" ,
245
+ "stateOrProvinceName" :"My State" , "FOO" : "BAR" }
246
+
247
+ with TestSSL ._ClientServerContextManager (server_port , ssl .CERT_REQUIRED , verify_fields ):
248
+ client = SSLClient ("https://localhost:%d/RPC2" % server_port , ca_certs ,
249
+ client_key , client_cert )
250
+ self .assertRaises (Exception , client .test_server )
251
+
252
+ def xtest_client_cert_failure_field_incorrect_value (self ):
253
+ """ Server & client use certs. Server fails to validate client certificate's fields
254
+ (all fields are in place, but commonName has an incorrect value).
255
+ """
256
+ server_port = 9005
257
+ verify_fields = {"commonName" :"Invalid" }
258
+ with TestSSL ._ClientServerContextManager (server_port , ssl .CERT_REQUIRED , verify_fields ):
259
+ client = SSLClient ("https://localhost:%d/RPC2" % server_port , ca_certs ,
260
+ client_key , client_cert )
261
+ self .assertRaises (Exception , client .test_server )
262
+
263
+ def test_client_cert_failure_no_client_cert (self ):
264
+ """ Server optionally requires a client to send the certificate
265
+ and validates its fields but client sends none.
266
+ """
267
+ server_port = 9006
268
+ verify_fields = {"commonName" :"My Client" }
269
+ with TestSSL ._ClientServerContextManager (server_port , ssl .CERT_OPTIONAL , verify_fields ):
270
+ client = SSLClient ("https://localhost:%d/RPC2" % server_port , ca_certs )
271
+ self .assertRaises (Exception , client .test_server )
272
+
273
+ if __name__ == "__main__" :
274
+ unittest .main ()
0 commit comments