Skip to content

Commit

Permalink
Merge pull request #24 from simonsobs/plugin
Browse files Browse the repository at this point in the history
Pass around a context object to all plugin callback functions
  • Loading branch information
TaiSakuma authored Jan 19, 2024
2 parents a9ec151 + 96f0ae7 commit 082489e
Show file tree
Hide file tree
Showing 18 changed files with 221 additions and 316 deletions.
36 changes: 20 additions & 16 deletions nextline/fsm/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
from logging import getLogger
from typing import Any, Optional

import apluggy
from transitions import EventData

from nextline.plugin import Context
from nextline.spawned import Command
from nextline.types import ResetOptions

Expand All @@ -14,8 +14,9 @@
class Machine:
'''The finite state machine of the nextline states.'''

def __init__(self, hook: apluggy.PluginManager) -> None:
self._hook = hook
def __init__(self, context: Context) -> None:
self._context = context
self._hook = context.hook
self._machine = build_state_machine(model=self)
self._machine.after_state_change = self.after_state_change.__name__ # type: ignore
assert self.state # type: ignore
Expand All @@ -28,39 +29,42 @@ async def after_state_change(self, event: EventData) -> None:
if not (event.transition and event.transition.dest):
# internal transition
return
await self._hook.ahook.on_change_state(state_name=self.state) # type: ignore
await self._hook.ahook.on_change_state(
context=self._context, state_name=self.state # type: ignore
)

async def on_exit_created(self, _: EventData) -> None:
await self._hook.ahook.start()
await self._hook.ahook.start(context=self._context)

async def on_enter_initialized(self, _: EventData) -> None:
self._run_arg = self._hook.hook.compose_run_arg()
await self._hook.ahook.on_initialize_run(run_arg=self._run_arg)
self._context.run_arg = self._hook.hook.compose_run_arg(context=self._context)
await self._hook.ahook.on_initialize_run(context=self._context)

async def on_enter_running(self, _: EventData) -> None:
self.run_finished = asyncio.Event()
run_started = asyncio.Event()

async def run() -> None:
async with self._hook.awith.run():
async with self._hook.awith.run(context=self._context):
run_started.set()
self._context.run_arg = None
await self.finish() # type: ignore
self.run_finished.set()

self._task = asyncio.create_task(run())
await run_started.wait()

async def send_command(self, command: Command) -> None:
await self._hook.ahook.send_command(command=command)
await self._hook.ahook.send_command(context=self._context, command=command)

async def interrupt(self) -> None:
await self._hook.ahook.interrupt()
await self._hook.ahook.interrupt(context=self._context)

async def terminate(self) -> None:
await self._hook.ahook.terminate()
await self._hook.ahook.terminate(context=self._context)

async def kill(self) -> None:
await self._hook.ahook.kill()
await self._hook.ahook.kill(context=self._context)

async def on_close_while_running(self, _: EventData) -> None:
await self.run_finished.wait()
Expand All @@ -72,13 +76,13 @@ async def on_exit_finished(self, _: EventData) -> None:
await self._task

def exception(self) -> Optional[BaseException]:
return self._hook.hook.exception()
return self._hook.hook.exception(context=self._context)

def result(self) -> Any:
return self._hook.hook.result()
return self._hook.hook.result(context=self._context)

async def on_enter_closed(self, _: EventData) -> None:
await self._hook.ahook.close()
await self._hook.ahook.close(context=self._context)

async def on_reset(self, event: EventData) -> None:
logger = getLogger(__name__)
Expand All @@ -88,7 +92,7 @@ async def on_reset(self, event: EventData) -> None:
reset_options: ResetOptions = kwargs.pop('reset_options')
if kwargs:
logger.warning(f'Unexpected kwargs: {kwargs!r}')
await self._hook.ahook.reset(reset_options=reset_options)
await self._hook.ahook.reset(context=self._context, reset_options=reset_options)

async def __aenter__(self) -> 'Machine':
await self.initialize() # type: ignore
Expand Down
12 changes: 4 additions & 8 deletions nextline/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from .continuous import Continuous
from .fsm import Machine
from .plugin import build_hook
from .plugin import Context, build_hook
from .spawned import PdbCommand
from .types import (
InitOptions,
Expand Down Expand Up @@ -86,13 +86,9 @@ async def start(self) -> None:
self._started = True
logger = getLogger(__name__)
logger.debug(f'self._init_options: {self._init_options}')
self._hook.hook.init(
nextline=self,
hook=self._hook,
registry=self._pubsub,
init_options=self._init_options,
)
self._machine = Machine(hook=self._hook)
context = Context(nextline=self, hook=self._hook, pubsub=self._pubsub)
self._hook.hook.init(context=context, init_options=self._init_options)
self._machine = Machine(context=context)
await self._continuous.start()
await self._machine.initialize() # type: ignore

Expand Down
3 changes: 2 additions & 1 deletion nextline/plugin/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
__all__ = ['build_hook']
__all__ = ['build_hook', 'Context']

from .hook import build_hook
from .spec import Context
31 changes: 8 additions & 23 deletions nextline/plugin/plugins/argument.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,32 @@
from apluggy import PluginManager

from nextline.count import RunNoCounter
from nextline.plugin.spec import hookimpl
from nextline.plugin.spec import Context, hookimpl
from nextline.spawned import RunArg
from nextline.types import InitOptions, ResetOptions
from nextline.utils.pubsub.broker import PubSub

SCRIPT_FILE_NAME = "<string>"


class RunArgComposer:
@hookimpl
def init(
self,
hook: PluginManager,
registry: PubSub,
init_options: InitOptions,
) -> None:
self._hook = hook
self._registry = registry
def init(self, init_options: InitOptions) -> None:
self._run_no_count = RunNoCounter(init_options.run_no_start_from)
self._statement = init_options.statement
self._filename = SCRIPT_FILE_NAME
self._trace_threads = init_options.trace_threads
self._trace_modules = init_options.trace_modules

@hookimpl
async def start(self) -> None:
await self._hook.ahook.on_change_script(
script=self._statement,
filename=self._filename,
async def start(self, context: Context) -> None:
await context.hook.ahook.on_change_script(
context=context, script=self._statement, filename=self._filename
)

@hookimpl
async def reset(
self,
reset_options: ResetOptions,
) -> None:
async def reset(self, context: Context, reset_options: ResetOptions) -> None:
if (statement := reset_options.statement) is not None:
self._statement = statement
await self._hook.ahook.on_change_script(
script=self._statement,
filename=self._filename,
await context.hook.ahook.on_change_script(
context=context, script=self._statement, filename=self._filename
)
if (run_no_start_from := reset_options.run_no_start_from) is not None:
self._run_no_count = RunNoCounter(run_no_start_from)
Expand Down
59 changes: 24 additions & 35 deletions nextline/plugin/plugins/registrars/prompt_info.py
Original file line number Diff line number Diff line change
@@ -1,92 +1,81 @@
import asyncio
import dataclasses
from logging import getLogger
from typing import Optional

from nextline.plugin.spec import hookimpl
from nextline.plugin.spec import Context, hookimpl
from nextline.spawned import (
OnEndPrompt,
OnEndTrace,
OnEndTraceCall,
OnStartPrompt,
OnStartTrace,
OnStartTraceCall,
RunArg,
)
from nextline.types import PromptInfo, PromptNo, RunNo, TraceNo
from nextline.utils.pubsub.broker import PubSub
from nextline.types import PromptInfo, PromptNo, TraceNo


class PromptInfoRegistrar:
def __init__(self) -> None:
self._run_no: Optional[RunNo] = None
self._last_prompt_frame_map = dict[TraceNo, int]()
self._trace_call_map = dict[TraceNo, OnStartTraceCall]()
self._prompt_info_map = dict[PromptNo, PromptInfo]()
self._keys = set[str]()
self._logger = getLogger(__name__)

@hookimpl
def init(self, registry: PubSub) -> None:
self._registry = registry

@hookimpl
async def start(self) -> None:
self._lock = asyncio.Lock()
pass

@hookimpl
async def on_initialize_run(self, run_arg: RunArg) -> None:
self._run_no = run_arg.run_no
async def on_initialize_run(self) -> None:
self._last_prompt_frame_map.clear()
self._trace_call_map.clear()
self._prompt_info_map.clear()
self._keys.clear()

@hookimpl
async def on_end_run(self) -> None:
async def on_end_run(self, context: Context) -> None:
async with self._lock:
while self._keys:
# the process might have been killed.
key = self._keys.pop()
await self._registry.end(key)

self._run_no = None
await context.pubsub.end(key)

@hookimpl
async def on_start_trace(self, event: OnStartTrace) -> None:
assert self._run_no is not None
async def on_start_trace(self, context: Context, event: OnStartTrace) -> None:
assert context.run_arg
trace_no = event.trace_no

# TODO: Putting a prompt info for now because otherwise tests get stuck
# sometimes for an unknown reason. Need to investigate
prompt_info = PromptInfo(
run_no=self._run_no,
run_no=context.run_arg.run_no,
trace_no=trace_no,
prompt_no=PromptNo(-1),
open=False,
)
key = f"prompt_info_{trace_no}"
async with self._lock:
self._keys.add(key)
await self._registry.publish(key, prompt_info)
await context.pubsub.publish(key, prompt_info)

@hookimpl
async def on_end_trace(self, event: OnEndTrace) -> None:
async def on_end_trace(self, context: Context, event: OnEndTrace) -> None:
trace_no = event.trace_no
key = f"prompt_info_{trace_no}"
async with self._lock:
if key in self._keys:
self._keys.remove(key)
await self._registry.end(key)
await context.pubsub.end(key)

@hookimpl
async def on_start_trace_call(self, event: OnStartTraceCall) -> None:
self._trace_call_map[event.trace_no] = event

@hookimpl
async def on_end_trace_call(self, event: OnEndTraceCall) -> None:
assert self._run_no is not None
async def on_end_trace_call(self, context: Context, event: OnEndTraceCall) -> None:
assert context.run_arg
trace_no = event.trace_no
trace_call = self._trace_call_map.pop(event.trace_no, None)
if trace_call is None:
Expand All @@ -102,7 +91,7 @@ async def on_end_trace_call(self, event: OnEndTraceCall) -> None:
# prompt info.

prompt_info = PromptInfo(
run_no=self._run_no,
run_no=context.run_arg.run_no,
trace_no=trace_no,
prompt_no=PromptNo(-1),
open=False,
Expand All @@ -111,21 +100,21 @@ async def on_end_trace_call(self, event: OnEndTraceCall) -> None:
line_no=trace_call.line_no,
trace_call_end=True,
)
await self._registry.publish('prompt_info', prompt_info)
await context.pubsub.publish('prompt_info', prompt_info)

key = f"prompt_info_{trace_no}"
async with self._lock:
self._keys.add(key)
await self._registry.publish(key, prompt_info)
await context.pubsub.publish(key, prompt_info)

@hookimpl
async def on_start_prompt(self, event: OnStartPrompt) -> None:
assert self._run_no is not None
async def on_start_prompt(self, context: Context, event: OnStartPrompt) -> None:
assert context.run_arg
trace_no = event.trace_no
prompt_no = event.prompt_no
trace_call = self._trace_call_map[trace_no]
prompt_info = PromptInfo(
run_no=self._run_no,
run_no=context.run_arg.run_no,
trace_no=trace_no,
prompt_no=prompt_no,
open=True,
Expand All @@ -138,15 +127,15 @@ async def on_start_prompt(self, event: OnStartPrompt) -> None:
self._prompt_info_map[prompt_no] = prompt_info
self._last_prompt_frame_map[trace_no] = trace_call.frame_object_id

await self._registry.publish('prompt_info', prompt_info)
await context.pubsub.publish('prompt_info', prompt_info)

key = f"prompt_info_{trace_no}"
async with self._lock:
self._keys.add(key)
await self._registry.publish(key, prompt_info)
await context.pubsub.publish(key, prompt_info)

@hookimpl
async def on_end_prompt(self, event: OnEndPrompt) -> None:
async def on_end_prompt(self, context: Context, event: OnEndPrompt) -> None:
trace_no = event.trace_no
prompt_no = event.prompt_no
prompt_info = self._prompt_info_map.pop(prompt_no)
Expand All @@ -157,9 +146,9 @@ async def on_end_prompt(self, event: OnEndPrompt) -> None:
ended_at=event.ended_at,
)

await self._registry.publish('prompt_info', prompt_info_end)
await context.pubsub.publish('prompt_info', prompt_info_end)

key = f"prompt_info_{trace_no}"
async with self._lock:
self._keys.add(key)
await self._registry.publish(key, prompt_info_end)
await context.pubsub.publish(key, prompt_info_end)
Loading

0 comments on commit 082489e

Please sign in to comment.