Skip to content

Commit 7f929b0

Browse files
committed
Fix key binding registration for bound methods, add unit tests
1 parent 1c6d74d commit 7f929b0

File tree

2 files changed

+83
-9
lines changed

2 files changed

+83
-9
lines changed

mpv-test.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,79 @@ def foo(*args, **kwargs):
243243
mock.call('loop', 'inf')],
244244
any_order=True)
245245

246+
class KeyBindingTest(MpvTestCase):
247+
def test_register_direct_cmd(self):
248+
self.m.register_key_binding('a', 'playlist-clear')
249+
self.assertEqual(self.m._key_binding_handlers, {})
250+
self.m.register_key_binding('Ctrl+Shift+a', 'playlist-clear')
251+
self.m.unregister_key_binding('a')
252+
self.m.unregister_key_binding('Ctrl+Shift+a')
253+
254+
def test_register_direct_fun(self):
255+
b = mpv.MPV._binding_name
256+
257+
def reg_test_fun(state, name):
258+
pass
259+
260+
self.m.register_key_binding('a', reg_test_fun)
261+
self.assertIn(b('a'), self.m._key_binding_handlers)
262+
self.assertEqual(self.m._key_binding_handlers[b('a')], reg_test_fun)
263+
264+
self.m.unregister_key_binding('a')
265+
self.assertNotIn(b('a'), self.m._key_binding_handlers)
266+
267+
def test_register_direct_bound_method(self):
268+
b = mpv.MPV._binding_name
269+
270+
class RegTestCls:
271+
def method(self, state, name):
272+
pass
273+
instance = RegTestCls()
274+
275+
self.m.register_key_binding('a', instance.method)
276+
self.assertIn(b('a'), self.m._key_binding_handlers)
277+
self.assertEqual(self.m._key_binding_handlers[b('a')], instance.method)
278+
279+
self.m.unregister_key_binding('a')
280+
self.assertNotIn(b('a'), self.m._key_binding_handlers)
281+
282+
def test_register_decorator_fun(self):
283+
b = mpv.MPV._binding_name
284+
285+
@self.m.key_binding('a')
286+
def reg_test_fun(state, name):
287+
pass
288+
self.assertEqual(reg_test_fun.mpv_key_bindings, ['a'])
289+
self.assertIn(b('a'), self.m._key_binding_handlers)
290+
self.assertEqual(self.m._key_binding_handlers[b('a')], reg_test_fun)
291+
292+
reg_test_fun.unregister_mpv_key_bindings()
293+
self.assertNotIn(b('a'), self.m._key_binding_handlers)
294+
295+
def test_register_decorator_fun_chaining(self):
296+
b = mpv.MPV._binding_name
297+
298+
@self.m.key_binding('a')
299+
@self.m.key_binding('b')
300+
def reg_test_fun(state, name):
301+
pass
302+
303+
@self.m.key_binding('c')
304+
def reg_test_fun_2_stay_intact(state, name):
305+
pass
306+
307+
self.assertEqual(reg_test_fun.mpv_key_bindings, ['b', 'a'])
308+
self.assertIn(b('a'), self.m._key_binding_handlers)
309+
self.assertIn(b('b'), self.m._key_binding_handlers)
310+
self.assertIn(b('c'), self.m._key_binding_handlers)
311+
self.assertEqual(self.m._key_binding_handlers[b('a')], reg_test_fun)
312+
self.assertEqual(self.m._key_binding_handlers[b('b')], reg_test_fun)
313+
314+
reg_test_fun.unregister_mpv_key_bindings()
315+
self.assertNotIn(b('a'), self.m._key_binding_handlers)
316+
self.assertNotIn(b('b'), self.m._key_binding_handlers)
317+
self.assertIn(b('c'), self.m._key_binding_handlers)
318+
246319
class TestLifecycle(unittest.TestCase):
247320
def test_create_destroy(self):
248321
thread_names = lambda: [ t.name for t in threading.enumerate() ]

mpv.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -837,8 +837,8 @@ def unregister_message_handler(self, target_or_handler):
837837
838838
You can also call the ```unregister_mpv_messages``` function attribute set on the handler function when it is
839839
registered. """
840-
if isinstance(target, str):
841-
del self._message_handlers[target]
840+
if isinstance(target_or_handler, str):
841+
del self._message_handlers[target_or_handler]
842842
else:
843843
for key, val in self._message_handlers.items():
844844
if val == target_or_handler:
@@ -945,10 +945,16 @@ def binding(state, name):
945945
this is completely fine--but, if you are about to pass untrusted input into this parameter, better double-check
946946
whether this is secure in your case. """
947947

948-
def wrapper(fun):
948+
def register(fun):
949+
fun.mpv_key_bindings = getattr(fun, 'mpv_key_bindings', []) + [keydef]
950+
def unregister_all():
951+
for keydef in fun.mpv_key_bindings:
952+
self.unregister_key_binding(keydef)
953+
fun.unregister_mpv_key_bindings = unregister_all
954+
949955
self.register_key_binding(keydef, fun, mode)
950956
return fun
951-
return wrapper
957+
return register
952958

953959
def register_key_binding(self, keydef, callback_or_cmd, mode='force'):
954960
""" Register a key binding. This takes an mpv keydef and either a string containing a mpv
@@ -959,11 +965,6 @@ def register_key_binding(self, keydef, callback_or_cmd, mode='force'):
959965
'symbolic name (as printed by --input-keylist')
960966
binding_name = MPV._binding_name(keydef)
961967
if callable(callback_or_cmd):
962-
callback_or_cmd.mpv_key_bindings = getattr(callback_or_cmd, 'mpv_key_bindings', []) + [keydef]
963-
def unregister_all():
964-
for keydef in callback_or_cmd.mpv_key_bindings:
965-
self.unregister_key_binding(keydef)
966-
callback_or_cmd.unregister_mpv_key_bindings = unregister_all
967968
self._key_binding_handlers[binding_name] = callback_or_cmd
968969
self.register_message_handler('key-binding', self._handle_key_binding_message)
969970
self.command('define-section',

0 commit comments

Comments
 (0)