diff --git a/libmproxy/cmdline.py b/libmproxy/cmdline.py index 5114f371a2..388607afa9 100644 --- a/libmproxy/cmdline.py +++ b/libmproxy/cmdline.py @@ -209,7 +209,7 @@ def common_options(parser): ) parser.add_argument( "-s", - action="append", type=lambda x: shlex.split(x,posix=(os.name != "nt")), dest="scripts", default=[], + action="append", type=str, dest="scripts", default=[], metavar='"script.py --bar"', help="Run a script. Surround with quotes to pass script arguments. Can be passed multiple times." ) diff --git a/libmproxy/console/__init__.py b/libmproxy/console/__init__.py index f68084fff3..536b0bac45 100644 --- a/libmproxy/console/__init__.py +++ b/libmproxy/console/__init__.py @@ -442,13 +442,13 @@ def _run_script_method(self, method, s, f): else: self.add_event("Method %s error: %s"%(method, val[1])) - def run_script_once(self, path, f): - if not path: + def run_script_once(self, command, f): + if not command: return - self.add_event("Running script on flow: %s"%path) + self.add_event("Running script on flow: %s"%command) try: - s = script.Script(shlex.split(path, posix=(os.name != "nt")), self) + s = script.Script(command, self) except script.ScriptError, v: self.statusbar.message("Error loading script.") self.add_event("Error loading script:\n%s"%v.args[0]) @@ -462,15 +462,15 @@ def run_script_once(self, path, f): self._run_script_method("error", s, f) s.unload() self.refresh_flow(f) - self.state.last_script = path + self.state.last_script = command - def set_script(self, path): - if not path: + def set_script(self, command): + if not command: return - ret = self.load_script(path) + ret = self.load_script(command) if ret: self.statusbar.message(ret) - self.state.last_script = path + self.state.last_script = command def toggle_eventlog(self): self.eventlog = not self.eventlog diff --git a/libmproxy/dump.py b/libmproxy/dump.py index 83139b46b9..e76ea1ce33 100644 --- a/libmproxy/dump.py +++ b/libmproxy/dump.py @@ -111,8 +111,8 @@ def __init__(self, server, options, filtstr, outfile=sys.stdout): ) scripts = options.scripts or [] - for script_argv in scripts: - err = self.load_script(script_argv) + for command in scripts: + err = self.load_script(command) if err: raise DumpError(err) diff --git a/libmproxy/flow.py b/libmproxy/flow.py index 0aa0319821..e0013f1ecf 100644 --- a/libmproxy/flow.py +++ b/libmproxy/flow.py @@ -1398,13 +1398,13 @@ def unload_script(self, script): script.unload() self.scripts.remove(script) - def load_script(self, script_argv): + def load_script(self, command): """ Loads a script. Returns an error description if something went wrong. """ try: - s = script.Script(script_argv, self) + s = script.Script(command, self) except script.ScriptError, v: return v.args[0] self.scripts.append(s) diff --git a/libmproxy/script.py b/libmproxy/script.py index 747d1567cc..a80f4694d3 100644 --- a/libmproxy/script.py +++ b/libmproxy/script.py @@ -1,4 +1,4 @@ -import os, traceback, threading +import os, traceback, threading, shlex import controller class ScriptError(Exception): @@ -45,8 +45,9 @@ class Script: s = Script(argv, master) s.load() """ - def __init__(self, argv, master): - self.argv = argv + def __init__(self, command, master): + self.command = command + self.argv = shlex.split(command, posix=(os.name != "nt")) self.ctx = ScriptContext(master) self.ns = None self.load() @@ -99,7 +100,6 @@ def run(self, name, *args, **kwargs): def _handle_concurrent_reply(fn, o, args=[], kwargs={}): reply = o.reply o.reply = controller.DummyReply() - def run(): fn(*args, **kwargs) reply(o) diff --git a/test/test_cmdline.py b/test/test_cmdline.py index 92e2adbd14..dbc61bfc21 100644 --- a/test/test_cmdline.py +++ b/test/test_cmdline.py @@ -40,19 +40,6 @@ def test_parse_setheaders(): x = cmdline.parse_setheader("/foo/bar/voing") assert x == ("foo", "bar", "voing") -def test_shlex(): - """ - shlex.split assumes posix=True by default, we do manual detection for windows. - Test whether script paths are parsed correctly - """ - absfilepath = os.path.normcase(os.path.abspath(__file__)) - - parser = argparse.ArgumentParser() - cmdline.common_options(parser) - opts = parser.parse_args(args=["-s",absfilepath]) - - assert os.path.isfile(opts.scripts[0][0]) - def test_common(): parser = argparse.ArgumentParser() cmdline.common_options(parser) diff --git a/test/test_dump.py b/test/test_dump.py index 9874b650d7..a958a2ec5b 100644 --- a/test/test_dump.py +++ b/test/test_dump.py @@ -153,7 +153,7 @@ def test_write_err(self): def test_script(self): ret = self._dummy_cycle( 1, None, "", - scripts=[[tutils.test_data.path("scripts/all.py")]], verbosity=0, eventlog=True + scripts=[tutils.test_data.path("scripts/all.py")], verbosity=0, eventlog=True ) assert "XCLIENTCONNECT" in ret assert "XSERVERCONNECT" in ret @@ -162,11 +162,11 @@ def test_script(self): assert "XCLIENTDISCONNECT" in ret tutils.raises( dump.DumpError, - self._dummy_cycle, 1, None, "", scripts=[["nonexistent"]] + self._dummy_cycle, 1, None, "", scripts=["nonexistent"] ) tutils.raises( dump.DumpError, - self._dummy_cycle, 1, None, "", scripts=[["starterr.py"]] + self._dummy_cycle, 1, None, "", scripts=["starterr.py"] ) def test_stickycookie(self): diff --git a/test/test_flow.py b/test/test_flow.py index bf6a7a4220..680d59e55b 100644 --- a/test/test_flow.py +++ b/test/test_flow.py @@ -563,12 +563,12 @@ class TestFlowMaster: def test_load_script(self): s = flow.State() fm = flow.FlowMaster(None, s) - assert not fm.load_script([tutils.test_data.path("scripts/a.py")]) - assert not fm.load_script([tutils.test_data.path("scripts/a.py")]) + assert not fm.load_script(tutils.test_data.path("scripts/a.py")) + assert not fm.load_script(tutils.test_data.path("scripts/a.py")) assert not fm.unload_script(fm.scripts[0]) assert not fm.unload_script(fm.scripts[0]) - assert fm.load_script(["nonexistent"]) - assert "ValueError" in fm.load_script([tutils.test_data.path("scripts/starterr.py")]) + assert fm.load_script("nonexistent") + assert "ValueError" in fm.load_script(tutils.test_data.path("scripts/starterr.py")) assert len(fm.scripts) == 0 def test_replay(self): @@ -584,7 +584,7 @@ def test_replay(self): def test_script_reqerr(self): s = flow.State() fm = flow.FlowMaster(None, s) - assert not fm.load_script([tutils.test_data.path("scripts/reqerr.py")]) + assert not fm.load_script(tutils.test_data.path("scripts/reqerr.py")) req = tutils.treq() fm.handle_clientconnect(req.client_conn) assert fm.handle_request(req) @@ -592,7 +592,7 @@ def test_script_reqerr(self): def test_script(self): s = flow.State() fm = flow.FlowMaster(None, s) - assert not fm.load_script([tutils.test_data.path("scripts/all.py")]) + assert not fm.load_script(tutils.test_data.path("scripts/all.py")) req = tutils.treq() fm.handle_clientconnect(req.client_conn) assert fm.scripts[0].ns["log"][-1] == "clientconnect" @@ -606,7 +606,7 @@ def test_script(self): fm.handle_response(resp) assert fm.scripts[0].ns["log"][-1] == "response" #load second script - assert not fm.load_script([tutils.test_data.path("scripts/all.py")]) + assert not fm.load_script(tutils.test_data.path("scripts/all.py")) assert len(fm.scripts) == 2 dc = flow.ClientDisconnect(req.client_conn) dc.reply = controller.DummyReply() @@ -659,7 +659,7 @@ def test_all(self): err.reply = controller.DummyReply() fm.handle_error(err) - fm.load_script([tutils.test_data.path("scripts/a.py")]) + fm.load_script(tutils.test_data.path("scripts/a.py")) fm.shutdown() def test_client_playback(self): diff --git a/test/test_script.py b/test/test_script.py index 2664b84072..39aa12e9c0 100644 --- a/test/test_script.py +++ b/test/test_script.py @@ -10,10 +10,8 @@ class TestScript: def test_simple(self): s = flow.State() fm = flow.FlowMaster(None, s) - p = script.Script( - shlex.split(tutils.test_data.path("scripts/a.py")+" --var 40",posix=(os.name != "nt")), fm - ) - p.load() + sp = tutils.test_data.path("scripts/a.py") + p = script.Script("%s --var 40"%sp, fm) assert "here" in p.ns assert p.run("here") == (True, 41) @@ -30,7 +28,7 @@ def test_simple(self): def test_duplicate_flow(self): s = flow.State() fm = flow.FlowMaster(None, s) - fm.load_script([tutils.test_data.path("scripts/duplicate_flow.py")]) + fm.load_script(tutils.test_data.path("scripts/duplicate_flow.py")) r = tutils.treq() fm.handle_request(r) assert fm.state.flow_count() == 2 @@ -43,28 +41,28 @@ def test_err(self): tutils.raises( "no such file", - script.Script, ["nonexistent"], fm + script.Script, "nonexistent", fm ) tutils.raises( "not a file", - script.Script, [tutils.test_data.path("scripts")], fm + script.Script, tutils.test_data.path("scripts"), fm ) tutils.raises( script.ScriptError, - script.Script, [tutils.test_data.path("scripts/syntaxerr.py")], fm + script.Script, tutils.test_data.path("scripts/syntaxerr.py"), fm ) tutils.raises( script.ScriptError, - script.Script, [tutils.test_data.path("scripts/loaderr.py")], fm + script.Script, tutils.test_data.path("scripts/loaderr.py"), fm ) def test_concurrent(self): s = flow.State() fm = flow.FlowMaster(None, s) - fm.load_script([tutils.test_data.path("scripts/concurrent_decorator.py")]) + fm.load_script(tutils.test_data.path("scripts/concurrent_decorator.py")) with mock.patch("libmproxy.controller.DummyReply.__call__") as m: r1, r2 = tutils.treq(), tutils.treq() @@ -84,7 +82,7 @@ def test_concurrent(self): def test_concurrent2(self): s = flow.State() fm = flow.FlowMaster(None, s) - s = script.Script([tutils.test_data.path("scripts/concurrent_decorator.py")], fm) + s = script.Script(tutils.test_data.path("scripts/concurrent_decorator.py"), fm) s.load() f = tutils.tflow_full() f.error = tutils.terr(f.request) @@ -104,5 +102,15 @@ def test_concurrent_err(self): fm = flow.FlowMaster(None, s) tutils.raises( "decorator not supported for this method", - script.Script, [tutils.test_data.path("scripts/concurrent_decorator_err.py")], fm + script.Script, tutils.test_data.path("scripts/concurrent_decorator_err.py"), fm ) + + +def test_command_parsing(): + s = flow.State() + fm = flow.FlowMaster(None, s) + absfilepath = os.path.normcase(tutils.test_data.path("scripts/a.py")) + s = script.Script(absfilepath, fm) + assert os.path.isfile(s.argv[0]) + +