From e220ed221c0c0676601d6cef11fcfd7eec6b9248 Mon Sep 17 00:00:00 2001 From: Roger Yang <80478925+RogerHYang@users.noreply.github.com> Date: Tue, 29 Oct 2024 15:44:51 -0700 Subject: [PATCH] refactor(sessions): add dataloaders for session queries (#5222) --- .../fixtures/multi-turn_chat_sessions.ipynb | 19 ++-- src/phoenix/server/api/context.py | 9 ++ .../server/api/dataloaders/__init__.py | 8 ++ .../server/api/dataloaders/session_io.py | 75 +++++++++++++ .../api/dataloaders/session_num_traces.py | 30 +++++ .../api/dataloaders/session_token_usages.py | 41 +++++++ .../api/dataloaders/trace_by_trace_ids.py | 9 +- .../api/dataloaders/trace_root_spans.py | 32 ++++++ .../server/api/types/ProjectSession.py | 70 ++---------- src/phoenix/server/api/types/Trace.py | 11 +- src/phoenix/server/app.py | 9 ++ src/phoenix/trace/schemas.py | 16 +++ .../server/api/types/test_ProjectSession.py | 103 +++++++++++------- tests/unit/server/api/types/test_Trace.py | 53 ++++++--- 14 files changed, 348 insertions(+), 137 deletions(-) create mode 100644 src/phoenix/server/api/dataloaders/session_io.py create mode 100644 src/phoenix/server/api/dataloaders/session_num_traces.py create mode 100644 src/phoenix/server/api/dataloaders/session_token_usages.py create mode 100644 src/phoenix/server/api/dataloaders/trace_root_spans.py diff --git a/scripts/fixtures/multi-turn_chat_sessions.ipynb b/scripts/fixtures/multi-turn_chat_sessions.ipynb index a1f82ceb04..a40b06e809 100644 --- a/scripts/fixtures/multi-turn_chat_sessions.ipynb +++ b/scripts/fixtures/multi-turn_chat_sessions.ipynb @@ -7,7 +7,7 @@ "metadata": {}, "outputs": [], "source": [ - "%pip install -Uqqq datasets openinference-instrumentation-openai openai-responses openai tiktoken langchain langchain-openai llama-index llama-index-llms-openai faker mdgen" + "%pip install -Uqqq datasets openinference-semantic-conventions openinference-instrumentation-openai faker openai-responses openai tiktoken" ] }, { @@ -25,7 +25,6 @@ "import pandas as pd\n", "from datasets import load_dataset\n", "from faker import Faker\n", - "from mdgen import MarkdownPostProvider\n", "from openai_responses import OpenAIMock\n", "from openinference.instrumentation import using_session, using_user\n", "from openinference.instrumentation.openai import OpenAIInstrumentor\n", @@ -39,8 +38,7 @@ "import phoenix as px\n", "from phoenix.trace.span_evaluations import SpanEvaluations\n", "\n", - "fake = Faker(\"ja_JP\")\n", - "fake.add_provider(MarkdownPostProvider)" + "fake = Faker([\"ja_JP\", \"vi_VN\", \"ko_KR\", \"zh_CN\", \"th_TH\", \"bn_BD\"])" ] }, { @@ -104,7 +102,7 @@ " if p < 0.1:\n", " return \":\" * randint(1, 5)\n", " if p < 0.9:\n", - " return Faker([\"ja_JP\", \"vi_VN\", \"ko_KR\", \"zh_CN\"]).address()\n", + " return fake.address()\n", " return int(abs(random()) * 1_000_000_000)\n", "\n", "\n", @@ -113,15 +111,17 @@ " if p < 0.1:\n", " return \":\" * randint(1, 5)\n", " if p < 0.9:\n", - " return Faker([\"ja_JP\", \"vi_VN\", \"ko_KR\", \"zh_CN\"]).name()\n", + " return fake.name()\n", " return int(abs(random()) * 1_000_000_000)\n", "\n", "\n", - "def export_spans():\n", + "def export_spans(prob_drop_root):\n", " \"\"\"Export spans in random order for receiver testing\"\"\"\n", " spans = list(in_memory_span_exporter.get_finished_spans())\n", " shuffle(spans)\n", " for span in spans:\n", + " if span.parent is None and random() < prob_drop_root:\n", + " continue\n", " otlp_span_exporter.export([span])\n", " in_memory_span_exporter.clear()\n", " session_count = len({id_ for span in spans if (id_ := span.attributes.get(\"session.id\"))})\n", @@ -147,7 +147,7 @@ " return\n", " has_yielded = False\n", " with tracer.start_as_current_span(\n", - " Faker(\"ja_JP\").kana_name(),\n", + " fake.city(),\n", " attributes=dict(rand_span_kind()),\n", " end_on_exit=False,\n", " ) as root:\n", @@ -185,6 +185,7 @@ "source": [ "session_count = randint(5, 10)\n", "tree_complexity = 4 # set to 0 for single span under root\n", + "prob_drop_root = 0.0 # probability that a root span gets dropped\n", "\n", "\n", "def simulate_openai():\n", @@ -237,7 +238,7 @@ " simulate_openai()\n", "finally:\n", " OpenAIInstrumentor().uninstrument()\n", - "spans = export_spans()\n", + "spans = export_spans(prob_drop_root)\n", "\n", "# Annotate root spans\n", "root_span_ids = pd.Series(\n", diff --git a/src/phoenix/server/api/context.py b/src/phoenix/server/api/context.py index e777a9ef28..bb2fade92b 100644 --- a/src/phoenix/server/api/context.py +++ b/src/phoenix/server/api/context.py @@ -31,12 +31,16 @@ MinStartOrMaxEndTimeDataLoader, ProjectByNameDataLoader, RecordCountDataLoader, + SessionIODataLoader, + SessionNumTracesDataLoader, + SessionTokenUsagesDataLoader, SpanAnnotationsDataLoader, SpanDatasetExamplesDataLoader, SpanDescendantsDataLoader, SpanProjectsDataLoader, TokenCountDataLoader, TraceByTraceIdsDataLoader, + TraceRootSpansDataLoader, UserRolesDataLoader, UsersDataLoader, ) @@ -68,12 +72,17 @@ class DataLoaders: latency_ms_quantile: LatencyMsQuantileDataLoader min_start_or_max_end_times: MinStartOrMaxEndTimeDataLoader record_counts: RecordCountDataLoader + session_first_inputs: SessionIODataLoader + session_last_outputs: SessionIODataLoader + session_num_traces: SessionNumTracesDataLoader + session_token_usages: SessionTokenUsagesDataLoader span_annotations: SpanAnnotationsDataLoader span_dataset_examples: SpanDatasetExamplesDataLoader span_descendants: SpanDescendantsDataLoader span_projects: SpanProjectsDataLoader token_counts: TokenCountDataLoader trace_by_trace_ids: TraceByTraceIdsDataLoader + trace_root_spans: TraceRootSpansDataLoader project_by_name: ProjectByNameDataLoader users: UsersDataLoader user_roles: UserRolesDataLoader diff --git a/src/phoenix/server/api/dataloaders/__init__.py b/src/phoenix/server/api/dataloaders/__init__.py index 8e33ee97b9..9dd189816a 100644 --- a/src/phoenix/server/api/dataloaders/__init__.py +++ b/src/phoenix/server/api/dataloaders/__init__.py @@ -19,12 +19,16 @@ from .min_start_or_max_end_times import MinStartOrMaxEndTimeCache, MinStartOrMaxEndTimeDataLoader from .project_by_name import ProjectByNameDataLoader from .record_counts import RecordCountCache, RecordCountDataLoader +from .session_io import SessionIODataLoader +from .session_num_traces import SessionNumTracesDataLoader +from .session_token_usages import SessionTokenUsagesDataLoader from .span_annotations import SpanAnnotationsDataLoader from .span_dataset_examples import SpanDatasetExamplesDataLoader from .span_descendants import SpanDescendantsDataLoader from .span_projects import SpanProjectsDataLoader from .token_counts import TokenCountCache, TokenCountDataLoader from .trace_by_trace_ids import TraceByTraceIdsDataLoader +from .trace_root_spans import TraceRootSpansDataLoader from .user_roles import UserRolesDataLoader from .users import UsersDataLoader @@ -45,11 +49,15 @@ "LatencyMsQuantileDataLoader", "MinStartOrMaxEndTimeDataLoader", "RecordCountDataLoader", + "SessionIODataLoader", + "SessionNumTracesDataLoader", + "SessionTokenUsagesDataLoader", "SpanDatasetExamplesDataLoader", "SpanDescendantsDataLoader", "SpanProjectsDataLoader", "TokenCountDataLoader", "TraceByTraceIdsDataLoader", + "TraceRootSpansDataLoader", "ProjectByNameDataLoader", "SpanAnnotationsDataLoader", "UsersDataLoader", diff --git a/src/phoenix/server/api/dataloaders/session_io.py b/src/phoenix/server/api/dataloaders/session_io.py new file mode 100644 index 0000000000..10094642f9 --- /dev/null +++ b/src/phoenix/server/api/dataloaders/session_io.py @@ -0,0 +1,75 @@ +from functools import cached_property +from typing import Literal, Optional, cast + +from openinference.semconv.trace import SpanAttributes +from sqlalchemy import Select, func, select +from strawberry.dataloader import DataLoader +from typing_extensions import TypeAlias, assert_never + +from phoenix.db import models +from phoenix.server.types import DbSessionFactory +from phoenix.trace.schemas import MimeType, SpanIOValue + +Key: TypeAlias = int +Result: TypeAlias = Optional[SpanIOValue] + +Kind = Literal["first_input", "last_output"] + + +class SessionIODataLoader(DataLoader[Key, Result]): + def __init__(self, db: DbSessionFactory, kind: Kind) -> None: + super().__init__(load_fn=self._load_fn) + self._db = db + self._kind = kind + + @cached_property + def _subq(self) -> Select[tuple[Optional[int], str, str, int]]: + stmt = ( + select(models.Trace.project_session_rowid.label("id_")) + .join_from(models.Span, models.Trace) + .where(models.Span.parent_id.is_(None)) + ) + if self._kind == "first_input": + stmt = stmt.add_columns( + models.Span.attributes[INPUT_VALUE].label("value"), + models.Span.attributes[INPUT_MIME_TYPE].label("mime_type"), + func.row_number() + .over( + partition_by=models.Trace.project_session_rowid, + order_by=[models.Trace.start_time.asc(), models.Trace.id.asc()], + ) + .label("rank"), + ) + elif self._kind == "last_output": + stmt = stmt.add_columns( + models.Span.attributes[OUTPUT_VALUE].label("value"), + models.Span.attributes[OUTPUT_MIME_TYPE].label("mime_type"), + func.row_number() + .over( + partition_by=models.Trace.project_session_rowid, + order_by=[models.Trace.start_time.desc(), models.Trace.id.desc()], + ) + .label("rank"), + ) + else: + assert_never(self._kind) + return cast(Select[tuple[Optional[int], str, str, int]], stmt) + + def _stmt(self, *keys: Key) -> Select[tuple[int, str, str]]: + subq = self._subq.where(models.Trace.project_session_rowid.in_(keys)).subquery() + return select(subq.c.id_, subq.c.value, subq.c.mime_type).filter_by(rank=1) + + async def _load_fn(self, keys: list[Key]) -> list[Result]: + async with self._db() as session: + result: dict[Key, SpanIOValue] = { + id_: SpanIOValue(value=value, mime_type=MimeType(mime_type)) + async for id_, value, mime_type in await session.stream(self._stmt(*keys)) + if id_ is not None + } + return [result.get(key) for key in keys] + + +INPUT_VALUE = SpanAttributes.INPUT_VALUE.split(".") +INPUT_MIME_TYPE = SpanAttributes.INPUT_MIME_TYPE.split(".") +OUTPUT_VALUE = SpanAttributes.OUTPUT_VALUE.split(".") +OUTPUT_MIME_TYPE = SpanAttributes.OUTPUT_MIME_TYPE.split(".") diff --git a/src/phoenix/server/api/dataloaders/session_num_traces.py b/src/phoenix/server/api/dataloaders/session_num_traces.py new file mode 100644 index 0000000000..6614429a51 --- /dev/null +++ b/src/phoenix/server/api/dataloaders/session_num_traces.py @@ -0,0 +1,30 @@ +from sqlalchemy import func, select +from strawberry.dataloader import DataLoader +from typing_extensions import TypeAlias + +from phoenix.db import models +from phoenix.server.types import DbSessionFactory + +Key: TypeAlias = int +Result: TypeAlias = int + + +class SessionNumTracesDataLoader(DataLoader[Key, Result]): + def __init__(self, db: DbSessionFactory) -> None: + super().__init__(load_fn=self._load_fn) + self._db = db + + async def _load_fn(self, keys: list[Key]) -> list[Result]: + stmt = ( + select( + models.Trace.project_session_rowid.label("id_"), + func.count(models.Trace.id).label("value"), + ) + .group_by(models.Trace.project_session_rowid) + .where(models.Trace.project_session_rowid.in_(keys)) + ) + async with self._db() as session: + result: dict[Key, int] = { + id_: value async for id_, value in await session.stream(stmt) if id_ is not None + } + return [result.get(key, 0) for key in keys] diff --git a/src/phoenix/server/api/dataloaders/session_token_usages.py b/src/phoenix/server/api/dataloaders/session_token_usages.py new file mode 100644 index 0000000000..68978d82d8 --- /dev/null +++ b/src/phoenix/server/api/dataloaders/session_token_usages.py @@ -0,0 +1,41 @@ +from sqlalchemy import func, select +from sqlalchemy.sql.functions import coalesce +from strawberry.dataloader import DataLoader +from typing_extensions import TypeAlias + +from phoenix.db import models +from phoenix.server.types import DbSessionFactory +from phoenix.trace.schemas import TokenUsage + +Key: TypeAlias = int +Result: TypeAlias = TokenUsage + + +class SessionTokenUsagesDataLoader(DataLoader[Key, Result]): + def __init__(self, db: DbSessionFactory) -> None: + super().__init__(load_fn=self._load_fn) + self._db = db + + async def _load_fn(self, keys: list[Key]) -> list[Result]: + stmt = ( + select( + models.Trace.project_session_rowid.label("id_"), + func.sum(coalesce(models.Span.cumulative_llm_token_count_prompt, 0)).label( + "prompt" + ), + func.sum(coalesce(models.Span.cumulative_llm_token_count_completion, 0)).label( + "completion" + ), + ) + .join_from(models.Span, models.Trace) + .where(models.Span.parent_id.is_(None)) + .where(models.Trace.project_session_rowid.in_(keys)) + .group_by(models.Trace.project_session_rowid) + ) + async with self._db() as session: + result: dict[Key, TokenUsage] = { + id_: TokenUsage(prompt=prompt, completion=completion) + async for id_, prompt, completion in await session.stream(stmt) + if id_ is not None + } + return [result.get(key, TokenUsage()) for key in keys] diff --git a/src/phoenix/server/api/dataloaders/trace_by_trace_ids.py b/src/phoenix/server/api/dataloaders/trace_by_trace_ids.py index e8d2fe6326..50f26aec25 100644 --- a/src/phoenix/server/api/dataloaders/trace_by_trace_ids.py +++ b/src/phoenix/server/api/dataloaders/trace_by_trace_ids.py @@ -7,10 +7,7 @@ from phoenix.db import models from phoenix.server.types import DbSessionFactory -TraceId: TypeAlias = str -Key: TypeAlias = TraceId -TraceRowId: TypeAlias = int -ProjectRowId: TypeAlias = int +Key: TypeAlias = str Result: TypeAlias = Optional[models.Trace] @@ -22,5 +19,7 @@ def __init__(self, db: DbSessionFactory) -> None: async def _load_fn(self, keys: List[Key]) -> List[Result]: stmt = select(models.Trace).where(models.Trace.trace_id.in_(keys)) async with self._db() as session: - result = {trace.trace_id: trace for trace in await session.scalars(stmt)} + result: dict[Key, models.Trace] = { + trace.trace_id: trace async for trace in await session.stream_scalars(stmt) + } return [result.get(trace_id) for trace_id in keys] diff --git a/src/phoenix/server/api/dataloaders/trace_root_spans.py b/src/phoenix/server/api/dataloaders/trace_root_spans.py new file mode 100644 index 0000000000..fc5716cb61 --- /dev/null +++ b/src/phoenix/server/api/dataloaders/trace_root_spans.py @@ -0,0 +1,32 @@ +from typing import List, Optional + +from sqlalchemy import select +from sqlalchemy.orm import contains_eager +from strawberry.dataloader import DataLoader +from typing_extensions import TypeAlias + +from phoenix.db import models +from phoenix.server.types import DbSessionFactory + +Key: TypeAlias = int +Result: TypeAlias = Optional[models.Span] + + +class TraceRootSpansDataLoader(DataLoader[Key, Result]): + def __init__(self, db: DbSessionFactory) -> None: + super().__init__(load_fn=self._load_fn) + self._db = db + + async def _load_fn(self, keys: List[Key]) -> List[Result]: + stmt = ( + select(models.Span) + .join(models.Trace) + .where(models.Span.parent_id.is_(None)) + .where(models.Trace.id.in_(keys)) + .options(contains_eager(models.Span.trace).load_only(models.Trace.trace_id)) + ) + async with self._db() as session: + result: dict[Key, models.Span] = { + span.trace_rowid: span async for span in await session.stream_scalars(stmt) + } + return [result.get(key) for key in keys] diff --git a/src/phoenix/server/api/types/ProjectSession.py b/src/phoenix/server/api/types/ProjectSession.py index 666937fb66..1b0a3efefe 100644 --- a/src/phoenix/server/api/types/ProjectSession.py +++ b/src/phoenix/server/api/types/ProjectSession.py @@ -3,8 +3,7 @@ import strawberry from openinference.semconv.trace import SpanAttributes -from sqlalchemy import func, select -from sqlalchemy.sql.functions import coalesce +from sqlalchemy import select from strawberry import UNSET, Info, lazy from strawberry.relay import Connection, Node, NodeID @@ -37,32 +36,18 @@ async def num_traces( self, info: Info[Context, None], ) -> int: - stmt = select(func.count(models.Trace.id)).filter_by(project_session_rowid=self.id_attr) - async with info.context.db() as session: - return await session.scalar(stmt) or 0 + return await info.context.data_loaders.session_num_traces.load(self.id_attr) @strawberry.field async def first_input( self, info: Info[Context, None], ) -> Optional[SpanIOValue]: - stmt = ( - select( - models.Span.attributes[INPUT_VALUE].label("value"), - models.Span.attributes[INPUT_MIME_TYPE].label("mime_type"), - ) - .join(models.Trace) - .filter_by(project_session_rowid=self.id_attr) - .where(models.Span.parent_id.is_(None)) - .order_by(models.Trace.start_time.asc()) - .limit(1) - ) - async with info.context.db() as session: - record = (await session.execute(stmt)).first() - if record is None or record.value is None: + record = await info.context.data_loaders.session_first_inputs.load(self.id_attr) + if record is None: return None return SpanIOValue( - mime_type=MimeType(record.mime_type), + mime_type=MimeType(record.mime_type.value), value=record.value, ) @@ -71,23 +56,11 @@ async def last_output( self, info: Info[Context, None], ) -> Optional[SpanIOValue]: - stmt = ( - select( - models.Span.attributes[OUTPUT_VALUE].label("value"), - models.Span.attributes[OUTPUT_MIME_TYPE].label("mime_type"), - ) - .join(models.Trace) - .filter_by(project_session_rowid=self.id_attr) - .where(models.Span.parent_id.is_(None)) - .order_by(models.Trace.start_time.desc()) - .limit(1) - ) - async with info.context.db() as session: - record = (await session.execute(stmt)).first() - if record is None or record.value is None: + record = await info.context.data_loaders.session_last_outputs.load(self.id_attr) + if record is None: return None return SpanIOValue( - mime_type=MimeType(record.mime_type), + mime_type=MimeType(record.mime_type.value), value=record.value, ) @@ -96,29 +69,10 @@ async def token_usage( self, info: Info[Context, None], ) -> TokenUsage: - stmt = ( - select( - func.sum(coalesce(models.Span.cumulative_llm_token_count_prompt, 0)).label( - "prompt" - ), - func.sum(coalesce(models.Span.cumulative_llm_token_count_completion, 0)).label( - "completion" - ), - ) - .join(models.Trace) - .filter_by(project_session_rowid=self.id_attr) - .where(models.Span.parent_id.is_(None)) - .limit(1) - ) - async with info.context.db() as session: - usage = (await session.execute(stmt)).first() - return ( - TokenUsage( - prompt=usage.prompt or 0, - completion=usage.completion or 0, - ) - if usage - else TokenUsage() + usage = await info.context.data_loaders.session_token_usages.load(self.id_attr) + return TokenUsage( + prompt=usage.prompt, + completion=usage.completion, ) @strawberry.field diff --git a/src/phoenix/server/api/types/Trace.py b/src/phoenix/server/api/types/Trace.py index 4d0e90a850..cdbc9692e5 100644 --- a/src/phoenix/server/api/types/Trace.py +++ b/src/phoenix/server/api/types/Trace.py @@ -71,16 +71,7 @@ async def root_span( self, info: Info[Context, None], ) -> Optional[Span]: - stmt = ( - select(models.Span) - .join(models.Trace) - .where(models.Trace.id == self.id_attr) - .options(contains_eager(models.Span.trace).load_only(models.Trace.trace_id)) - .where(models.Span.parent_id.is_(None)) - .limit(1) - ) - async with info.context.db() as session: - span = await session.scalar(stmt) + span = await info.context.data_loaders.trace_root_spans.load(self.id_attr) if span is None: return None return to_gql_span(span) diff --git a/src/phoenix/server/app.py b/src/phoenix/server/app.py index 7cfa36df2a..5307dcbaf2 100644 --- a/src/phoenix/server/app.py +++ b/src/phoenix/server/app.py @@ -91,12 +91,16 @@ MinStartOrMaxEndTimeDataLoader, ProjectByNameDataLoader, RecordCountDataLoader, + SessionIODataLoader, + SessionNumTracesDataLoader, + SessionTokenUsagesDataLoader, SpanAnnotationsDataLoader, SpanDatasetExamplesDataLoader, SpanDescendantsDataLoader, SpanProjectsDataLoader, TokenCountDataLoader, TraceByTraceIdsDataLoader, + TraceRootSpansDataLoader, UserRolesDataLoader, UsersDataLoader, ) @@ -555,6 +559,10 @@ def get_context() -> Context: db, cache_map=cache_for_dataloaders.record_count if cache_for_dataloaders else None, ), + session_first_inputs=SessionIODataLoader(db, "first_input"), + session_last_outputs=SessionIODataLoader(db, "last_output"), + session_num_traces=SessionNumTracesDataLoader(db), + session_token_usages=SessionTokenUsagesDataLoader(db), span_annotations=SpanAnnotationsDataLoader(db), span_dataset_examples=SpanDatasetExamplesDataLoader(db), span_descendants=SpanDescendantsDataLoader(db), @@ -564,6 +572,7 @@ def get_context() -> Context: cache_map=cache_for_dataloaders.token_count if cache_for_dataloaders else None, ), trace_by_trace_ids=TraceByTraceIdsDataLoader(db), + trace_root_spans=TraceRootSpansDataLoader(db), project_by_name=ProjectByNameDataLoader(db), users=UsersDataLoader(db), user_roles=UserRolesDataLoader(db), diff --git a/src/phoenix/trace/schemas.py b/src/phoenix/trace/schemas.py index 342c8fab82..41100da3ae 100644 --- a/src/phoenix/trace/schemas.py +++ b/src/phoenix/trace/schemas.py @@ -187,6 +187,22 @@ def _missing_(cls, v: Any) -> Optional["MimeType"]: return None if v else cls.TEXT +@dataclass(frozen=True) +class SpanIOValue: + value: str + mime_type: MimeType = MimeType.TEXT + + +@dataclass(frozen=True) +class TokenUsage: + prompt: int = 0 + completion: int = 0 + + def __post_init__(self) -> None: + assert self.prompt >= 0, "prompt must be non-negative" + assert self.completion >= 0, "completion must be non-negative" + + ATTRIBUTE_PREFIX = "attributes." CONTEXT_PREFIX = "context." COMPUTED_PREFIX = "__computed__." diff --git a/tests/unit/server/api/types/test_ProjectSession.py b/tests/unit/server/api/types/test_ProjectSession.py index 13a0a7939d..4769725a91 100644 --- a/tests/unit/server/api/types/test_ProjectSession.py +++ b/tests/unit/server/api/types/test_ProjectSession.py @@ -1,11 +1,10 @@ from datetime import datetime, timedelta, timezone -from typing import Any +from typing import Any, NamedTuple import httpx import pytest from faker import Faker from strawberry.relay import GlobalID -from typing_extensions import TypeAlias from phoenix.db import models from phoenix.server.api.types.ProjectSession import ProjectSession @@ -14,11 +13,12 @@ from ...._helpers import _add_project, _add_project_session, _add_span, _add_trace, _node -_Data: TypeAlias = tuple[ - list[models.ProjectSession], - list[models.Trace], - list[models.Project], -] + +class _Data(NamedTuple): + spans: list[models.Span] + traces: list[models.Trace] + project_sessions: list[models.ProjectSession] + projects: list[models.Project] class TestProjectSession: @@ -43,6 +43,7 @@ async def _data( ) -> _Data: project_sessions = [] traces = [] + spans = [] async with db() as session: project = await _add_project(session) start_time = datetime.now(timezone.utc) @@ -62,12 +63,14 @@ async def _data( start_time=start_time, ) ) - await _add_span( - session, - traces[-1], - attributes={"input": {"value": "123"}, "output": {"value": "321"}}, - cumulative_llm_token_count_prompt=1, - cumulative_llm_token_count_completion=2, + spans.append( + await _add_span( + session, + traces[-1], + attributes={"input": {"value": "123"}, "output": {"value": "321"}}, + cumulative_llm_token_count_prompt=1, + cumulative_llm_token_count_completion=2, + ) ) traces.append( await _add_trace( @@ -77,61 +80,77 @@ async def _data( start_time=start_time + timedelta(seconds=1), ) ) - await _add_span( - session, - traces[-1], - attributes={"input": {"value": "1234"}, "output": {"value": "4321"}}, - cumulative_llm_token_count_prompt=3, - cumulative_llm_token_count_completion=4, + spans.append( + await _add_span( + session, + traces[-1], + attributes={"input": {"value": "1234"}, "output": {"value": "4321"}}, + cumulative_llm_token_count_prompt=3, + cumulative_llm_token_count_completion=4, + ) ) project_sessions.append(await _add_project_session(session, project)) - return project_sessions, traces, [project] + return _Data( + spans=spans, + traces=traces, + project_sessions=project_sessions, + projects=[project], + ) async def test_session_user( self, _data: _Data, httpx_client: httpx.AsyncClient, ) -> None: - assert await self._node("sessionUser", _data[0][0], httpx_client) == "xyz" - assert await self._node("sessionUser", _data[0][1], httpx_client) is None + project_sessions = _data.project_sessions + field = "sessionUser" + assert await self._node(field, project_sessions[0], httpx_client) == "xyz" + assert await self._node(field, project_sessions[1], httpx_client) is None async def test_num_traces( self, _data: _Data, httpx_client: httpx.AsyncClient, ) -> None: - assert await self._node("numTraces", _data[0][0], httpx_client) == 2 + project_session = _data.project_sessions[0] + field = "numTraces" + assert await self._node(field, project_session, httpx_client) == 2 async def test_first_input( self, _data: _Data, httpx_client: httpx.AsyncClient, ) -> None: - assert await self._node( - "firstInput{value mimeType}", - _data[0][0], - httpx_client, - ) == {"value": "123", "mimeType": "text"} + project_session = _data.project_sessions[0] + field = "firstInput{value mimeType}" + assert await self._node(field, project_session, httpx_client) == { + "value": "123", + "mimeType": "text", + } async def test_last_output( self, _data: _Data, httpx_client: httpx.AsyncClient, ) -> None: - assert await self._node( - "lastOutput{value mimeType}", - _data[0][0], - httpx_client, - ) == {"value": "4321", "mimeType": "text"} + project_session = _data.project_sessions[0] + field = "lastOutput{value mimeType}" + assert await self._node(field, project_session, httpx_client) == { + "value": "4321", + "mimeType": "text", + } async def test_traces( self, _data: _Data, httpx_client: httpx.AsyncClient, ) -> None: - traces = await self._node("traces{edges{node{id}}}", _data[0][0], httpx_client) - assert {edge["node"]["id"] for edge in traces["edges"]} == { - str(GlobalID(Trace.__name__, str(trace.id))) for trace in _data[1] + project_session = _data.project_sessions[0] + field = "traces{edges{node{id traceId}}}" + traces = await self._node(field, project_session, httpx_client) + assert traces["edges"] + assert {(edge["node"]["id"], edge["node"]["traceId"]) for edge in traces["edges"]} == { + (str(GlobalID(Trace.__name__, str(trace.id))), trace.trace_id) for trace in _data.traces } async def test_token_usage( @@ -139,8 +158,10 @@ async def test_token_usage( _data: _Data, httpx_client: httpx.AsyncClient, ) -> None: - assert await self._node( - "tokenUsage{prompt completion total}", - _data[0][0], - httpx_client, - ) == {"prompt": 4, "completion": 6, "total": 10} + project_sessions = _data.project_sessions + field = "tokenUsage{prompt completion total}" + assert await self._node(field, project_sessions[0], httpx_client) == { + "prompt": 4, + "completion": 6, + "total": 10, + } diff --git a/tests/unit/server/api/types/test_Trace.py b/tests/unit/server/api/types/test_Trace.py index 1fb817d85f..689393e1e5 100644 --- a/tests/unit/server/api/types/test_Trace.py +++ b/tests/unit/server/api/types/test_Trace.py @@ -1,22 +1,23 @@ -from typing import Any +from typing import Any, NamedTuple import httpx import pytest from strawberry.relay import GlobalID -from typing_extensions import TypeAlias from phoenix.db import models from phoenix.server.api.types.ProjectSession import ProjectSession +from phoenix.server.api.types.Span import Span from phoenix.server.api.types.Trace import Trace from phoenix.server.types import DbSessionFactory -from ...._helpers import _add_project, _add_project_session, _add_trace, _node +from ...._helpers import _add_project, _add_project_session, _add_span, _add_trace, _node -_Data: TypeAlias = tuple[ - list[models.Trace], - list[models.ProjectSession], - list[models.Project], -] + +class _Data(NamedTuple): + spans: list[models.Span] + traces: list[models.Trace] + project_sessions: list[models.ProjectSession] + projects: list[models.Project] class TestTrace: @@ -36,21 +37,45 @@ async def _node( @pytest.fixture async def _data(self, db: DbSessionFactory) -> _Data: traces = [] + spans = [] async with db() as session: project = await _add_project(session) project_session = await _add_project_session(session, project) traces.append(await _add_trace(session, project)) traces.append(await _add_trace(session, project, project_session)) - return traces, [project_session], [project] + spans.append(await _add_span(session, traces[-1])) + spans.append(await _add_span(session, traces[-1], parent_span=spans[-1])) + return _Data( + spans=spans, + traces=traces, + project_sessions=[project_session], + projects=[project], + ) async def test_session( self, _data: _Data, httpx_client: httpx.AsyncClient, ) -> None: - traces = _data[0] - project_session = _data[1][0] - assert await self._node("session{id}", traces[0], httpx_client) is None - assert await self._node("session{id}", traces[1], httpx_client) == { - "id": str(GlobalID(ProjectSession.__name__, str(project_session.id))) + traces = _data.traces + project_session = _data.project_sessions[0] + field = "session{id sessionId}" + assert await self._node(field, traces[0], httpx_client) is None + assert await self._node(field, traces[1], httpx_client) == { + "id": str(GlobalID(ProjectSession.__name__, str(project_session.id))), + "sessionId": project_session.session_id, + } + + async def test_root_span( + self, + _data: _Data, + httpx_client: httpx.AsyncClient, + ) -> None: + traces = _data.traces + span = _data.spans[0] + field = "rootSpan{id name}" + assert await self._node(field, traces[0], httpx_client) is None + assert await self._node(field, traces[1], httpx_client) == { + "id": str(GlobalID(Span.__name__, str(span.id))), + "name": span.name, }