-
Notifications
You must be signed in to change notification settings - Fork 292
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor(sessions): add dataloaders for session queries (#5222)
- Loading branch information
1 parent
579a855
commit fddd0d1
Showing
14 changed files
with
348 additions
and
137 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
from functools import cached_property | ||
from typing import Literal, Optional, cast | ||
|
||
from openinference.semconv.trace import SpanAttributes | ||
from sqlalchemy import Select, func, select | ||
from strawberry.dataloader import DataLoader | ||
from typing_extensions import TypeAlias, assert_never | ||
|
||
from phoenix.db import models | ||
from phoenix.server.types import DbSessionFactory | ||
from phoenix.trace.schemas import MimeType, SpanIOValue | ||
|
||
Key: TypeAlias = int | ||
Result: TypeAlias = Optional[SpanIOValue] | ||
|
||
Kind = Literal["first_input", "last_output"] | ||
|
||
|
||
class SessionIODataLoader(DataLoader[Key, Result]): | ||
def __init__(self, db: DbSessionFactory, kind: Kind) -> None: | ||
super().__init__(load_fn=self._load_fn) | ||
self._db = db | ||
self._kind = kind | ||
|
||
@cached_property | ||
def _subq(self) -> Select[tuple[Optional[int], str, str, int]]: | ||
stmt = ( | ||
select(models.Trace.project_session_rowid.label("id_")) | ||
.join_from(models.Span, models.Trace) | ||
.where(models.Span.parent_id.is_(None)) | ||
) | ||
if self._kind == "first_input": | ||
stmt = stmt.add_columns( | ||
models.Span.attributes[INPUT_VALUE].label("value"), | ||
models.Span.attributes[INPUT_MIME_TYPE].label("mime_type"), | ||
func.row_number() | ||
.over( | ||
partition_by=models.Trace.project_session_rowid, | ||
order_by=[models.Trace.start_time.asc(), models.Trace.id.asc()], | ||
) | ||
.label("rank"), | ||
) | ||
elif self._kind == "last_output": | ||
stmt = stmt.add_columns( | ||
models.Span.attributes[OUTPUT_VALUE].label("value"), | ||
models.Span.attributes[OUTPUT_MIME_TYPE].label("mime_type"), | ||
func.row_number() | ||
.over( | ||
partition_by=models.Trace.project_session_rowid, | ||
order_by=[models.Trace.start_time.desc(), models.Trace.id.desc()], | ||
) | ||
.label("rank"), | ||
) | ||
else: | ||
assert_never(self._kind) | ||
return cast(Select[tuple[Optional[int], str, str, int]], stmt) | ||
|
||
def _stmt(self, *keys: Key) -> Select[tuple[int, str, str]]: | ||
subq = self._subq.where(models.Trace.project_session_rowid.in_(keys)).subquery() | ||
return select(subq.c.id_, subq.c.value, subq.c.mime_type).filter_by(rank=1) | ||
|
||
async def _load_fn(self, keys: list[Key]) -> list[Result]: | ||
async with self._db() as session: | ||
result: dict[Key, SpanIOValue] = { | ||
id_: SpanIOValue(value=value, mime_type=MimeType(mime_type)) | ||
async for id_, value, mime_type in await session.stream(self._stmt(*keys)) | ||
if id_ is not None | ||
} | ||
return [result.get(key) for key in keys] | ||
|
||
|
||
INPUT_VALUE = SpanAttributes.INPUT_VALUE.split(".") | ||
INPUT_MIME_TYPE = SpanAttributes.INPUT_MIME_TYPE.split(".") | ||
OUTPUT_VALUE = SpanAttributes.OUTPUT_VALUE.split(".") | ||
OUTPUT_MIME_TYPE = SpanAttributes.OUTPUT_MIME_TYPE.split(".") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
from sqlalchemy import func, select | ||
from strawberry.dataloader import DataLoader | ||
from typing_extensions import TypeAlias | ||
|
||
from phoenix.db import models | ||
from phoenix.server.types import DbSessionFactory | ||
|
||
Key: TypeAlias = int | ||
Result: TypeAlias = int | ||
|
||
|
||
class SessionNumTracesDataLoader(DataLoader[Key, Result]): | ||
def __init__(self, db: DbSessionFactory) -> None: | ||
super().__init__(load_fn=self._load_fn) | ||
self._db = db | ||
|
||
async def _load_fn(self, keys: list[Key]) -> list[Result]: | ||
stmt = ( | ||
select( | ||
models.Trace.project_session_rowid.label("id_"), | ||
func.count(models.Trace.id).label("value"), | ||
) | ||
.group_by(models.Trace.project_session_rowid) | ||
.where(models.Trace.project_session_rowid.in_(keys)) | ||
) | ||
async with self._db() as session: | ||
result: dict[Key, int] = { | ||
id_: value async for id_, value in await session.stream(stmt) if id_ is not None | ||
} | ||
return [result.get(key, 0) for key in keys] |
41 changes: 41 additions & 0 deletions
41
src/phoenix/server/api/dataloaders/session_token_usages.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
from sqlalchemy import func, select | ||
from sqlalchemy.sql.functions import coalesce | ||
from strawberry.dataloader import DataLoader | ||
from typing_extensions import TypeAlias | ||
|
||
from phoenix.db import models | ||
from phoenix.server.types import DbSessionFactory | ||
from phoenix.trace.schemas import TokenUsage | ||
|
||
Key: TypeAlias = int | ||
Result: TypeAlias = TokenUsage | ||
|
||
|
||
class SessionTokenUsagesDataLoader(DataLoader[Key, Result]): | ||
def __init__(self, db: DbSessionFactory) -> None: | ||
super().__init__(load_fn=self._load_fn) | ||
self._db = db | ||
|
||
async def _load_fn(self, keys: list[Key]) -> list[Result]: | ||
stmt = ( | ||
select( | ||
models.Trace.project_session_rowid.label("id_"), | ||
func.sum(coalesce(models.Span.cumulative_llm_token_count_prompt, 0)).label( | ||
"prompt" | ||
), | ||
func.sum(coalesce(models.Span.cumulative_llm_token_count_completion, 0)).label( | ||
"completion" | ||
), | ||
) | ||
.join_from(models.Span, models.Trace) | ||
.where(models.Span.parent_id.is_(None)) | ||
.where(models.Trace.project_session_rowid.in_(keys)) | ||
.group_by(models.Trace.project_session_rowid) | ||
) | ||
async with self._db() as session: | ||
result: dict[Key, TokenUsage] = { | ||
id_: TokenUsage(prompt=prompt, completion=completion) | ||
async for id_, prompt, completion in await session.stream(stmt) | ||
if id_ is not None | ||
} | ||
return [result.get(key, TokenUsage()) for key in keys] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
from typing import List, Optional | ||
|
||
from sqlalchemy import select | ||
from sqlalchemy.orm import contains_eager | ||
from strawberry.dataloader import DataLoader | ||
from typing_extensions import TypeAlias | ||
|
||
from phoenix.db import models | ||
from phoenix.server.types import DbSessionFactory | ||
|
||
Key: TypeAlias = int | ||
Result: TypeAlias = Optional[models.Span] | ||
|
||
|
||
class TraceRootSpansDataLoader(DataLoader[Key, Result]): | ||
def __init__(self, db: DbSessionFactory) -> None: | ||
super().__init__(load_fn=self._load_fn) | ||
self._db = db | ||
|
||
async def _load_fn(self, keys: List[Key]) -> List[Result]: | ||
stmt = ( | ||
select(models.Span) | ||
.join(models.Trace) | ||
.where(models.Span.parent_id.is_(None)) | ||
.where(models.Trace.id.in_(keys)) | ||
.options(contains_eager(models.Span.trace).load_only(models.Trace.trace_id)) | ||
) | ||
async with self._db() as session: | ||
result: dict[Key, models.Span] = { | ||
span.trace_rowid: span async for span in await session.stream_scalars(stmt) | ||
} | ||
return [result.get(key) for key in keys] |
Oops, something went wrong.