-
Notifications
You must be signed in to change notification settings - Fork 292
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
refactor(sessions): add dataloaders for session queries #5222
Changes from 5 commits
6b0bfea
993a809
9e5be85
14488e8
80e37ce
b35d76f
6ad6e7a
655eb53
679dddf
115e7cd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How does the query plan look with the row number/ filter on rank approach? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
from functools import cached_property | ||
from typing import Literal, Optional, cast | ||
|
||
from openinference.semconv.trace import SpanAttributes | ||
from sqlalchemy import Select, func, select | ||
from strawberry.dataloader import DataLoader | ||
from typing_extensions import TypeAlias, assert_never | ||
|
||
from phoenix.db import models | ||
from phoenix.server.types import DbSessionFactory | ||
from phoenix.trace.schemas import MimeType, SpanIOValue | ||
|
||
Key: TypeAlias = int | ||
Result: TypeAlias = Optional[SpanIOValue] | ||
|
||
Kind = Literal["first_input", "last_output"] | ||
|
||
|
||
class SessionFirstInputLastOutputsDataLoader(DataLoader[Key, Result]): | ||
RogerHYang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
def __init__(self, db: DbSessionFactory, kind: Kind) -> None: | ||
super().__init__(load_fn=self._load_fn) | ||
self._db = db | ||
self._kind = kind | ||
|
||
@cached_property | ||
def _subq(self) -> Select[tuple[int, str, str, int]]: | ||
stmt = ( | ||
select(models.Trace.project_session_rowid.label("id_")) | ||
.join_from(models.Span, models.Trace) | ||
.where(models.Span.parent_id.is_(None)) | ||
.where(models.Trace.project_session_rowid.isnot(None)) | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need to handle the case of multiple root spans in a single trace? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we're punting on that for now, since there' no good way to display them |
||
if self._kind == "first_input": | ||
stmt = stmt.add_columns( | ||
models.Span.attributes[INPUT_VALUE].label("value"), | ||
models.Span.attributes[INPUT_MIME_TYPE].label("mime_type"), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same non-blocking comment as before. |
||
func.row_number() | ||
.over( | ||
partition_by=models.Trace.project_session_rowid, | ||
order_by=models.Trace.start_time.asc(), | ||
) | ||
.label("rank"), | ||
) | ||
elif self._kind == "last_output": | ||
stmt = stmt.add_columns( | ||
models.Span.attributes[OUTPUT_VALUE].label("value"), | ||
models.Span.attributes[OUTPUT_MIME_TYPE].label("mime_type"), | ||
func.row_number() | ||
.over( | ||
partition_by=models.Trace.project_session_rowid, | ||
order_by=models.Trace.start_time.desc(), | ||
) | ||
.label("rank"), | ||
) | ||
else: | ||
assert_never(self._kind) | ||
return cast(Select[tuple[int, str, str, int]], stmt) | ||
|
||
def _stmt(self, *keys: Key) -> Select[tuple[int, str, str]]: | ||
subq = self._subq.where(models.Trace.project_session_rowid.in_(keys)).subquery() | ||
return select(subq.c.id_, subq.c.value, subq.c.mime_type).filter_by(rank=1) | ||
|
||
async def _load_fn(self, keys: list[Key]) -> list[Result]: | ||
async with self._db() as session: | ||
result: dict[Key, SpanIOValue] = { | ||
id_: SpanIOValue(value=value, mime_type=MimeType(mime_type)) | ||
async for id_, value, mime_type in await session.stream(self._stmt(*keys)) | ||
if id_ is not None | ||
} | ||
return [result.get(key) for key in keys] | ||
|
||
|
||
INPUT_VALUE = SpanAttributes.INPUT_VALUE.split(".") | ||
INPUT_MIME_TYPE = SpanAttributes.INPUT_MIME_TYPE.split(".") | ||
OUTPUT_VALUE = SpanAttributes.OUTPUT_VALUE.split(".") | ||
OUTPUT_MIME_TYPE = SpanAttributes.OUTPUT_MIME_TYPE.split(".") |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
from functools import cached_property | ||
from typing import Optional | ||
|
||
from sqlalchemy import Select, func, select | ||
from strawberry.dataloader import DataLoader | ||
from typing_extensions import TypeAlias | ||
|
||
from phoenix.db import models | ||
from phoenix.server.types import DbSessionFactory | ||
|
||
Key: TypeAlias = int | ||
Result: TypeAlias = int | ||
|
||
|
||
class SessionNumTracesDataLoader(DataLoader[Key, Result]): | ||
def __init__(self, db: DbSessionFactory) -> None: | ||
super().__init__(load_fn=self._load_fn) | ||
self._db = db | ||
|
||
@cached_property | ||
def _stmt(self) -> Select[tuple[Optional[int], int]]: | ||
RogerHYang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return ( | ||
select( | ||
models.Trace.project_session_rowid.label("id_"), | ||
func.count(models.Trace.id).label("value"), | ||
) | ||
.group_by(models.Trace.project_session_rowid) | ||
.where(models.Trace.project_session_rowid.isnot(None)) | ||
RogerHYang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) | ||
|
||
async def _load_fn(self, keys: list[Key]) -> list[Result]: | ||
stmt = self._stmt.where(models.Trace.project_session_rowid.in_(keys)) | ||
async with self._db() as session: | ||
result: dict[Key, int] = { | ||
id_: value async for id_, value in await session.stream(stmt) if id_ is not None | ||
} | ||
return [result.get(key, 0) for key in keys] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
from functools import cached_property | ||
from typing import Optional | ||
|
||
from sqlalchemy import Select, func, select | ||
from sqlalchemy.sql.functions import coalesce | ||
from strawberry.dataloader import DataLoader | ||
from typing_extensions import TypeAlias | ||
|
||
from phoenix.db import models | ||
from phoenix.server.types import DbSessionFactory | ||
from phoenix.trace.schemas import TokenUsage | ||
|
||
Key: TypeAlias = int | ||
Result: TypeAlias = TokenUsage | ||
|
||
|
||
class SessionTokenUsagesDataLoader(DataLoader[Key, Result]): | ||
def __init__(self, db: DbSessionFactory) -> None: | ||
super().__init__(load_fn=self._load_fn) | ||
self._db = db | ||
|
||
@cached_property | ||
def _stmt(self) -> Select[tuple[Optional[int], int, int]]: | ||
return ( | ||
select( | ||
models.Trace.project_session_rowid.label("id_"), | ||
func.sum(coalesce(models.Span.cumulative_llm_token_count_prompt, 0)).label( | ||
"prompt" | ||
), | ||
func.sum(coalesce(models.Span.cumulative_llm_token_count_completion, 0)).label( | ||
"completion" | ||
), | ||
) | ||
.join_from(models.Span, models.Trace) | ||
.where(models.Span.parent_id.is_(None)) | ||
.where(models.Trace.project_session_rowid.isnot(None)) | ||
.group_by(models.Trace.project_session_rowid) | ||
) | ||
|
||
async def _load_fn(self, keys: list[Key]) -> list[Result]: | ||
stmt = self._stmt.where(models.Trace.project_session_rowid.in_(keys)) | ||
async with self._db() as session: | ||
result: dict[Key, TokenUsage] = { | ||
id_: TokenUsage(prompt=prompt, completion=completion) | ||
async for id_, prompt, completion in await session.stream(stmt) | ||
if id_ is not None | ||
} | ||
return [result.get(key, TokenUsage()) for key in keys] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
from functools import cached_property | ||
from typing import List, Optional | ||
|
||
from sqlalchemy import Select, select | ||
from sqlalchemy.orm import contains_eager | ||
from strawberry.dataloader import DataLoader | ||
from typing_extensions import TypeAlias | ||
|
||
from phoenix.db import models | ||
from phoenix.server.types import DbSessionFactory | ||
|
||
Key: TypeAlias = int | ||
Result: TypeAlias = Optional[models.Span] | ||
|
||
|
||
class TraceRootSpansDataLoader(DataLoader[Key, Result]): | ||
def __init__(self, db: DbSessionFactory) -> None: | ||
super().__init__(load_fn=self._load_fn) | ||
self._db = db | ||
|
||
@cached_property | ||
def _stmt(self) -> Select[tuple[models.Span]]: | ||
return ( | ||
select(models.Span) | ||
.join(models.Trace) | ||
.where(models.Span.parent_id.is_(None)) | ||
.options(contains_eager(models.Span.trace).load_only(models.Trace.trace_id)) | ||
) | ||
|
||
async def _load_fn(self, keys: List[Key]) -> List[Result]: | ||
stmt = self._stmt.where(models.Trace.id.in_(keys)) | ||
async with self._db() as session: | ||
result: dict[Key, models.Span] = { | ||
span.trace_rowid: span async for span in await session.stream_scalars(stmt) | ||
} | ||
return [result.get(key) for key in keys] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this can be ambiguous, since it can imply a list of inputs per session, whereas first_input matches the graphql field, which is more explicit