Skip to content

Commit b50a86b

Browse files
authored
Merge pull request #1 from python-trio/initial-push
Commit of old main code. Untested and unmodified from my PR in Feb.
2 parents 785be19 + 164a616 commit b50a86b

File tree

1 file changed

+381
-0
lines changed

1 file changed

+381
-0
lines changed

trio_monitor/monitor.py

Lines changed: 381 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,381 @@
1+
import inspect
2+
import os
3+
import random
4+
import signal
5+
import string
6+
import sys
7+
import traceback
8+
from types import FunctionType
9+
10+
from async_generator._impl import ANextIter
11+
12+
from trio import Queue, WouldBlock, BrokenStreamError
13+
from trio._highlevel_serve_listeners import _run_handler
14+
from ._version import __version__
15+
from trio.abc import Instrument
16+
from trio.hazmat import current_task, Task
17+
18+
19+
# inspiration: https://github.com/python-trio/trio/blob/master/notes-to-self/print-task-tree.py
20+
# additional credit: python-trio/trio#429 (You are being ratelimited)
21+
22+
23+
def walk_coro_stack(coro):
24+
while coro is not None:
25+
if hasattr(coro, "cr_frame"):
26+
# A real coroutine
27+
yield coro.cr_frame, coro.cr_frame.f_lineno
28+
coro = coro.cr_await
29+
elif isinstance(coro, ANextIter):
30+
# black hole
31+
return
32+
else:
33+
# A generator decorated with @types.coroutine
34+
yield coro.gi_frame, coro.gi_frame.f_lineno
35+
coro = coro.gi_yieldfrom
36+
37+
38+
class Monitor(Instrument):
39+
"""Represents a monitor; a simple way of monitoring the health of your
40+
Trio application using the Trio instrumentation API.
41+
42+
The monitor protocol is a simple text-based protocol, accessible from a
43+
telnet client, for example.
44+
"""
45+
MID_PREFIX = "|─ "
46+
MID_CONTINUE = "| "
47+
END_PREFIX = "\─ "
48+
END_CONTINUE = " " * len(END_PREFIX)
49+
50+
def __init__(self):
51+
self.authenticated = False
52+
53+
# authentication for running code in a frame
54+
rand = random.SystemRandom()
55+
self.auth_pin = ''.join(
56+
rand.choice(string.ascii_letters) for x in range(0, 8)
57+
)
58+
59+
self._is_monitoring = False
60+
# semi-arbitrary size, because otherwise we'll be dropping events
61+
# no clue how to make this better, alas.
62+
self._monitoring_queue = Queue(capacity=100)
63+
64+
@staticmethod
65+
def get_root_task() -> Task:
66+
"""Gets the current root task."""
67+
task = current_task()
68+
while task.parent_nursery is not None:
69+
task = task.parent_nursery.parent_task
70+
return task
71+
72+
@staticmethod
73+
def flatten_tasks():
74+
"""Gets a list of all tasks."""
75+
root = Monitor.get_root_task()
76+
tasks = [root]
77+
for child in root.child_nurseries:
78+
tasks.extend(Monitor.recursively_get_tasks(child))
79+
80+
return tasks
81+
82+
@staticmethod
83+
def recursively_get_tasks(nursery):
84+
"""Recursively gets all tasks from a nursery."""
85+
tasks = []
86+
for task in nursery.child_tasks:
87+
tasks.append(task)
88+
for nursery in task.child_nurseries:
89+
tasks.extend(Monitor.recursively_get_tasks(nursery))
90+
91+
return tasks
92+
93+
@staticmethod
94+
def get_task_by_id(taskid: int):
95+
"""Gets a task by ID."""
96+
for task in Monitor.flatten_tasks():
97+
if id(task) == taskid:
98+
return task
99+
100+
# specific overrides
101+
def before_io_wait(self, timeout):
102+
if timeout == 0:
103+
return
104+
105+
self._add_to_monitoring_queue(("before_io_wait", timeout))
106+
107+
def after_io_wait(self, timeout):
108+
if timeout == 0:
109+
return
110+
111+
self._add_to_monitoring_queue(("after_io_wait", timeout))
112+
113+
def _add_to_monitoring_queue(self, item):
114+
if not self._is_monitoring:
115+
return
116+
117+
if 'task' in item[0]:
118+
task = item[1]
119+
# idk how to make this better.
120+
if task.coro.cr_code == _run_handler.__code__:
121+
if task.coro.cr_frame is not None:
122+
loc = task.coro.cr_frame.f_locals
123+
124+
# if it's our own handler, skip it!
125+
if loc['handler'] == self.listen_on_stream:
126+
return
127+
128+
try:
129+
self._monitoring_queue.put_nowait(item)
130+
except WouldBlock:
131+
return
132+
133+
async def listen_on_stream(self, stream):
134+
"""Makes the monitor server listen on a stream.
135+
Use this as a callback from a listener.
136+
"""
137+
return await self.main_loop(stream)
138+
139+
async def main_loop(self, stream):
140+
"""Runs the main loop of the monitor.
141+
"""
142+
# send the banner
143+
version = __version__
144+
await stream.send_all(
145+
b"Connected to the Trio monitor, using "
146+
b"trio " + version.encode(encoding="ascii") + b"\n"
147+
)
148+
149+
while True:
150+
await stream.send_all(b"trio> ")
151+
command = await stream.receive_some(2048)
152+
if command == b"":
153+
return
154+
155+
command = command.decode("ascii").rstrip("\n").rstrip("\r")
156+
name, *args = command.split(" ")
157+
158+
# special handling for closing
159+
if name in ["exit", "ex", "quit", "q", ":q"]:
160+
return await stream.aclose()
161+
162+
# special handling for monitor
163+
if name in ["monitor", "mon", "m", "feed"]:
164+
try:
165+
self._is_monitoring = True
166+
return await self.do_monitor(stream)
167+
finally:
168+
self._is_monitoring = False
169+
# empty out the queue
170+
self._monitoring_queue = Queue(capacity=100)
171+
172+
try:
173+
fn = getattr(self, "command_{}".format(name))
174+
except AttributeError:
175+
await stream.send_all(
176+
b"No such command: " + name.encode() + b"\n"
177+
)
178+
continue
179+
180+
try:
181+
lines = await fn(*args)
182+
except Exception as e:
183+
messages = ["takes at most", "required positional argument"]
184+
185+
if isinstance(e, TypeError) and \
186+
any(x in ' '.join(e.args) for x in messages):
187+
# hacky, but idk what else to do
188+
await stream.send_all(
189+
' '.join(e.args).encode("ascii") + b"\n"
190+
)
191+
continue
192+
193+
errormessage = type(e).__name__ + ": " + ' '.join(e.args)
194+
await stream.send_all(b"Error: " + errormessage.encode())
195+
raise
196+
197+
await stream.send_all(
198+
"\n".join(lines).encode(encoding="ascii") + b"\n"
199+
)
200+
201+
# monitor feed
202+
async def do_monitor(self, stream):
203+
"""Livefeeds information about the running program."""
204+
prefix = "[FEED] "
205+
async for item in self._monitoring_queue:
206+
key = item[0]
207+
208+
if key == "task_spawned":
209+
task = item[1]
210+
message = "Task spawned: {} ({})".format(task.name, id(task))
211+
212+
elif key == "task_scheduled":
213+
task = item[1]
214+
message = "Task scheduled: {} ({})".format(task.name, id(task))
215+
elif key == "task_exited":
216+
task = item[1]
217+
message = "Task exited: {} ({})".format(task.name, id(task))
218+
219+
elif key == "before_io_wait":
220+
timeout = item[1]
221+
message = "Waiting for IO (timeout: {:.3f})".format(timeout)
222+
223+
elif key == "after_io_wait":
224+
timeout = item[1]
225+
message = "Done waiting for IO (timeout: {:.3f})" \
226+
.format(timeout)
227+
228+
elif key == "before_task_step":
229+
task = item[1]
230+
message = "Task stepping: {} ({})".format(task.name, id(task))
231+
232+
elif key == "after_task_step":
233+
task = item[1]
234+
message = "Task finished stepping: {} ({})".format(
235+
task.name, id(task)
236+
)
237+
238+
else:
239+
message = "Unknown event: {}".format(key)
240+
241+
message = prefix + message
242+
243+
try:
244+
await stream.send_all(message.encode("ascii") + b'\n')
245+
except BrokenStreamError: # client disconnected on us
246+
return
247+
248+
# command definitions
249+
async def command_help(self):
250+
"""Sends help."""
251+
name_rpad = 12
252+
253+
def pred(i):
254+
return hasattr(i, "__name__") \
255+
and i.__name__.startswith("command_")
256+
257+
commands = inspect.getmembers(self, predicate=pred)
258+
lines = ["Commands:"]
259+
for name, command in commands:
260+
doc = inspect.getdoc(command).splitlines(keepends=False)[0]
261+
name = name.split("_", 1)[1]
262+
lines.append(name.ljust(name_rpad) + doc)
263+
264+
return lines
265+
266+
async def command_signal(self, signame: str):
267+
"""Sends a signal to the server process."""
268+
signame = signame.upper()
269+
if not signame.startswith("SIG"):
270+
signame = "SIG{}".format(signame)
271+
272+
try:
273+
tosend = getattr(signal, signame)
274+
except AttributeError:
275+
return ["Invalid signal: {}".format(signame)]
276+
277+
os.kill(os.getpid(), tosend)
278+
return ["Signal sent successfully"]
279+
280+
async def command_ps(self):
281+
"""Gets the current list of tasks."""
282+
lines = []
283+
headers = ('ID', 'Name', 'Parent')
284+
widths = (15, 50, 15)
285+
header_line = []
286+
287+
for name, width in zip(headers, widths):
288+
header_line.append(name.ljust(width))
289+
290+
lines.append(' '.join(header_line))
291+
lines.append("-" * sum(widths))
292+
293+
for task in self.flatten_tasks():
294+
if len(task.name) >= 50:
295+
name = task.name[:46] + "..."
296+
else:
297+
name = task.name
298+
299+
if task.parent_nursery is None:
300+
parent = "N/A"
301+
else:
302+
parent = str(id(task.parent_nursery.parent_task))
303+
304+
result = [str(id(task)).ljust(widths[0]), name.ljust(49), parent]
305+
306+
lines.append(' '.join(result))
307+
308+
return lines
309+
310+
async def command_where(self, taskid):
311+
"""Shows the stack frames of a task."""
312+
try:
313+
taskid = int(taskid)
314+
except ValueError:
315+
return ["Invalid task ID"]
316+
317+
task = self.get_task_by_id(taskid)
318+
if task is None:
319+
return ["Invalid task ID"]
320+
321+
summary = traceback.StackSummary.extract(walk_coro_stack(task.coro))
322+
lines = summary.format()
323+
lines = [line.rstrip('\n') for line in lines]
324+
return lines
325+
326+
# stub commands
327+
async def command_monitor(self, *args):
328+
"""Starts a live monitor feed."""
329+
return ["You shouldn't see this"]
330+
331+
async def command_exit(self, *args):
332+
"""Exits the monitor."""
333+
return ["You shouldn't see this"]
334+
335+
336+
def _patch_monitor():
337+
def pred(i):
338+
return isinstance(i, FunctionType) and not i.__name__.startswith("__")
339+
340+
for name, _ in inspect.getmembers(Instrument, predicate=pred):
341+
# 99% sure this is needed to bind the right name
342+
# otherwise it always uses the last name
343+
def bind(fname):
344+
def magic(self, *args):
345+
return self._add_to_monitoring_queue((fname, *args))
346+
347+
return magic
348+
349+
if getattr(Monitor, name) == getattr(Instrument, name):
350+
setattr(Monitor, name, bind(name))
351+
352+
353+
_patch_monitor()
354+
del _patch_monitor
355+
356+
357+
def main():
358+
import argparse
359+
import telnetlib
360+
361+
parser = argparse.ArgumentParser()
362+
parser.add_argument(
363+
"-a",
364+
"--address",
365+
default="127.0.0.1",
366+
help="The address to connect to"
367+
)
368+
parser.add_argument(
369+
"-p", "--port", default=14761, help="The port to connect to"
370+
)
371+
372+
args = parser.parse_args()
373+
# TODO: Potentially wrap sys.stdin for better readline
374+
client = telnetlib.Telnet(host=args.address, port=args.port)
375+
client.interact()
376+
377+
return 0
378+
379+
380+
if __name__ == "__main__":
381+
sys.exit(main())

0 commit comments

Comments
 (0)