66
77import magic
88
9- from .compat import BaseHTTPRequestHandler , urlsplit , parse_qs , encode_to_bytes , decode_from_bytes
9+ from .compat import (
10+ BaseHTTPRequestHandler ,
11+ urlsplit ,
12+ parse_qs ,
13+ encode_to_bytes ,
14+ decode_from_bytes ,
15+ unquote_utf8 ,
16+ )
1017from .mocket import Mocket , MocketEntry
1118
1219
@@ -22,41 +29,57 @@ def __init__(self, data):
2229 self .error_code = self .error_message = None
2330 self .parse_request ()
2431 self .method = self .command
32+ self .querystring = parse_qs (unquote_utf8 (urlsplit (self .path ).query ), keep_blank_values = True )
33+
34+ def __str__ (self ):
35+ return "{} - {} - {}" .format (self .method , self .path , self .headers )
2536
2637 def __str__ (self ):
2738 return "{} - {} - {}" .format (self .method , self .path , self .headers )
2839
2940
3041class Response (object ):
42+ headers = None
43+ is_file_object = False
44+
3145 def __init__ (self , body = '' , status = 200 , headers = None ):
3246 headers = headers or {}
33- is_file_object = False
3447 try :
3548 # File Objects
3649 self .body = body .read ()
37- is_file_object = True
50+ self . is_file_object = True
3851 except AttributeError :
3952 self .body = encode_to_bytes (body )
4053 self .status = status
54+
55+ self .set_base_headers ()
56+
57+ if headers is not None :
58+ self .set_extra_headers (headers )
59+
60+ self .data = self .get_protocol_data () + self .body
61+
62+ def get_protocol_data (self ):
63+ status_line = 'HTTP/1.1 {status_code} {status}' .format (status_code = self .status , status = STATUS [self .status ])
64+ header_lines = CRLF .join (['{0}: {1}' .format (k .capitalize (), v ) for k , v in self .headers .items ()])
65+ return '{0}\r \n {1}\r \n \r \n ' .format (status_line , header_lines ).encode ('utf-8' )
66+
67+ def set_base_headers (self ):
4168 self .headers = {
4269 'Status' : str (self .status ),
4370 'Date' : time .strftime ('%a, %d %b %Y %H:%M:%S GMT' , time .gmtime ()),
4471 'Server' : 'Python/Mocket' ,
4572 'Connection' : 'close' ,
4673 'Content-Length' : str (len (self .body )),
4774 }
48- if not is_file_object :
75+ if not self . is_file_object :
4976 self .headers ['Content-Type' ] = 'text/plain; charset=utf-8'
5077 else :
5178 self .headers ['Content-Type' ] = decode_from_bytes (magic .from_buffer (self .body , mime = True ))
79+
80+ def set_extra_headers (self , headers ):
5281 for k , v in headers .items ():
5382 self .headers ['-' .join ([token .capitalize () for token in k .split ('-' )])] = v
54- self .data = self .get_protocol_data () + self .body
55-
56- def get_protocol_data (self ):
57- status_line = 'HTTP/1.1 {status_code} {status}' .format (status_code = self .status , status = STATUS [self .status ])
58- header_lines = CRLF .join (['{0}: {1}' .format (k .capitalize (), v ) for k , v in self .headers .items ()])
59- return '{0}\r \n {1}\r \n \r \n ' .format (status_line , header_lines ).encode ('utf-8' )
6083
6184
6285class Entry (MocketEntry ):
@@ -75,7 +98,7 @@ class Entry(MocketEntry):
7598 request_cls = Request
7699 response_cls = Response
77100
78- def __init__ (self , uri , method , responses ):
101+ def __init__ (self , uri , method , responses , match_querystring = True ):
79102 uri = urlsplit (uri )
80103
81104 if not uri .port :
@@ -90,6 +113,7 @@ def __init__(self, uri, method, responses):
90113 self .query = uri .query
91114 self .method = method .upper ()
92115 self ._sent_data = b''
116+ self ._match_querystring = match_querystring
93117
94118 def collect (self , data ):
95119 self ._sent_data += data
@@ -116,11 +140,13 @@ def can_handle(self, data):
116140 except AttributeError :
117141 return False
118142 uri = urlsplit (path )
119- kw = dict (keep_blank_values = True )
120- ch = uri .path == self .path and parse_qs (uri .query , ** kw ) == parse_qs (self .query , ** kw ) and method == self .method
121- if ch :
143+ can_handle = uri .path == self .path and method == self .method
144+ if self ._match_querystring :
145+ kw = dict (keep_blank_values = True )
146+ can_handle = can_handle and parse_qs (uri .query , ** kw ) == parse_qs (self .query , ** kw )
147+ if can_handle :
122148 Mocket ._last_entry = self
123- return ch
149+ return can_handle
124150
125151 @staticmethod
126152 def _parse_requestline (line ):
@@ -144,9 +170,21 @@ def _parse_requestline(line):
144170 raise ValueError ('Not a Request-Line' )
145171
146172 @classmethod
147- def register (cls , method , uri , * responses ):
148- Mocket .register (cls (uri , method , responses ))
173+ def register (cls , method , uri , * responses , ** config ):
174+
175+ default_config = dict (match_querystring = True , add_trailing_slash = True )
176+ default_config .update (config )
177+ config = default_config
178+
179+ if config ['add_trailing_slash' ] and not urlsplit (uri ).path :
180+ uri += '/'
181+
182+ Mocket .register (cls (uri , method , responses , match_querystring = config ['match_querystring' ]))
149183
150184 @classmethod
151- def single_register (cls , method , uri , body = '' , status = 200 , headers = None ):
152- cls .register (method , uri , Response (body = body , status = status , headers = headers ))
185+ def single_register (cls , method , uri , body = '' , status = 200 , headers = None , match_querystring = True ):
186+ cls .register (
187+ method , uri , cls .response_cls (
188+ body = body , status = status , headers = headers
189+ ), match_querystring = match_querystring
190+ )
0 commit comments