Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(sessions): add dataloaders for session queries #5222

Merged
merged 10 commits into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 10 additions & 9 deletions scripts/fixtures/multi-turn_chat_sessions.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"metadata": {},
"outputs": [],
"source": [
"%pip install -Uqqq datasets openinference-instrumentation-openai openai-responses openai tiktoken langchain langchain-openai llama-index llama-index-llms-openai faker mdgen"
"%pip install -Uqqq datasets openinference-instrumentation-openai openai-responses openai tiktoken langchain langchain-openai llama-index llama-index-llms-openai faker"
]
},
{
Expand All @@ -25,7 +25,6 @@
"import pandas as pd\n",
"from datasets import load_dataset\n",
"from faker import Faker\n",
"from mdgen import MarkdownPostProvider\n",
"from openai_responses import OpenAIMock\n",
"from openinference.instrumentation import using_session, using_user\n",
"from openinference.instrumentation.openai import OpenAIInstrumentor\n",
Expand All @@ -39,8 +38,7 @@
"import phoenix as px\n",
"from phoenix.trace.span_evaluations import SpanEvaluations\n",
"\n",
"fake = Faker(\"ja_JP\")\n",
"fake.add_provider(MarkdownPostProvider)"
"fake = Faker([\"ja_JP\", \"vi_VN\", \"ko_KR\", \"zh_CN\", \"th_TH\", \"bn_BD\"])"
]
},
{
Expand Down Expand Up @@ -104,7 +102,7 @@
" if p < 0.1:\n",
" return \":\" * randint(1, 5)\n",
" if p < 0.9:\n",
" return Faker([\"ja_JP\", \"vi_VN\", \"ko_KR\", \"zh_CN\"]).address()\n",
" return fake.address()\n",
" return int(abs(random()) * 1_000_000_000)\n",
"\n",
"\n",
Expand All @@ -113,15 +111,17 @@
" if p < 0.1:\n",
" return \":\" * randint(1, 5)\n",
" if p < 0.9:\n",
" return Faker([\"ja_JP\", \"vi_VN\", \"ko_KR\", \"zh_CN\"]).name()\n",
" return fake.name()\n",
" return int(abs(random()) * 1_000_000_000)\n",
"\n",
"\n",
"def export_spans():\n",
"def export_spans(prob_drop_root):\n",
" \"\"\"Export spans in random order for receiver testing\"\"\"\n",
" spans = list(in_memory_span_exporter.get_finished_spans())\n",
" shuffle(spans)\n",
" for span in spans:\n",
" if span.parent is None and random() < prob_drop_root:\n",
" continue\n",
" otlp_span_exporter.export([span])\n",
" in_memory_span_exporter.clear()\n",
" session_count = len({id_ for span in spans if (id_ := span.attributes.get(\"session.id\"))})\n",
Expand All @@ -147,7 +147,7 @@
" return\n",
" has_yielded = False\n",
" with tracer.start_as_current_span(\n",
" Faker(\"ja_JP\").kana_name(),\n",
" fake.city(),\n",
" attributes=dict(rand_span_kind()),\n",
" end_on_exit=False,\n",
" ) as root:\n",
Expand Down Expand Up @@ -185,6 +185,7 @@
"source": [
"session_count = randint(5, 10)\n",
"tree_complexity = 4 # set to 0 for single span under root\n",
"prob_drop_root = 0.0 # probability that a root span gets dropped\n",
"\n",
"\n",
"def simulate_openai():\n",
Expand Down Expand Up @@ -237,7 +238,7 @@
" simulate_openai()\n",
"finally:\n",
" OpenAIInstrumentor().uninstrument()\n",
"spans = export_spans()\n",
"spans = export_spans(prob_drop_root)\n",
"\n",
"# Annotate root spans\n",
"root_span_ids = pd.Series(\n",
Expand Down
9 changes: 9 additions & 0 deletions src/phoenix/server/api/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,16 @@
MinStartOrMaxEndTimeDataLoader,
ProjectByNameDataLoader,
RecordCountDataLoader,
SessionFirstInputLastOutputsDataLoader,
SessionNumTracesDataLoader,
SessionTokenUsagesDataLoader,
SpanAnnotationsDataLoader,
SpanDatasetExamplesDataLoader,
SpanDescendantsDataLoader,
SpanProjectsDataLoader,
TokenCountDataLoader,
TraceByTraceIdsDataLoader,
TraceRootSpansDataLoader,
UserRolesDataLoader,
UsersDataLoader,
)
Expand Down Expand Up @@ -68,12 +72,17 @@ class DataLoaders:
latency_ms_quantile: LatencyMsQuantileDataLoader
min_start_or_max_end_times: MinStartOrMaxEndTimeDataLoader
record_counts: RecordCountDataLoader
session_first_inputs: SessionFirstInputLastOutputsDataLoader
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
session_first_inputs: SessionFirstInputLastOutputsDataLoader
session_inputs: SessionIODataLoader

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this can be ambiguous, since it can imply a list of inputs per session, whereas first_input matches the graphql field, which is more explicit

session_last_outputs: SessionFirstInputLastOutputsDataLoader
session_num_traces: SessionNumTracesDataLoader
session_token_usages: SessionTokenUsagesDataLoader
span_annotations: SpanAnnotationsDataLoader
span_dataset_examples: SpanDatasetExamplesDataLoader
span_descendants: SpanDescendantsDataLoader
span_projects: SpanProjectsDataLoader
token_counts: TokenCountDataLoader
trace_by_trace_ids: TraceByTraceIdsDataLoader
trace_root_spans: TraceRootSpansDataLoader
project_by_name: ProjectByNameDataLoader
users: UsersDataLoader
user_roles: UserRolesDataLoader
Expand Down
8 changes: 8 additions & 0 deletions src/phoenix/server/api/dataloaders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,16 @@
from .min_start_or_max_end_times import MinStartOrMaxEndTimeCache, MinStartOrMaxEndTimeDataLoader
from .project_by_name import ProjectByNameDataLoader
from .record_counts import RecordCountCache, RecordCountDataLoader
from .session_first_input_last_outputs import SessionFirstInputLastOutputsDataLoader
from .session_num_traces import SessionNumTracesDataLoader
from .session_token_usages import SessionTokenUsagesDataLoader
from .span_annotations import SpanAnnotationsDataLoader
from .span_dataset_examples import SpanDatasetExamplesDataLoader
from .span_descendants import SpanDescendantsDataLoader
from .span_projects import SpanProjectsDataLoader
from .token_counts import TokenCountCache, TokenCountDataLoader
from .trace_by_trace_ids import TraceByTraceIdsDataLoader
from .trace_root_spans import TraceRootSpansDataLoader
from .user_roles import UserRolesDataLoader
from .users import UsersDataLoader

Expand All @@ -45,11 +49,15 @@
"LatencyMsQuantileDataLoader",
"MinStartOrMaxEndTimeDataLoader",
"RecordCountDataLoader",
"SessionFirstInputLastOutputsDataLoader",
"SessionNumTracesDataLoader",
"SessionTokenUsagesDataLoader",
"SpanDatasetExamplesDataLoader",
"SpanDescendantsDataLoader",
"SpanProjectsDataLoader",
"TokenCountDataLoader",
"TraceByTraceIdsDataLoader",
"TraceRootSpansDataLoader",
"ProjectByNameDataLoader",
"SpanAnnotationsDataLoader",
"UsersDataLoader",
Expand Down
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does the query plan look with the row number/ filter on rank approach?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here's the query plan from postgres

Screenshot 2024-10-29 at 3 19 35 PM

Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from functools import cached_property
from typing import Literal, Optional, cast

from openinference.semconv.trace import SpanAttributes
from sqlalchemy import Select, func, select
from strawberry.dataloader import DataLoader
from typing_extensions import TypeAlias, assert_never

from phoenix.db import models
from phoenix.server.types import DbSessionFactory
from phoenix.trace.schemas import MimeType, SpanIOValue

Key: TypeAlias = int
Result: TypeAlias = Optional[SpanIOValue]

Kind = Literal["first_input", "last_output"]


class SessionFirstInputLastOutputsDataLoader(DataLoader[Key, Result]):
RogerHYang marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, db: DbSessionFactory, kind: Kind) -> None:
super().__init__(load_fn=self._load_fn)
self._db = db
self._kind = kind

@cached_property
def _subq(self) -> Select[tuple[int, str, str, int]]:
stmt = (
select(models.Trace.project_session_rowid.label("id_"))
.join_from(models.Span, models.Trace)
.where(models.Span.parent_id.is_(None))
.where(models.Trace.project_session_rowid.isnot(None))
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to handle the case of multiple root spans in a single trace?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we're punting on that for now, since there' no good way to display them

if self._kind == "first_input":
stmt = stmt.add_columns(
models.Span.attributes[INPUT_VALUE].label("value"),
models.Span.attributes[INPUT_MIME_TYPE].label("mime_type"),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same non-blocking comment as before.

func.row_number()
.over(
partition_by=models.Trace.project_session_rowid,
order_by=models.Trace.start_time.asc(),
)
.label("rank"),
)
elif self._kind == "last_output":
stmt = stmt.add_columns(
models.Span.attributes[OUTPUT_VALUE].label("value"),
models.Span.attributes[OUTPUT_MIME_TYPE].label("mime_type"),
func.row_number()
.over(
partition_by=models.Trace.project_session_rowid,
order_by=models.Trace.start_time.desc(),
)
.label("rank"),
)
else:
assert_never(self._kind)
return cast(Select[tuple[int, str, str, int]], stmt)

def _stmt(self, *keys: Key) -> Select[tuple[int, str, str]]:
subq = self._subq.where(models.Trace.project_session_rowid.in_(keys)).subquery()
return select(subq.c.id_, subq.c.value, subq.c.mime_type).filter_by(rank=1)

async def _load_fn(self, keys: list[Key]) -> list[Result]:
async with self._db() as session:
result: dict[Key, SpanIOValue] = {
id_: SpanIOValue(value=value, mime_type=MimeType(mime_type))
async for id_, value, mime_type in await session.stream(self._stmt(*keys))
if id_ is not None
}
return [result.get(key) for key in keys]


INPUT_VALUE = SpanAttributes.INPUT_VALUE.split(".")
INPUT_MIME_TYPE = SpanAttributes.INPUT_MIME_TYPE.split(".")
OUTPUT_VALUE = SpanAttributes.OUTPUT_VALUE.split(".")
OUTPUT_MIME_TYPE = SpanAttributes.OUTPUT_MIME_TYPE.split(".")
37 changes: 37 additions & 0 deletions src/phoenix/server/api/dataloaders/session_num_traces.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from functools import cached_property
from typing import Optional

from sqlalchemy import Select, func, select
from strawberry.dataloader import DataLoader
from typing_extensions import TypeAlias

from phoenix.db import models
from phoenix.server.types import DbSessionFactory

Key: TypeAlias = int
Result: TypeAlias = int


class SessionNumTracesDataLoader(DataLoader[Key, Result]):
def __init__(self, db: DbSessionFactory) -> None:
super().__init__(load_fn=self._load_fn)
self._db = db

@cached_property
def _stmt(self) -> Select[tuple[Optional[int], int]]:
RogerHYang marked this conversation as resolved.
Show resolved Hide resolved
return (
select(
models.Trace.project_session_rowid.label("id_"),
func.count(models.Trace.id).label("value"),
)
.group_by(models.Trace.project_session_rowid)
.where(models.Trace.project_session_rowid.isnot(None))
RogerHYang marked this conversation as resolved.
Show resolved Hide resolved
)

async def _load_fn(self, keys: list[Key]) -> list[Result]:
stmt = self._stmt.where(models.Trace.project_session_rowid.in_(keys))
async with self._db() as session:
result: dict[Key, int] = {
id_: value async for id_, value in await session.stream(stmt) if id_ is not None
}
return [result.get(key, 0) for key in keys]
48 changes: 48 additions & 0 deletions src/phoenix/server/api/dataloaders/session_token_usages.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from functools import cached_property
from typing import Optional

from sqlalchemy import Select, func, select
from sqlalchemy.sql.functions import coalesce
from strawberry.dataloader import DataLoader
from typing_extensions import TypeAlias

from phoenix.db import models
from phoenix.server.types import DbSessionFactory
from phoenix.trace.schemas import TokenUsage

Key: TypeAlias = int
Result: TypeAlias = TokenUsage


class SessionTokenUsagesDataLoader(DataLoader[Key, Result]):
def __init__(self, db: DbSessionFactory) -> None:
super().__init__(load_fn=self._load_fn)
self._db = db

@cached_property
def _stmt(self) -> Select[tuple[Optional[int], int, int]]:
return (
select(
models.Trace.project_session_rowid.label("id_"),
func.sum(coalesce(models.Span.cumulative_llm_token_count_prompt, 0)).label(
"prompt"
),
func.sum(coalesce(models.Span.cumulative_llm_token_count_completion, 0)).label(
"completion"
),
)
.join_from(models.Span, models.Trace)
.where(models.Span.parent_id.is_(None))
.where(models.Trace.project_session_rowid.isnot(None))
.group_by(models.Trace.project_session_rowid)
)

async def _load_fn(self, keys: list[Key]) -> list[Result]:
stmt = self._stmt.where(models.Trace.project_session_rowid.in_(keys))
async with self._db() as session:
result: dict[Key, TokenUsage] = {
id_: TokenUsage(prompt=prompt, completion=completion)
async for id_, prompt, completion in await session.stream(stmt)
if id_ is not None
}
return [result.get(key, TokenUsage()) for key in keys]
9 changes: 4 additions & 5 deletions src/phoenix/server/api/dataloaders/trace_by_trace_ids.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,7 @@
from phoenix.db import models
from phoenix.server.types import DbSessionFactory

TraceId: TypeAlias = str
Key: TypeAlias = TraceId
TraceRowId: TypeAlias = int
ProjectRowId: TypeAlias = int
Key: TypeAlias = str
Result: TypeAlias = Optional[models.Trace]


Expand All @@ -22,5 +19,7 @@ def __init__(self, db: DbSessionFactory) -> None:
async def _load_fn(self, keys: List[Key]) -> List[Result]:
stmt = select(models.Trace).where(models.Trace.trace_id.in_(keys))
async with self._db() as session:
result = {trace.trace_id: trace for trace in await session.scalars(stmt)}
result: dict[Key, models.Trace] = {
trace.trace_id: trace async for trace in await session.stream_scalars(stmt)
}
return [result.get(trace_id) for trace_id in keys]
36 changes: 36 additions & 0 deletions src/phoenix/server/api/dataloaders/trace_root_spans.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from functools import cached_property
from typing import List, Optional

from sqlalchemy import Select, select
from sqlalchemy.orm import contains_eager
from strawberry.dataloader import DataLoader
from typing_extensions import TypeAlias

from phoenix.db import models
from phoenix.server.types import DbSessionFactory

Key: TypeAlias = int
Result: TypeAlias = Optional[models.Span]


class TraceRootSpansDataLoader(DataLoader[Key, Result]):
def __init__(self, db: DbSessionFactory) -> None:
super().__init__(load_fn=self._load_fn)
self._db = db

@cached_property
def _stmt(self) -> Select[tuple[models.Span]]:
return (
select(models.Span)
.join(models.Trace)
.where(models.Span.parent_id.is_(None))
.options(contains_eager(models.Span.trace).load_only(models.Trace.trace_id))
)

async def _load_fn(self, keys: List[Key]) -> List[Result]:
stmt = self._stmt.where(models.Trace.id.in_(keys))
async with self._db() as session:
result: dict[Key, models.Span] = {
span.trace_rowid: span async for span in await session.stream_scalars(stmt)
}
return [result.get(key) for key in keys]
Loading
Loading