Skip to content

Typing State for the class-based API does not work #400

Description

@mdrideout

Following docs for action level typing for class-based actions does not work.

ref: #386

Current behavior

Example first action:

class SetInitialPromptAction(Action):
    @property
    def reads(self) -> list[str]:
        return []

    def run(self, state: ApplicationState, prompt: str) -> dict:
        return {"initial_prompt": prompt}

    @property
    def writes(self) -> list[str]:
        return ["initial_prompt"]

    def update(self, result: dict, state: ApplicationState) -> ApplicationState:
        prompt = result["initial_prompt"]
        logger.info(f"Saving prompt to state: {prompt}")
        state.initial_prompt = prompt
        return state

    @property
    def inputs(self) -> list[str]:
        return ["prompt"]

Example second action:

class ExtractSetAction(Action):
    @property
    def reads(self) -> list[str]:
        return ["initial_prompt"]

    def run(self, state: ApplicationState) -> dict:
        logger.info(f"ApplicationState: {state}")

        # Read prompt from state
        prompt = state.initial_prompt
        ...

Logs: ApplicationState: {'initial_prompt': None}

Stack Traces

Details

api | ********************************************************************************
api | -------------------------------------------------------------------
api | Oh no an error! Need help with Burr?
api | Join our discord and ask for help! https://discord.gg/4FxBMyzW5n
api | -------------------------------------------------------------------
api | > Action: extract_set encountered an error!<
api | > State (at time of action):
api | {'__PRIOR_STEP': 'set_prompt',
api | '__SEQUENCE_ID': 1,
api | 'initial_prompt': None,
api | 'set_from_prompt': None}
api | > Inputs (at time of action):
api | {'prompt': 'bicep curls with 22 pound dumbells for 21 reps'}
api | ********************************************************************************
api | Traceback (most recent call last):
api | File "/usr/local/lib/python3.11/site-packages/burr/core/application.py", line 816, in _step
api | result = _run_function(
api | ^^^^^^^^^^^^^^
api | File "/usr/local/lib/python3.11/site-packages/burr/core/application.py", line 123, in _run_function
api | result = function.run(state_to_use, **inputs)
api | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
api | File "/app/app/ai/actions/extract_set/action.py", line 93, in run
api | prompt = state.initial_prompt
api | ^^^^^^^^^^^^^^^^^^^^
api | AttributeError: 'State' object has no attribute 'initial_prompt'
api | INFO: 192.168.65.1:35000 - "GET /api/extract-set?prompt=bicep%20curls%20with%2022%20pound%20dumbells%20for%2021%20reps HTTP/1.1" 500 Internal Server Error
api | ERROR: Exception in ASGI application
api | + Exception Group Traceback (most recent call last):
api | | File "/usr/local/lib/python3.11/site-packages/starlette/_utils.py", line 77, in collapse_excgroups
api | | yield
api | | File "/usr/local/lib/python3.11/site-packages/starlette/middleware/base.py", line 186, in call
api | | async with anyio.create_task_group() as task_group:
api | | File "/usr/local/lib/python3.11/site-packages/anyio/_backends/_asyncio.py", line 736, in aexit
api | | raise BaseExceptionGroup(
api | | ExceptionGroup: unhandled errors in a TaskGroup (1 sub-exception)
api | +-+---------------- 1 ----------------
api | | Traceback (most recent call last):
api | | File "/usr/local/lib/python3.11/site-packages/uvicorn/protocols/http/httptools_impl.py", line 401, in run_asgi
api | | result = await app( # type: ignore[func-returns-value]
api | | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
api | | File "/usr/local/lib/python3.11/site-packages/uvicorn/middleware/proxy_headers.py", line 60, in call
api | | return await self.app(scope, receive, send)
api | | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
api | | File "/usr/local/lib/python3.11/site-packages/fastapi/applications.py", line 1054, in call
api | | await super().call(scope, receive, send)
api | | File "/usr/local/lib/python3.11/site-packages/starlette/applications.py", line 113, in call
api | | await self.middleware_stack(scope, receive, send)
api | | File "/usr/local/lib/python3.11/site-packages/starlette/middleware/errors.py", line 187, in call
api | | raise exc
api | | File "/usr/local/lib/python3.11/site-packages/starlette/middleware/errors.py", line 165, in call
api | | await self.app(scope, receive, _send)
api | | File "/usr/local/lib/python3.11/site-packages/starlette/middleware/base.py", line 185, in call
api | | with collapse_excgroups():
api | | File "/usr/local/lib/python3.11/contextlib.py", line 158, in exit
api | | self.gen.throw(typ, value, traceback)
api | | File "/usr/local/lib/python3.11/site-packages/starlette/_utils.py", line 83, in collapse_excgroups
api | | raise exc
api | | File "/usr/local/lib/python3.11/site-packages/starlette/middleware/base.py", line 187, in call
api | | response = await self.dispatch_func(request, call_next)
api | | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
api | | File "/app/app/main.py", line 36, in log_requests
api | | response = await call_next(request)
api | | ^^^^^^^^^^^^^^^^^^^^^^^^
api | | File "/usr/local/lib/python3.11/site-packages/starlette/middleware/base.py", line 163, in call_next
api | | raise app_exc
api | | File "/usr/local/lib/python3.11/site-packages/starlette/middleware/base.py", line 149, in coro
api | | await self.app(scope, receive_or_disconnect, send_no_error)
api | | File "/usr/local/lib/python3.11/site-packages/starlette/middleware/exceptions.py", line 62, in call
api | | await wrap_app_handling_exceptions(self.app, conn)(scope, receive, send)
api | | File "/usr/local/lib/python3.11/site-packages/starlette/_exception_handler.py", line 62, in wrapped_app
api | | raise exc
api | | File "/usr/local/lib/python3.11/site-packages/starlette/_exception_handler.py", line 51, in wrapped_app
api | | await app(scope, receive, sender)
api | | File "/usr/local/lib/python3.11/site-packages/starlette/routing.py", line 715, in call
api | | await self.middleware_stack(scope, receive, send)
api | | File "/usr/local/lib/python3.11/site-packages/starlette/routing.py", line 735, in app
api | | await route.handle(scope, receive, send)
api | | File "/usr/local/lib/python3.11/site-packages/starlette/routing.py", line 288, in handle
api | | await self.app(scope, receive, send)
api | | File "/usr/local/lib/python3.11/site-packages/starlette/routing.py", line 76, in app
api | | await wrap_app_handling_exceptions(app, request)(scope, receive, send)
api | | File "/usr/local/lib/python3.11/site-packages/starlette/_exception_handler.py", line 62, in wrapped_app
api | | raise exc
api | | File "/usr/local/lib/python3.11/site-packages/starlette/_exception_handler.py", line 51, in wrapped_app
api | | await app(scope, receive, sender)
api | | File "/usr/local/lib/python3.11/site-packages/starlette/routing.py", line 73, in app
api | | response = await f(request)
api | | ^^^^^^^^^^^^^^^^
api | | File "/usr/local/lib/python3.11/site-packages/fastapi/routing.py", line 301, in app
api | | raw_response = await run_endpoint_function(
api | | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
api | | File "/usr/local/lib/python3.11/site-packages/fastapi/routing.py", line 214, in run_endpoint_function
api | | return await run_in_threadpool(dependant.call, **values)
api | | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
api | | File "/usr/local/lib/python3.11/site-packages/starlette/concurrency.py", line 39, in run_in_threadpool
api | | return await anyio.to_thread.run_sync(func, *args)
api | | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
api | | File "/usr/local/lib/python3.11/site-packages/anyio/to_thread.py", line 56, in run_sync
api | | return await get_async_backend().run_sync_in_worker_thread(
api | | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
api | | File "/usr/local/lib/python3.11/site-packages/anyio/_backends/_asyncio.py", line 2405, in run_sync_in_worker_thread
api | | return await future
api | | ^^^^^^^^^^^^
api | | File "/usr/local/lib/python3.11/site-packages/anyio/_backends/_asyncio.py", line 914, in run
api | | result = context.run(func, *args)
api | | ^^^^^^^^^^^^^^^^^^^^^^^^
api | | File "/app/app/api/routes.py", line 53, in extract_set
api | | action, result, state = application.run(
api | | ^^^^^^^^^^^^^^^^
api | | File "/usr/local/lib/python3.11/site-packages/burr/telemetry.py", line 276, in wrapped_fn
api | | return call_fn(*args, **kwargs)
api | | ^^^^^^^^^^^^^^^^^^^^^^^^
api | | File "/usr/local/lib/python3.11/site-packages/burr/core/application.py", line 616, in wrapper_sync
api | | return fn(app_self, *args, **kwargs)
api | | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
api | | File "/usr/local/lib/python3.11/site-packages/burr/core/application.py", line 1168, in run
api | | next(gen)
api | | File "/usr/local/lib/python3.11/site-packages/burr/core/application.py", line 1111, in iterate
api | | prior_action, result, state = self.step(inputs=inputs)
api | | ^^^^^^^^^^^^^^^^^^^^^^^^
api | | File "/usr/local/lib/python3.11/site-packages/burr/core/application.py", line 616, in wrapper_sync
api | | return fn(app_self, *args, **kwargs)
api | | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
api | | File "/usr/local/lib/python3.11/site-packages/burr/core/application.py", line 773, in step
api | | out = self._step(inputs=inputs, _run_hooks=True)
api | | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
api | | File "/usr/local/lib/python3.11/site-packages/burr/core/application.py", line 826, in _step
api | | raise e
api | | File "/usr/local/lib/python3.11/site-packages/burr/core/application.py", line 816, in _step
api | | result = _run_function(
api | | ^^^^^^^^^^^^^^
api | | File "/usr/local/lib/python3.11/site-packages/burr/core/application.py", line 123, in _run_function
api | | result = function.run(state_to_use, **inputs)
api | | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
api | | File "/app/app/ai/actions/extract_set/action.py", line 93, in run
api | | prompt = state.initial_prompt
api | | ^^^^^^^^^^^^^^^^^^^^
api | | AttributeError: 'State' object has no attribute 'initial_prompt'
api | +------------------------------------
api |
api | During handling of the above exception, another exception occurred:
api |
api | Traceback (most recent call last):
api | File "/usr/local/lib/python3.11/site-packages/uvicorn/protocols/http/httptools_impl.py", line 401, in run_asgi
api | result = await app( # type: ignore[func-returns-value]
api | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
api | File "/usr/local/lib/python3.11/site-packages/uvicorn/middleware/proxy_headers.py", line 60, in call
api | return await self.app(scope, receive, send)
api | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
api | File "/usr/local/lib/python3.11/site-packages/fastapi/applications.py", line 1054, in call
api | await super().call(scope, receive, send)
api | File "/usr/local/lib/python3.11/site-packages/starlette/applications.py", line 113, in call
api | await self.middleware_stack(scope, receive, send)
api | File "/usr/local/lib/python3.11/site-packages/starlette/middleware/errors.py", line 187, in call
api | raise exc
api | File "/usr/local/lib/python3.11/site-packages/starlette/middleware/errors.py", line 165, in call
api | await self.app(scope, receive, _send)
api | File "/usr/local/lib/python3.11/site-packages/starlette/middleware/base.py", line 185, in call
api | with collapse_excgroups():
api | File "/usr/local/lib/python3.11/contextlib.py", line 158, in exit
api | self.gen.throw(typ, value, traceback)
api | File "/usr/local/lib/python3.11/site-packages/starlette/_utils.py", line 83, in collapse_excgroups
api | raise exc
api | File "/usr/local/lib/python3.11/site-packages/starlette/middleware/base.py", line 187, in call
api | response = await self.dispatch_func(request, call_next)
api | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
api | File "/app/app/main.py", line 36, in log_requests
api | response = await call_next(request)
api | ^^^^^^^^^^^^^^^^^^^^^^^^
api | File "/usr/local/lib/python3.11/site-packages/starlette/middleware/base.py", line 163, in call_next
api | raise app_exc
api | File "/usr/local/lib/python3.11/site-packages/starlette/middleware/base.py", line 149, in coro
api | await self.app(scope, receive_or_disconnect, send_no_error)
api | File "/usr/local/lib/python3.11/site-packages/starlette/middleware/exceptions.py", line 62, in call
api | await wrap_app_handling_exceptions(self.app, conn)(scope, receive, send)
api | File "/usr/local/lib/python3.11/site-packages/starlette/_exception_handler.py", line 62, in wrapped_app
api | raise exc
api | File "/usr/local/lib/python3.11/site-packages/starlette/_exception_handler.py", line 51, in wrapped_app
api | await app(scope, receive, sender)
api | File "/usr/local/lib/python3.11/site-packages/starlette/routing.py", line 715, in call
api | await self.middleware_stack(scope, receive, send)
api | File "/usr/local/lib/python3.11/site-packages/starlette/routing.py", line 735, in app
api | await route.handle(scope, receive, send)
api | File "/usr/local/lib/python3.11/site-packages/starlette/routing.py", line 288, in handle
api | await self.app(scope, receive, send)
api | File "/usr/local/lib/python3.11/site-packages/starlette/routing.py", line 76, in app
api | await wrap_app_handling_exceptions(app, request)(scope, receive, send)
api | File "/usr/local/lib/python3.11/site-packages/starlette/_exception_handler.py", line 62, in wrapped_app
api | raise exc
api | File "/usr/local/lib/python3.11/site-packages/starlette/_exception_handler.py", line 51, in wrapped_app
api | await app(scope, receive, sender)
api | File "/usr/local/lib/python3.11/site-packages/starlette/routing.py", line 73, in app
api | response = await f(request)
api | ^^^^^^^^^^^^^^^^
api | File "/usr/local/lib/python3.11/site-packages/fastapi/routing.py", line 301, in app
api | raw_response = await run_endpoint_function(
api | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
api | File "/usr/local/lib/python3.11/site-packages/fastapi/routing.py", line 214, in run_endpoint_function
api | return await run_in_threadpool(dependant.call, **values)
api | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
api | File "/usr/local/lib/python3.11/site-packages/starlette/concurrency.py", line 39, in run_in_threadpool
api | return await anyio.to_thread.run_sync(func, *args)
api | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
api | File "/usr/local/lib/python3.11/site-packages/anyio/to_thread.py", line 56, in run_sync
api | return await get_async_backend().run_sync_in_worker_thread(
api | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
api | File "/usr/local/lib/python3.11/site-packages/anyio/_backends/_asyncio.py", line 2405, in run_sync_in_worker_thread
api | return await future
api | ^^^^^^^^^^^^
api | File "/usr/local/lib/python3.11/site-packages/anyio/_backends/_asyncio.py", line 914, in run
api | result = context.run(func, *args)
api | ^^^^^^^^^^^^^^^^^^^^^^^^
api | File "/app/app/api/routes.py", line 53, in extract_set
api | action, result, state = application.run(
api | ^^^^^^^^^^^^^^^^
api | File "/usr/local/lib/python3.11/site-packages/burr/telemetry.py", line 276, in wrapped_fn
api | return call_fn(*args, **kwargs)
api | ^^^^^^^^^^^^^^^^^^^^^^^^
api | File "/usr/local/lib/python3.11/site-packages/burr/core/application.py", line 616, in wrapper_sync
api | return fn(app_self, *args, **kwargs)
api | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
api | File "/usr/local/lib/python3.11/site-packages/burr/core/application.py", line 1168, in run
api | next(gen)
api | File "/usr/local/lib/python3.11/site-packages/burr/core/application.py", line 1111, in iterate
api | prior_action, result, state = self.step(inputs=inputs)
api | ^^^^^^^^^^^^^^^^^^^^^^^^
api | File "/usr/local/lib/python3.11/site-packages/burr/core/application.py", line 616, in wrapper_sync
api | return fn(app_self, *args, **kwargs)
api | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
api | File "/usr/local/lib/python3.11/site-packages/burr/core/application.py", line 773, in step
api | out = self._step(inputs=inputs, _run_hooks=True)
api | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
api | File "/usr/local/lib/python3.11/site-packages/burr/core/application.py", line 826, in _step
api | raise e
api | File "/usr/local/lib/python3.11/site-packages/burr/core/application.py", line 816, in _step
api | result = _run_function(
api | ^^^^^^^^^^^^^^
api | File "/usr/local/lib/python3.11/site-packages/burr/core/application.py", line 123, in _run_function
api | result = function.run(state_to_use, **inputs)
api | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
api | File "/app/app/ai/actions/extract_set/action.py", line 93, in run
api | prompt = state.initial_prompt
api | ^^^^^^^^^^^^^^^^^^^^
api | AttributeError: 'State' object has no attribute 'initial_prompt'

Screenshots

(If applicable)

Steps to replicate behavior

Library & System Information

E.g. python version, burr library version, linux, etc.

  • Python 3.11
  • Debian bookworm slim
  • Burr library:
burr = { extras = [
  "graphviz",
  "hamilton",
  "streamlit",
  "tracking-client",
  "tracking-server",
], version = "^0.31.1" }

Expected behavior

To work the same as function-based actions

Additional context

Add any other context about the problem here.

Metadata

Metadata

Assignees

Labels

area/coreApplication, State, Graph, Actionsarea/typingMypy, type hints, pydantichelp wantedContributors wanted!kind/bugSomething is brokenpriority/highAffects many users, needs action within weeks

Type

No type

Fields

No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions