29
29
'''
30
30
31
31
from contextlib import contextmanager , closing
32
- from inspect import isclass
32
+ from inspect import isclass , ismethod
33
33
import threading
34
34
import os
35
35
@@ -59,8 +59,36 @@ def format_output(self, level, returnval):
59
59
class Tracer (threading .local ):
60
60
def __init__ (self , formatter = None ):
61
61
self .level = 0
62
+ self .max_depth = None
62
63
self .formatter = formatter or Formatter ()
63
64
65
+ def trace (self , f , args , kwargs , additional_depth = None ):
66
+ prev_max = self .max_depth
67
+ try :
68
+ if additional_depth is not None : # None means unlimited
69
+ total_depth = self .level + additional_depth
70
+ if self .max_depth is not None :
71
+ self .max_depth = min (self .max_depth , total_depth )
72
+ else :
73
+ self .max_depth = total_depth
74
+ if (self .max_depth is None or (self .level < self .max_depth )):
75
+ self .trace_in (f , args , kwargs )
76
+ self .level += 1
77
+
78
+ try :
79
+ r = f (* args , ** kwargs )
80
+ except Exception as e :
81
+ r = e # print the exception as the return val
82
+ raise
83
+ finally :
84
+ self .level -= 1
85
+ self .trace_out (r )
86
+ return r
87
+ else :
88
+ return f (* args , ** kwargs )
89
+ finally :
90
+ self .max_depth = prev_max
91
+
64
92
def close (self ):
65
93
pass
66
94
@@ -69,18 +97,11 @@ class StdoutTracer(Tracer):
69
97
def __init__ (self ):
70
98
super (StdoutTracer , self ).__init__ ()
71
99
72
- def trace (self , f , * args , ** kwargs ):
100
+ def trace_in (self , f , args , kwargs ):
73
101
print self .formatter .format_input (self .level , f , args , kwargs )
74
- self .level += 1
75
- try :
76
- r = f (* args , ** kwargs )
77
- except Exception as e :
78
- r = e # print the exception as the return val
79
- raise
80
- finally :
81
- self .level -= 1
82
- print self .formatter .format_output (self .level , r )
83
- return r
102
+
103
+ def trace_out (self , r ):
104
+ print self .formatter .format_output (self .level , r )
84
105
85
106
86
107
class PerThreadFileTracer (Tracer ):
@@ -91,26 +112,19 @@ def __init__(self, filename=None):
91
112
os .makedirs (d )
92
113
self .outputfile = open (filename , 'w' )
93
114
94
- def trace (self , f , * args , ** kwargs ):
115
+ def trace_in (self , f , * args , ** kwargs ):
95
116
self .outputfile .write (self .formatter .format_input (self .level , f , args , kwargs ) + "\n " )
96
- self .level += 1
97
- try :
98
- r = f (* args , ** kwargs )
99
- except Exception as e :
100
- r = e # print the exception as the return val
101
- raise
102
- finally :
103
- self .level -= 1
104
- self .outputfile .write (self .formatter .format_output (self .level , r ) + "\n " )
105
- return r
117
+
118
+ def trace_out (self , r ):
119
+ self .outputfile .write (self .formatter .format_output (self .level , r ) + "\n " )
106
120
107
121
def close (self ):
108
122
self .outputfile .close ()
109
123
110
124
111
- def add_trace (f , tracer ):
125
+ def add_trace (f , tracer , depth = None ):
112
126
def traced_fn (* args , ** kwargs ):
113
- return tracer .trace (f , * args , ** kwargs )
127
+ return tracer .trace (f , args , kwargs , additional_depth = depth )
114
128
traced_fn .trace = True # set flag so that we don't add trace more than once
115
129
return traced_fn
116
130
@@ -125,22 +139,37 @@ def traceable(f):
125
139
and not getattr (f , 'trace' , None ) # already being traced
126
140
127
141
142
+ def _get_func (m ):
143
+ '''Returns function given a function or method'''
144
+ if ismethod (m ):
145
+ return m .im_func
146
+ else :
147
+ return m
148
+
149
+
128
150
@contextmanager
129
- def trace_on (objs , include_hidden = False , tracer = None , skip = None ):
151
+ def trace_on (objs , include_hidden = False , tracer = None , depths = None ):
130
152
tracer = tracer or StdoutTracer ()
131
153
origs = {}
132
- skip = skip or []
154
+ depths = depths or {}
155
+
156
+ # converts methods to functions, since that is what's in __dict__
157
+ f_depths = {}
158
+ for (k , v ) in depths .items ():
159
+ f_depths [_get_func (k )] = v
160
+ depths = f_depths
161
+
133
162
for o in objs :
134
163
replacements = {}
135
164
for k in o .__dict__ .keys ():
136
165
v = o .__dict__ [k ]
137
166
if traceable (v ) and getattr (v , '__name__' , None ) is not '__repr__' \
138
- and v not in skip \
167
+ and ( v not in depths or depths [ v ] >= 0 ) \
139
168
and (include_hidden or
140
169
not (include_hidden or k .startswith ("_" ))):
141
170
replacements [k ] = v
142
- # print "Replacing: " + k
143
- setattr (o , k , add_trace (v , tracer ))
171
+ # print "Replacing: %s %s , depth %s" % (k, v, depths.get(v, None))
172
+ setattr (o , k , add_trace (v , tracer , depth = depths . get ( v , None ) ))
144
173
origs [o ] = replacements
145
174
# print origs
146
175
with closing (tracer ):
0 commit comments