forked from AUTOMATIC1111/stable-diffusion-webui
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpatches.py
64 lines (41 loc) · 1.79 KB
/
patches.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
from collections import defaultdict
def patch(key, obj, field, replacement):
"""Replaces a function in a module or a class.
Also stores the original function in this module, possible to be retrieved via original(key, obj, field).
If the function is already replaced by this caller (key), an exception is raised -- use undo() before that.
Arguments:
key: identifying information for who is doing the replacement. You can use __name__.
obj: the module or the class
field: name of the function as a string
replacement: the new function
Returns:
the original function
"""
patch_key = (obj, field)
if patch_key in originals[key]:
raise RuntimeError(f"patch for {field} is already applied")
original_func = getattr(obj, field)
originals[key][patch_key] = original_func
setattr(obj, field, replacement)
return original_func
def undo(key, obj, field):
"""Undoes the peplacement by the patch().
If the function is not replaced, raises an exception.
Arguments:
key: identifying information for who is doing the replacement. You can use __name__.
obj: the module or the class
field: name of the function as a string
Returns:
Always None
"""
patch_key = (obj, field)
if patch_key not in originals[key]:
raise RuntimeError(f"there is no patch for {field} to undo")
original_func = originals[key].pop(patch_key)
setattr(obj, field, original_func)
return None
def original(key, obj, field):
"""Returns the original function for the patch created by the patch() function"""
patch_key = (obj, field)
return originals[key].get(patch_key, None)
originals = defaultdict(dict)