diff --git a/api_docs_and_examples.py b/api_docs_and_examples.py index 837d25f1c..593ddabb2 100644 --- a/api_docs_and_examples.py +++ b/api_docs_and_examples.py @@ -5,7 +5,8 @@ import docutils.core -from nicegui import globals, ui +from nicegui import ui +from nicegui.auto_context import Context from nicegui.task_logger import create_task REGEX_H4 = re.compile(r'(.*?)') @@ -815,7 +816,7 @@ def turn_off(): ui.notify('Turning off that line plot to save resources on our live demo server. 😎') line_checkbox.value = msg.value if msg.value: - with globals.within_view(line_checkbox.view): + with Context(line_checkbox.view): ui.timer(10.0, turn_off, once=True) line_checkbox.update() return False diff --git a/nicegui/auto_context.py b/nicegui/auto_context.py new file mode 100644 index 000000000..b6b8881fd --- /dev/null +++ b/nicegui/auto_context.py @@ -0,0 +1,76 @@ +from __future__ import annotations + +import asyncio +from typing import TYPE_CHECKING, Any, Coroutine, Generator, List + +from . import globals +from .task_logger import create_task + +if TYPE_CHECKING: + import justpy as jp + + +def get_task_id() -> int: + return id(asyncio.current_task()) if globals.loop and globals.loop.is_running() else 0 + + +def get_view_stack() -> List['jp.HTMLBaseComponent']: + task_id = get_task_id() + if task_id not in globals.view_stacks: + globals.view_stacks[task_id] = [] + return globals.view_stacks[task_id] + + +def prune_view_stack() -> None: + task_id = get_task_id() + if not globals.view_stacks[task_id]: + del globals.view_stacks[task_id] + + +class Context: + + def __init__(self, view: 'jp.HTMLBaseComponent') -> None: + self.view = view + + def __enter__(self): + self.child_count = len(self.view) + get_view_stack().append(self.view) + return self + + def __exit__(self, type, value, traceback): + get_view_stack().pop() + prune_view_stack() + self.lazy_update() + + def lazy_update(self) -> None: + if len(self.view) != self.child_count: + self.child_count = len(self.view) + create_task(self.view.update()) + + def watch_asyncs(self, coro: Coroutine) -> AutoUpdaterForAsyncs: + return AutoUpdaterForAsyncs(coro, self) + + +class AutoUpdaterForAsyncs: + + def __init__(self, coro: Coroutine, context: Context) -> None: + self.coro = coro + self.context = context + self.context.lazy_update() + + def __await__(self) -> Generator[Any, None, Any]: + coro_iter = self.coro.__await__() + iter_send, iter_throw = coro_iter.send, coro_iter.throw + send, message = iter_send, None + while True: + try: + signal = send(message) + self.context.lazy_update() + except StopIteration as err: + return err.value + else: + send = iter_send + try: + message = yield signal + except BaseException as err: + send, message = iter_throw, err diff --git a/nicegui/elements/group.py b/nicegui/elements/group.py index 837ab88ef..c85152cd3 100644 --- a/nicegui/elements/group.py +++ b/nicegui/elements/group.py @@ -4,7 +4,7 @@ import justpy as jp -from .. import globals +from ..auto_context import get_view_stack from ..binding import active_links, bindable_properties, bindings from .element import Element @@ -13,11 +13,11 @@ class Group(Element): def __enter__(self): self._child_count_on_enter = len(self.view) - globals.get_view_stack().append(self.view) + get_view_stack().append(self.view) return self def __exit__(self, *_): - globals.get_view_stack().pop() + get_view_stack().pop() if self._child_count_on_enter != len(self.view): self.update() diff --git a/nicegui/elements/scene.py b/nicegui/elements/scene.py index fff17c711..741f4f069 100644 --- a/nicegui/elements/scene.py +++ b/nicegui/elements/scene.py @@ -5,7 +5,7 @@ import websockets from justpy import WebPage -from .. import globals +from ..auto_context import get_view_stack from ..events import handle_event from ..page import Page from ..routes import add_dependencies @@ -111,14 +111,14 @@ def __init__(self, width: int = 400, height: int = 300, on_click: Optional[Calla super().__init__(SceneView(width=width, height=height, on_click=on_click)) def __enter__(self): - globals.get_view_stack().append(self.view) + get_view_stack().append(self.view) scene = self.view.objects.get('scene', SceneObject(self.view, self.page)) Object3D.stack.clear() Object3D.stack.append(scene) return self def __exit__(self, *_): - globals.get_view_stack().pop() + get_view_stack().pop() def move_camera(self, x: Optional[float] = None, diff --git a/nicegui/events.py b/nicegui/events.py index 1c97b6944..6064ff3a4 100644 --- a/nicegui/events.py +++ b/nicegui/events.py @@ -6,6 +6,7 @@ from starlette.websockets import WebSocket from . import globals +from .auto_context import Context from .helpers import is_coroutine from .lifecycle import on_startup from .task_logger import create_task @@ -238,12 +239,12 @@ def handle_event(handler: Optional[Callable], arguments: EventArguments) -> Opti if handler is None: return False no_arguments = not signature(handler).parameters - with globals.within_view(arguments.sender.parent_view): + with Context(arguments.sender.parent_view): result = handler() if no_arguments else handler(arguments) if is_coroutine(handler): async def wait_for_result(): - with globals.within_view(arguments.sender.parent_view): - await result + with Context(arguments.sender.parent_view) as context: + await context.watch_asyncs(result) if globals.loop and globals.loop.is_running(): create_task(wait_for_result(), name=str(handler)) else: diff --git a/nicegui/globals.py b/nicegui/globals.py index 337ea3bf1..b2131a55d 100644 --- a/nicegui/globals.py +++ b/nicegui/globals.py @@ -2,15 +2,13 @@ import asyncio import logging -from contextlib import contextmanager from enum import Enum -from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Generator, List, Optional, Union +from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional, Union from starlette.applications import Starlette from uvicorn import Server from .config import Config -from .task_logger import create_task if TYPE_CHECKING: import justpy as jp @@ -45,31 +43,3 @@ def find_route(function: Callable) -> str: if not routes: raise ValueError(f'Invalid page function {function}') return routes[0] - - -def get_task_id() -> int: - return id(asyncio.current_task()) if loop and loop.is_running() else 0 - - -def get_view_stack() -> List['jp.HTMLBaseComponent']: - task_id = get_task_id() - if task_id not in view_stacks: - view_stacks[task_id] = [] - return view_stacks[task_id] - - -def prune_view_stack() -> None: - task_id = get_task_id() - if not view_stacks[task_id]: - del view_stacks[task_id] - - -@contextmanager -def within_view(view: 'jp.HTMLBaseComponent') -> Generator[None, None, None]: - child_count = len(view) - get_view_stack().append(view) - yield - get_view_stack().pop() - prune_view_stack() - if len(view) != child_count: - create_task(view.update()) diff --git a/nicegui/page.py b/nicegui/page.py index 07c503253..0bcbd15de 100644 --- a/nicegui/page.py +++ b/nicegui/page.py @@ -17,6 +17,7 @@ from starlette.websockets import WebSocket from . import globals +from .auto_context import Context, get_view_stack from .events import PageEvent from .helpers import is_coroutine from .page_builder import PageBuilder @@ -74,7 +75,7 @@ def set_favicon(self, favicon: Optional[str]) -> None: self.favicon = f'_favicon/{favicon}' async def _route_function(self, request: Request) -> Page: - with globals.within_view(self.view): + with Context(self.view): for handler in globals.connect_handlers + ([self.connect_handler] if self.connect_handler else []): arg_count = len(inspect.signature(handler).parameters) is_coro = is_coroutine(handler) @@ -87,11 +88,11 @@ async def _route_function(self, request: Request) -> Page: return self async def handle_page_ready(self, msg: AdDict) -> bool: - with globals.within_view(self.view): + with Context(self.view) as context: try: if self.page_ready_generator is not None: if isinstance(self.page_ready_generator, types.AsyncGeneratorType): - await self.page_ready_generator.asend(PageEvent(msg.websocket)) + await context.watch_asyncs(self.page_ready_generator.asend(PageEvent(msg.websocket))) elif isinstance(self.page_ready_generator, types.GeneratorType): self.page_ready_generator.send(PageEvent(msg.websocket)) except (StopIteration, StopAsyncIteration): @@ -103,17 +104,19 @@ async def handle_page_ready(self, msg: AdDict) -> bool: arg_count = len(inspect.signature(self.page_ready_handler).parameters) is_coro = is_coroutine(self.page_ready_handler) if arg_count == 1: - await self.page_ready_handler(msg.websocket) if is_coro else self.page_ready_handler(msg.websocket) + result = self.page_ready_handler(msg.websocket) elif arg_count == 0: - await self.page_ready_handler() if is_coro else self.page_ready_handler() + result = self.page_ready_handler() else: raise ValueError(f'invalid number of arguments (0 or 1 allowed, got {arg_count})') + if is_coro: + await context.watch_asyncs(result) except: globals.log.exception('Failed to execute page-ready') return False async def on_disconnect(self, websocket: Optional[WebSocket] = None) -> None: - with globals.within_view(self.view): + with Context(self.view): for handler in globals.disconnect_handlers + ([self.disconnect_handler] if self.disconnect_handler else[]): arg_count = len(inspect.signature(handler).parameters) is_coro = is_coroutine(handler) @@ -209,7 +212,7 @@ def __init__( self.page: Optional[Page] = None *_, self.converters = compile_path(route) - def __call__(self, func, **kwargs) -> Callable: + def __call__(self, func: Callable, **kwargs) -> Callable: @wraps(func) async def decorated(request: Optional[Request] = None) -> Page: self.page = Page( @@ -224,7 +227,7 @@ async def decorated(request: Optional[Request] = None) -> Page: shared=self.shared, ) try: - with globals.within_view(self.page.view): + with Context(self.page.view): if 'request' in inspect.signature(func).parameters: if self.shared: raise RuntimeError('Cannot use `request` argument in shared page') @@ -263,7 +266,7 @@ async def after_content(self) -> None: def find_parent_view() -> jp.HTMLBaseComponent: - view_stack = globals.get_view_stack() + view_stack = get_view_stack() if not view_stack: if globals.loop and globals.loop.is_running(): raise RuntimeError('cannot find parent view, view stack is empty') diff --git a/nicegui/timer.py b/nicegui/timer.py index 4a99337e3..3d60f779d 100644 --- a/nicegui/timer.py +++ b/nicegui/timer.py @@ -7,6 +7,7 @@ from starlette.websockets import WebSocket from . import globals +from .auto_context import Context from .binding import BindableProperty from .helpers import is_coroutine from .page import Page, find_parent_page, find_parent_view @@ -41,10 +42,10 @@ def __init__(self, interval: float, callback: Callable, *, active: bool = True, async def do_callback(): try: - with globals.within_view(self.parent_view): + with Context(self.parent_view) as context: result = callback() if is_coroutine(callback): - await result + await context.watch_asyncs(result) except Exception: traceback.print_exc() diff --git a/tests/screen.py b/tests/screen.py index 077835c69..b6c5ab1a2 100644 --- a/tests/screen.py +++ b/tests/screen.py @@ -159,6 +159,16 @@ def get_attributes(self, tag: str, attribute: str) -> List[str]: def wait(self, t: float) -> None: time.sleep(t) + def wait_for(self, text: str, *, timeout: float = 1.0) -> None: + deadline = time.time() + timeout + while time.time() < deadline: + try: + self.find(text) + return + except: + self.wait(0.1) + raise TimeoutError() + def shot(self, name: str) -> None: os.makedirs(self.SCREENSHOT_DIR, exist_ok=True) filename = f'{self.SCREENSHOT_DIR}/{name}.png' diff --git a/tests/test_auto_context.py b/tests/test_auto_context.py index d06a68bc3..9888834bb 100644 --- a/tests/test_auto_context.py +++ b/tests/test_auto_context.py @@ -1,6 +1,8 @@ import asyncio +from typing import Generator from nicegui import ui +from nicegui.events import PageEvent from .screen import Screen @@ -51,3 +53,74 @@ def test_adding_elements_during_onconnect(screen: Screen): screen.open('/') screen.should_contain('Label 2') + + +def test_autoupdate_on_async_page_after_yield(screen: Screen): + @ui.page('/') + async def page() -> Generator[None, PageEvent, None]: + ui.label('before page is ready') + yield + ui.label('page ready') + await asyncio.sleep(1) + ui.label('one') + await asyncio.sleep(1) + ui.label('two') + await asyncio.sleep(1) + ui.label('three') + + screen.open('/') + screen.should_contain('before page is ready') + screen.should_contain('page ready') + screen.should_not_contain('one') + screen.wait_for('one') + screen.should_not_contain('two') + screen.wait_for('two') + screen.should_not_contain('three') + screen.wait_for('three') + + +def test_autoupdate_on_async_page_ready_callback(screen: Screen): + async def ready(): + ui.label('page ready') + await asyncio.sleep(1) + ui.label('after delay') + + @ui.page('/', on_page_ready=ready) + def page() -> Generator[None, PageEvent, None]: + ui.label('before page is ready') + + screen.open('/') + screen.should_contain('before page is ready') + screen.should_contain('page ready') + screen.should_not_contain('after delay') + screen.wait_for('after delay') + + +def test_autoupdate_on_async_event_handler(screen: Screen): + async def open(): + with ui.dialog() as dialog, ui.card(): + l = ui.label('This should be visible') + dialog.open() + await asyncio.sleep(1) + l.text = 'New text after 1 second' + ui.button('Dialog', on_click=open) + + screen.open('/') + screen.click('Dialog') + screen.should_contain('This should be visible') + screen.should_not_contain('New text after 1 second') + screen.wait_for('New text after 1 second') + + +def test_autoupdate_on_async_timer_callback(screen: Screen): + async def update(): + ui.label('1') + await asyncio.sleep(1.0) + ui.label('2') + ui.timer(2.0, update, once=True) + + screen.open('/') + screen.should_not_contain('1') + screen.wait_for('1') + screen.should_not_contain('2') + screen.wait_for('2') diff --git a/tests/test_events.py b/tests/test_events.py new file mode 100644 index 000000000..ec0411d26 --- /dev/null +++ b/tests/test_events.py @@ -0,0 +1,22 @@ +import asyncio + +from nicegui import ui + +from .screen import Screen + + +def test_event_with_update_before_await(screen: Screen): + @ui.page('/') + def page(): + async def update(): + ui.label('1') + await asyncio.sleep(1.0) + ui.label('2') + + ui.button('update', on_click=update) + + screen.open('/') + screen.click('update') + screen.wait_for('1') + screen.should_not_contain('2') + screen.wait_for('2')