diff --git a/api_docs_and_examples.py b/api_docs_and_examples.py
index 837d25f1c..593ddabb2 100644
--- a/api_docs_and_examples.py
+++ b/api_docs_and_examples.py
@@ -5,7 +5,8 @@
import docutils.core
-from nicegui import globals, ui
+from nicegui import ui
+from nicegui.auto_context import Context
from nicegui.task_logger import create_task
REGEX_H4 = re.compile(r'
(.*?)')
@@ -815,7 +816,7 @@ def turn_off():
ui.notify('Turning off that line plot to save resources on our live demo server. 😎')
line_checkbox.value = msg.value
if msg.value:
- with globals.within_view(line_checkbox.view):
+ with Context(line_checkbox.view):
ui.timer(10.0, turn_off, once=True)
line_checkbox.update()
return False
diff --git a/nicegui/auto_context.py b/nicegui/auto_context.py
new file mode 100644
index 000000000..b6b8881fd
--- /dev/null
+++ b/nicegui/auto_context.py
@@ -0,0 +1,76 @@
+from __future__ import annotations
+
+import asyncio
+from typing import TYPE_CHECKING, Any, Coroutine, Generator, List
+
+from . import globals
+from .task_logger import create_task
+
+if TYPE_CHECKING:
+ import justpy as jp
+
+
+def get_task_id() -> int:
+ return id(asyncio.current_task()) if globals.loop and globals.loop.is_running() else 0
+
+
+def get_view_stack() -> List['jp.HTMLBaseComponent']:
+ task_id = get_task_id()
+ if task_id not in globals.view_stacks:
+ globals.view_stacks[task_id] = []
+ return globals.view_stacks[task_id]
+
+
+def prune_view_stack() -> None:
+ task_id = get_task_id()
+ if not globals.view_stacks[task_id]:
+ del globals.view_stacks[task_id]
+
+
+class Context:
+
+ def __init__(self, view: 'jp.HTMLBaseComponent') -> None:
+ self.view = view
+
+ def __enter__(self):
+ self.child_count = len(self.view)
+ get_view_stack().append(self.view)
+ return self
+
+ def __exit__(self, type, value, traceback):
+ get_view_stack().pop()
+ prune_view_stack()
+ self.lazy_update()
+
+ def lazy_update(self) -> None:
+ if len(self.view) != self.child_count:
+ self.child_count = len(self.view)
+ create_task(self.view.update())
+
+ def watch_asyncs(self, coro: Coroutine) -> AutoUpdaterForAsyncs:
+ return AutoUpdaterForAsyncs(coro, self)
+
+
+class AutoUpdaterForAsyncs:
+
+ def __init__(self, coro: Coroutine, context: Context) -> None:
+ self.coro = coro
+ self.context = context
+ self.context.lazy_update()
+
+ def __await__(self) -> Generator[Any, None, Any]:
+ coro_iter = self.coro.__await__()
+ iter_send, iter_throw = coro_iter.send, coro_iter.throw
+ send, message = iter_send, None
+ while True:
+ try:
+ signal = send(message)
+ self.context.lazy_update()
+ except StopIteration as err:
+ return err.value
+ else:
+ send = iter_send
+ try:
+ message = yield signal
+ except BaseException as err:
+ send, message = iter_throw, err
diff --git a/nicegui/elements/group.py b/nicegui/elements/group.py
index 837ab88ef..c85152cd3 100644
--- a/nicegui/elements/group.py
+++ b/nicegui/elements/group.py
@@ -4,7 +4,7 @@
import justpy as jp
-from .. import globals
+from ..auto_context import get_view_stack
from ..binding import active_links, bindable_properties, bindings
from .element import Element
@@ -13,11 +13,11 @@ class Group(Element):
def __enter__(self):
self._child_count_on_enter = len(self.view)
- globals.get_view_stack().append(self.view)
+ get_view_stack().append(self.view)
return self
def __exit__(self, *_):
- globals.get_view_stack().pop()
+ get_view_stack().pop()
if self._child_count_on_enter != len(self.view):
self.update()
diff --git a/nicegui/elements/scene.py b/nicegui/elements/scene.py
index fff17c711..741f4f069 100644
--- a/nicegui/elements/scene.py
+++ b/nicegui/elements/scene.py
@@ -5,7 +5,7 @@
import websockets
from justpy import WebPage
-from .. import globals
+from ..auto_context import get_view_stack
from ..events import handle_event
from ..page import Page
from ..routes import add_dependencies
@@ -111,14 +111,14 @@ def __init__(self, width: int = 400, height: int = 300, on_click: Optional[Calla
super().__init__(SceneView(width=width, height=height, on_click=on_click))
def __enter__(self):
- globals.get_view_stack().append(self.view)
+ get_view_stack().append(self.view)
scene = self.view.objects.get('scene', SceneObject(self.view, self.page))
Object3D.stack.clear()
Object3D.stack.append(scene)
return self
def __exit__(self, *_):
- globals.get_view_stack().pop()
+ get_view_stack().pop()
def move_camera(self,
x: Optional[float] = None,
diff --git a/nicegui/events.py b/nicegui/events.py
index 1c97b6944..6064ff3a4 100644
--- a/nicegui/events.py
+++ b/nicegui/events.py
@@ -6,6 +6,7 @@
from starlette.websockets import WebSocket
from . import globals
+from .auto_context import Context
from .helpers import is_coroutine
from .lifecycle import on_startup
from .task_logger import create_task
@@ -238,12 +239,12 @@ def handle_event(handler: Optional[Callable], arguments: EventArguments) -> Opti
if handler is None:
return False
no_arguments = not signature(handler).parameters
- with globals.within_view(arguments.sender.parent_view):
+ with Context(arguments.sender.parent_view):
result = handler() if no_arguments else handler(arguments)
if is_coroutine(handler):
async def wait_for_result():
- with globals.within_view(arguments.sender.parent_view):
- await result
+ with Context(arguments.sender.parent_view) as context:
+ await context.watch_asyncs(result)
if globals.loop and globals.loop.is_running():
create_task(wait_for_result(), name=str(handler))
else:
diff --git a/nicegui/globals.py b/nicegui/globals.py
index 337ea3bf1..b2131a55d 100644
--- a/nicegui/globals.py
+++ b/nicegui/globals.py
@@ -2,15 +2,13 @@
import asyncio
import logging
-from contextlib import contextmanager
from enum import Enum
-from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Generator, List, Optional, Union
+from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional, Union
from starlette.applications import Starlette
from uvicorn import Server
from .config import Config
-from .task_logger import create_task
if TYPE_CHECKING:
import justpy as jp
@@ -45,31 +43,3 @@ def find_route(function: Callable) -> str:
if not routes:
raise ValueError(f'Invalid page function {function}')
return routes[0]
-
-
-def get_task_id() -> int:
- return id(asyncio.current_task()) if loop and loop.is_running() else 0
-
-
-def get_view_stack() -> List['jp.HTMLBaseComponent']:
- task_id = get_task_id()
- if task_id not in view_stacks:
- view_stacks[task_id] = []
- return view_stacks[task_id]
-
-
-def prune_view_stack() -> None:
- task_id = get_task_id()
- if not view_stacks[task_id]:
- del view_stacks[task_id]
-
-
-@contextmanager
-def within_view(view: 'jp.HTMLBaseComponent') -> Generator[None, None, None]:
- child_count = len(view)
- get_view_stack().append(view)
- yield
- get_view_stack().pop()
- prune_view_stack()
- if len(view) != child_count:
- create_task(view.update())
diff --git a/nicegui/page.py b/nicegui/page.py
index 07c503253..0bcbd15de 100644
--- a/nicegui/page.py
+++ b/nicegui/page.py
@@ -17,6 +17,7 @@
from starlette.websockets import WebSocket
from . import globals
+from .auto_context import Context, get_view_stack
from .events import PageEvent
from .helpers import is_coroutine
from .page_builder import PageBuilder
@@ -74,7 +75,7 @@ def set_favicon(self, favicon: Optional[str]) -> None:
self.favicon = f'_favicon/{favicon}'
async def _route_function(self, request: Request) -> Page:
- with globals.within_view(self.view):
+ with Context(self.view):
for handler in globals.connect_handlers + ([self.connect_handler] if self.connect_handler else []):
arg_count = len(inspect.signature(handler).parameters)
is_coro = is_coroutine(handler)
@@ -87,11 +88,11 @@ async def _route_function(self, request: Request) -> Page:
return self
async def handle_page_ready(self, msg: AdDict) -> bool:
- with globals.within_view(self.view):
+ with Context(self.view) as context:
try:
if self.page_ready_generator is not None:
if isinstance(self.page_ready_generator, types.AsyncGeneratorType):
- await self.page_ready_generator.asend(PageEvent(msg.websocket))
+ await context.watch_asyncs(self.page_ready_generator.asend(PageEvent(msg.websocket)))
elif isinstance(self.page_ready_generator, types.GeneratorType):
self.page_ready_generator.send(PageEvent(msg.websocket))
except (StopIteration, StopAsyncIteration):
@@ -103,17 +104,19 @@ async def handle_page_ready(self, msg: AdDict) -> bool:
arg_count = len(inspect.signature(self.page_ready_handler).parameters)
is_coro = is_coroutine(self.page_ready_handler)
if arg_count == 1:
- await self.page_ready_handler(msg.websocket) if is_coro else self.page_ready_handler(msg.websocket)
+ result = self.page_ready_handler(msg.websocket)
elif arg_count == 0:
- await self.page_ready_handler() if is_coro else self.page_ready_handler()
+ result = self.page_ready_handler()
else:
raise ValueError(f'invalid number of arguments (0 or 1 allowed, got {arg_count})')
+ if is_coro:
+ await context.watch_asyncs(result)
except:
globals.log.exception('Failed to execute page-ready')
return False
async def on_disconnect(self, websocket: Optional[WebSocket] = None) -> None:
- with globals.within_view(self.view):
+ with Context(self.view):
for handler in globals.disconnect_handlers + ([self.disconnect_handler] if self.disconnect_handler else[]):
arg_count = len(inspect.signature(handler).parameters)
is_coro = is_coroutine(handler)
@@ -209,7 +212,7 @@ def __init__(
self.page: Optional[Page] = None
*_, self.converters = compile_path(route)
- def __call__(self, func, **kwargs) -> Callable:
+ def __call__(self, func: Callable, **kwargs) -> Callable:
@wraps(func)
async def decorated(request: Optional[Request] = None) -> Page:
self.page = Page(
@@ -224,7 +227,7 @@ async def decorated(request: Optional[Request] = None) -> Page:
shared=self.shared,
)
try:
- with globals.within_view(self.page.view):
+ with Context(self.page.view):
if 'request' in inspect.signature(func).parameters:
if self.shared:
raise RuntimeError('Cannot use `request` argument in shared page')
@@ -263,7 +266,7 @@ async def after_content(self) -> None:
def find_parent_view() -> jp.HTMLBaseComponent:
- view_stack = globals.get_view_stack()
+ view_stack = get_view_stack()
if not view_stack:
if globals.loop and globals.loop.is_running():
raise RuntimeError('cannot find parent view, view stack is empty')
diff --git a/nicegui/timer.py b/nicegui/timer.py
index 4a99337e3..3d60f779d 100644
--- a/nicegui/timer.py
+++ b/nicegui/timer.py
@@ -7,6 +7,7 @@
from starlette.websockets import WebSocket
from . import globals
+from .auto_context import Context
from .binding import BindableProperty
from .helpers import is_coroutine
from .page import Page, find_parent_page, find_parent_view
@@ -41,10 +42,10 @@ def __init__(self, interval: float, callback: Callable, *, active: bool = True,
async def do_callback():
try:
- with globals.within_view(self.parent_view):
+ with Context(self.parent_view) as context:
result = callback()
if is_coroutine(callback):
- await result
+ await context.watch_asyncs(result)
except Exception:
traceback.print_exc()
diff --git a/tests/screen.py b/tests/screen.py
index 077835c69..b6c5ab1a2 100644
--- a/tests/screen.py
+++ b/tests/screen.py
@@ -159,6 +159,16 @@ def get_attributes(self, tag: str, attribute: str) -> List[str]:
def wait(self, t: float) -> None:
time.sleep(t)
+ def wait_for(self, text: str, *, timeout: float = 1.0) -> None:
+ deadline = time.time() + timeout
+ while time.time() < deadline:
+ try:
+ self.find(text)
+ return
+ except:
+ self.wait(0.1)
+ raise TimeoutError()
+
def shot(self, name: str) -> None:
os.makedirs(self.SCREENSHOT_DIR, exist_ok=True)
filename = f'{self.SCREENSHOT_DIR}/{name}.png'
diff --git a/tests/test_auto_context.py b/tests/test_auto_context.py
index d06a68bc3..9888834bb 100644
--- a/tests/test_auto_context.py
+++ b/tests/test_auto_context.py
@@ -1,6 +1,8 @@
import asyncio
+from typing import Generator
from nicegui import ui
+from nicegui.events import PageEvent
from .screen import Screen
@@ -51,3 +53,74 @@ def test_adding_elements_during_onconnect(screen: Screen):
screen.open('/')
screen.should_contain('Label 2')
+
+
+def test_autoupdate_on_async_page_after_yield(screen: Screen):
+ @ui.page('/')
+ async def page() -> Generator[None, PageEvent, None]:
+ ui.label('before page is ready')
+ yield
+ ui.label('page ready')
+ await asyncio.sleep(1)
+ ui.label('one')
+ await asyncio.sleep(1)
+ ui.label('two')
+ await asyncio.sleep(1)
+ ui.label('three')
+
+ screen.open('/')
+ screen.should_contain('before page is ready')
+ screen.should_contain('page ready')
+ screen.should_not_contain('one')
+ screen.wait_for('one')
+ screen.should_not_contain('two')
+ screen.wait_for('two')
+ screen.should_not_contain('three')
+ screen.wait_for('three')
+
+
+def test_autoupdate_on_async_page_ready_callback(screen: Screen):
+ async def ready():
+ ui.label('page ready')
+ await asyncio.sleep(1)
+ ui.label('after delay')
+
+ @ui.page('/', on_page_ready=ready)
+ def page() -> Generator[None, PageEvent, None]:
+ ui.label('before page is ready')
+
+ screen.open('/')
+ screen.should_contain('before page is ready')
+ screen.should_contain('page ready')
+ screen.should_not_contain('after delay')
+ screen.wait_for('after delay')
+
+
+def test_autoupdate_on_async_event_handler(screen: Screen):
+ async def open():
+ with ui.dialog() as dialog, ui.card():
+ l = ui.label('This should be visible')
+ dialog.open()
+ await asyncio.sleep(1)
+ l.text = 'New text after 1 second'
+ ui.button('Dialog', on_click=open)
+
+ screen.open('/')
+ screen.click('Dialog')
+ screen.should_contain('This should be visible')
+ screen.should_not_contain('New text after 1 second')
+ screen.wait_for('New text after 1 second')
+
+
+def test_autoupdate_on_async_timer_callback(screen: Screen):
+ async def update():
+ ui.label('1')
+ await asyncio.sleep(1.0)
+ ui.label('2')
+ ui.timer(2.0, update, once=True)
+
+ screen.open('/')
+ screen.should_not_contain('1')
+ screen.wait_for('1')
+ screen.should_not_contain('2')
+ screen.wait_for('2')
diff --git a/tests/test_events.py b/tests/test_events.py
new file mode 100644
index 000000000..ec0411d26
--- /dev/null
+++ b/tests/test_events.py
@@ -0,0 +1,22 @@
+import asyncio
+
+from nicegui import ui
+
+from .screen import Screen
+
+
+def test_event_with_update_before_await(screen: Screen):
+ @ui.page('/')
+ def page():
+ async def update():
+ ui.label('1')
+ await asyncio.sleep(1.0)
+ ui.label('2')
+
+ ui.button('update', on_click=update)
+
+ screen.open('/')
+ screen.click('update')
+ screen.wait_for('1')
+ screen.should_not_contain('2')
+ screen.wait_for('2')