Skip to content

Commit b5aef48

Browse files
committed
Use option instead of long cast statements. Also try to automate and delegate resolution to flag initializer. This makes flag class independent of cli arg names
1 parent 263c067 commit b5aef48

File tree

1 file changed

+39
-91
lines changed

1 file changed

+39
-91
lines changed

proxy/common/flag.py

Lines changed: 39 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
from typing import Optional, List, Any, cast
2121

2222
from .plugins import Plugins
23-
from .types import IpAddress
2423
from .utils import bytes_, is_py2, is_threadless, set_open_file_limit
2524
from .constants import COMMA, DEFAULT_DATA_DIRECTORY_PATH, DEFAULT_NUM_ACCEPTORS, DEFAULT_NUM_WORKERS
2625
from .constants import DEFAULT_DEVTOOLS_WS_PATH, DEFAULT_DISABLE_HEADERS, PY2_DEPRECATION_MESSAGE
@@ -109,12 +108,27 @@ def initialize(
109108
print(__version__)
110109
sys.exit(0)
111110

111+
# https://github.com/python/mypy/issues/5865
112+
def option(t: object, key: str, default: Optional[Any] = None) -> Any:
113+
return cast(
114+
t, # type: ignore
115+
opts.get(
116+
key,
117+
default or getattr(args, key),
118+
),
119+
)
120+
121+
# Command line arguments MUST always take preference
122+
# over kwargs passed to the program constructor.
123+
# for f in args.__dict__.keys():
124+
# print(f)
125+
# print(option(Any, f))
126+
112127
# proxy.py currently cannot serve over HTTPS and also perform TLS interception
113128
# at the same time. Check if user is trying to enable both feature
114129
# at the same time.
115130
#
116-
# TODO: Use parser.add_mutually_exclusive_group()
117-
# and remove this logic from here.
131+
# TODO: Use parser.add_mutually_exclusive_group() and remove this logic from here.
118132
if (args.cert_file and args.key_file) and \
119133
(args.ca_key_file and args.ca_cert_file and args.ca_signing_key_file):
120134
print(
@@ -157,27 +171,9 @@ def initialize(
157171

158172
# --enable flags must be parsed before loading plugins
159173
# otherwise we will miss the plugins passed via constructor
160-
args.enable_web_server = cast(
161-
bool,
162-
opts.get(
163-
'enable_web_server',
164-
args.enable_web_server,
165-
),
166-
)
167-
args.enable_static_server = cast(
168-
bool,
169-
opts.get(
170-
'enable_static_server',
171-
args.enable_static_server,
172-
),
173-
)
174-
args.enable_events = cast(
175-
bool,
176-
opts.get(
177-
'enable_events',
178-
args.enable_events,
179-
),
180-
)
174+
args.enable_web_server = option(bool, 'enable_web_server')
175+
args.enable_static_server = option(bool, 'enable_static_server')
176+
args.enable_events = option(bool, 'enable_events')
181177

182178
# Load default plugins along with user provided --plugins
183179
default_plugins = [
@@ -191,10 +187,6 @@ def initialize(
191187
default_plugins + auth_plugins + requested_plugins,
192188
)
193189

194-
# https://github.com/python/mypy/issues/5865
195-
#
196-
# def option(t: object, key: str, default: Any) -> Any:
197-
# return cast(t, opts.get(key, default))
198190
args.work_klass = work_klass
199191
args.plugins = plugins
200192
args.auth_code = cast(
@@ -204,20 +196,8 @@ def initialize(
204196
auth_code,
205197
),
206198
)
207-
args.server_recvbuf_size = cast(
208-
int,
209-
opts.get(
210-
'server_recvbuf_size',
211-
args.server_recvbuf_size,
212-
),
213-
)
214-
args.client_recvbuf_size = cast(
215-
int,
216-
opts.get(
217-
'client_recvbuf_size',
218-
args.client_recvbuf_size,
219-
),
220-
)
199+
args.server_recvbuf_size = option(int, 'server_recvbuf_size')
200+
args.client_recvbuf_size = option(int, 'client_recvbuf_size')
221201
args.pac_file = cast(
222202
Optional[str], opts.get(
223203
'pac_file', bytes_(
@@ -241,44 +221,18 @@ def initialize(
241221
],
242222
),
243223
)
244-
args.disable_headers = disabled_headers if disabled_headers is not None else DEFAULT_DISABLE_HEADERS
245-
args.certfile = cast(
246-
Optional[str], opts.get(
247-
'cert_file', args.cert_file,
248-
),
249-
)
250-
args.keyfile = cast(Optional[str], opts.get('key_file', args.key_file))
251-
args.ca_key_file = cast(
252-
Optional[str], opts.get(
253-
'ca_key_file', args.ca_key_file,
254-
),
255-
)
256-
args.ca_cert_file = cast(
257-
Optional[str], opts.get(
258-
'ca_cert_file', args.ca_cert_file,
259-
),
260-
)
261-
args.ca_signing_key_file = cast(
262-
Optional[str],
263-
opts.get(
264-
'ca_signing_key_file',
265-
args.ca_signing_key_file,
266-
),
267-
)
268-
args.ca_file = cast(
269-
Optional[str],
270-
opts.get(
271-
'ca_file',
272-
args.ca_file,
273-
),
274-
)
275-
args.hostname = cast(
276-
IpAddress,
277-
opts.get('hostname', ipaddress.ip_address(args.hostname)),
278-
)
279-
args.unix_socket_path = opts.get(
280-
'unix_socket_path', args.unix_socket_path,
281-
)
224+
args.disable_headers = disabled_headers \
225+
if disabled_headers is not None \
226+
else DEFAULT_DISABLE_HEADERS
227+
args.certfile = option(Optional[str], 'cert_file')
228+
args.keyfile = option(Optional[str], 'key_file')
229+
args.ca_key_file = option(Optional[str], 'ca_key_file')
230+
args.ca_cert_file = option(Optional[str], 'ca_cert_file')
231+
args.ca_signing_key_file = option(Optional[str], 'ca_signing_key_file')
232+
args.ca_file = option(Optional[str], 'ca_file')
233+
args.hostname = option(str, 'hostname')
234+
args.hostname = ipaddress.ip_address(args.hostname)
235+
args.unix_socket_path = option(str, 'unix_socket_path')
282236
# AF_UNIX is not available on Windows
283237
# See https://bugs.python.org/issue33408
284238
if not IS_WINDOWS:
@@ -294,13 +248,13 @@ def initialize(
294248
#
295249
# assert args.unix_socket_path is None
296250
args.family = socket.AF_INET6 if args.hostname.version == 6 else socket.AF_INET
297-
args.port = cast(int, opts.get('port', args.port))
298-
args.backlog = cast(int, opts.get('backlog', args.backlog))
299-
num_workers = opts.get('num_workers', args.num_workers)
251+
args.port = option(int, 'port')
252+
args.backlog = option(int, 'backlog')
253+
num_workers = option(int, 'num_workers')
300254
args.num_workers = cast(
301255
int, num_workers if num_workers > 0 else multiprocessing.cpu_count(),
302256
)
303-
num_acceptors = opts.get('num_acceptors', args.num_acceptors)
257+
num_acceptors = option(int, 'num_acceptors')
304258
# See https://github.com/abhinavsingh/proxy.py/pull/714 description
305259
# to understand rationale behind the following logic.
306260
#
@@ -314,13 +268,7 @@ def initialize(
314268
int, num_acceptors if num_acceptors > 0 else multiprocessing.cpu_count(),
315269
)
316270

317-
args.static_server_dir = cast(
318-
str,
319-
opts.get(
320-
'static_server_dir',
321-
args.static_server_dir,
322-
),
323-
)
271+
args.static_server_dir = option(str, 'static_server_dir')
324272
args.min_compression_limit = cast(
325273
bool,
326274
opts.get(

0 commit comments

Comments
 (0)