Skip to content

Commit

Permalink
refactor(sessions): add dataloaders for session queries (#5222)
Browse files Browse the repository at this point in the history
  • Loading branch information
RogerHYang committed Oct 29, 2024
1 parent 579a855 commit fddd0d1
Show file tree
Hide file tree
Showing 14 changed files with 348 additions and 137 deletions.
19 changes: 10 additions & 9 deletions scripts/fixtures/multi-turn_chat_sessions.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"metadata": {},
"outputs": [],
"source": [
"%pip install -Uqqq datasets openinference-instrumentation-openai openai-responses openai tiktoken langchain langchain-openai llama-index llama-index-llms-openai faker mdgen"
"%pip install -Uqqq datasets openinference-semantic-conventions openinference-instrumentation-openai faker openai-responses openai tiktoken"
]
},
{
Expand All @@ -25,7 +25,6 @@
"import pandas as pd\n",
"from datasets import load_dataset\n",
"from faker import Faker\n",
"from mdgen import MarkdownPostProvider\n",
"from openai_responses import OpenAIMock\n",
"from openinference.instrumentation import using_session, using_user\n",
"from openinference.instrumentation.openai import OpenAIInstrumentor\n",
Expand All @@ -39,8 +38,7 @@
"import phoenix as px\n",
"from phoenix.trace.span_evaluations import SpanEvaluations\n",
"\n",
"fake = Faker(\"ja_JP\")\n",
"fake.add_provider(MarkdownPostProvider)"
"fake = Faker([\"ja_JP\", \"vi_VN\", \"ko_KR\", \"zh_CN\", \"th_TH\", \"bn_BD\"])"
]
},
{
Expand Down Expand Up @@ -104,7 +102,7 @@
" if p < 0.1:\n",
" return \":\" * randint(1, 5)\n",
" if p < 0.9:\n",
" return Faker([\"ja_JP\", \"vi_VN\", \"ko_KR\", \"zh_CN\"]).address()\n",
" return fake.address()\n",
" return int(abs(random()) * 1_000_000_000)\n",
"\n",
"\n",
Expand All @@ -113,15 +111,17 @@
" if p < 0.1:\n",
" return \":\" * randint(1, 5)\n",
" if p < 0.9:\n",
" return Faker([\"ja_JP\", \"vi_VN\", \"ko_KR\", \"zh_CN\"]).name()\n",
" return fake.name()\n",
" return int(abs(random()) * 1_000_000_000)\n",
"\n",
"\n",
"def export_spans():\n",
"def export_spans(prob_drop_root):\n",
" \"\"\"Export spans in random order for receiver testing\"\"\"\n",
" spans = list(in_memory_span_exporter.get_finished_spans())\n",
" shuffle(spans)\n",
" for span in spans:\n",
" if span.parent is None and random() < prob_drop_root:\n",
" continue\n",
" otlp_span_exporter.export([span])\n",
" in_memory_span_exporter.clear()\n",
" session_count = len({id_ for span in spans if (id_ := span.attributes.get(\"session.id\"))})\n",
Expand All @@ -147,7 +147,7 @@
" return\n",
" has_yielded = False\n",
" with tracer.start_as_current_span(\n",
" Faker(\"ja_JP\").kana_name(),\n",
" fake.city(),\n",
" attributes=dict(rand_span_kind()),\n",
" end_on_exit=False,\n",
" ) as root:\n",
Expand Down Expand Up @@ -185,6 +185,7 @@
"source": [
"session_count = randint(5, 10)\n",
"tree_complexity = 4 # set to 0 for single span under root\n",
"prob_drop_root = 0.0 # probability that a root span gets dropped\n",
"\n",
"\n",
"def simulate_openai():\n",
Expand Down Expand Up @@ -237,7 +238,7 @@
" simulate_openai()\n",
"finally:\n",
" OpenAIInstrumentor().uninstrument()\n",
"spans = export_spans()\n",
"spans = export_spans(prob_drop_root)\n",
"\n",
"# Annotate root spans\n",
"root_span_ids = pd.Series(\n",
Expand Down
9 changes: 9 additions & 0 deletions src/phoenix/server/api/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,16 @@
MinStartOrMaxEndTimeDataLoader,
ProjectByNameDataLoader,
RecordCountDataLoader,
SessionIODataLoader,
SessionNumTracesDataLoader,
SessionTokenUsagesDataLoader,
SpanAnnotationsDataLoader,
SpanDatasetExamplesDataLoader,
SpanDescendantsDataLoader,
SpanProjectsDataLoader,
TokenCountDataLoader,
TraceByTraceIdsDataLoader,
TraceRootSpansDataLoader,
UserRolesDataLoader,
UsersDataLoader,
)
Expand Down Expand Up @@ -68,12 +72,17 @@ class DataLoaders:
latency_ms_quantile: LatencyMsQuantileDataLoader
min_start_or_max_end_times: MinStartOrMaxEndTimeDataLoader
record_counts: RecordCountDataLoader
session_first_inputs: SessionIODataLoader
session_last_outputs: SessionIODataLoader
session_num_traces: SessionNumTracesDataLoader
session_token_usages: SessionTokenUsagesDataLoader
span_annotations: SpanAnnotationsDataLoader
span_dataset_examples: SpanDatasetExamplesDataLoader
span_descendants: SpanDescendantsDataLoader
span_projects: SpanProjectsDataLoader
token_counts: TokenCountDataLoader
trace_by_trace_ids: TraceByTraceIdsDataLoader
trace_root_spans: TraceRootSpansDataLoader
project_by_name: ProjectByNameDataLoader
users: UsersDataLoader
user_roles: UserRolesDataLoader
Expand Down
8 changes: 8 additions & 0 deletions src/phoenix/server/api/dataloaders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,16 @@
from .min_start_or_max_end_times import MinStartOrMaxEndTimeCache, MinStartOrMaxEndTimeDataLoader
from .project_by_name import ProjectByNameDataLoader
from .record_counts import RecordCountCache, RecordCountDataLoader
from .session_io import SessionIODataLoader
from .session_num_traces import SessionNumTracesDataLoader
from .session_token_usages import SessionTokenUsagesDataLoader
from .span_annotations import SpanAnnotationsDataLoader
from .span_dataset_examples import SpanDatasetExamplesDataLoader
from .span_descendants import SpanDescendantsDataLoader
from .span_projects import SpanProjectsDataLoader
from .token_counts import TokenCountCache, TokenCountDataLoader
from .trace_by_trace_ids import TraceByTraceIdsDataLoader
from .trace_root_spans import TraceRootSpansDataLoader
from .user_roles import UserRolesDataLoader
from .users import UsersDataLoader

Expand All @@ -45,11 +49,15 @@
"LatencyMsQuantileDataLoader",
"MinStartOrMaxEndTimeDataLoader",
"RecordCountDataLoader",
"SessionIODataLoader",
"SessionNumTracesDataLoader",
"SessionTokenUsagesDataLoader",
"SpanDatasetExamplesDataLoader",
"SpanDescendantsDataLoader",
"SpanProjectsDataLoader",
"TokenCountDataLoader",
"TraceByTraceIdsDataLoader",
"TraceRootSpansDataLoader",
"ProjectByNameDataLoader",
"SpanAnnotationsDataLoader",
"UsersDataLoader",
Expand Down
75 changes: 75 additions & 0 deletions src/phoenix/server/api/dataloaders/session_io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from functools import cached_property
from typing import Literal, Optional, cast

from openinference.semconv.trace import SpanAttributes
from sqlalchemy import Select, func, select
from strawberry.dataloader import DataLoader
from typing_extensions import TypeAlias, assert_never

from phoenix.db import models
from phoenix.server.types import DbSessionFactory
from phoenix.trace.schemas import MimeType, SpanIOValue

Key: TypeAlias = int
Result: TypeAlias = Optional[SpanIOValue]

Kind = Literal["first_input", "last_output"]


class SessionIODataLoader(DataLoader[Key, Result]):
def __init__(self, db: DbSessionFactory, kind: Kind) -> None:
super().__init__(load_fn=self._load_fn)
self._db = db
self._kind = kind

@cached_property
def _subq(self) -> Select[tuple[Optional[int], str, str, int]]:
stmt = (
select(models.Trace.project_session_rowid.label("id_"))
.join_from(models.Span, models.Trace)
.where(models.Span.parent_id.is_(None))
)
if self._kind == "first_input":
stmt = stmt.add_columns(
models.Span.attributes[INPUT_VALUE].label("value"),
models.Span.attributes[INPUT_MIME_TYPE].label("mime_type"),
func.row_number()
.over(
partition_by=models.Trace.project_session_rowid,
order_by=[models.Trace.start_time.asc(), models.Trace.id.asc()],
)
.label("rank"),
)
elif self._kind == "last_output":
stmt = stmt.add_columns(
models.Span.attributes[OUTPUT_VALUE].label("value"),
models.Span.attributes[OUTPUT_MIME_TYPE].label("mime_type"),
func.row_number()
.over(
partition_by=models.Trace.project_session_rowid,
order_by=[models.Trace.start_time.desc(), models.Trace.id.desc()],
)
.label("rank"),
)
else:
assert_never(self._kind)
return cast(Select[tuple[Optional[int], str, str, int]], stmt)

def _stmt(self, *keys: Key) -> Select[tuple[int, str, str]]:
subq = self._subq.where(models.Trace.project_session_rowid.in_(keys)).subquery()
return select(subq.c.id_, subq.c.value, subq.c.mime_type).filter_by(rank=1)

async def _load_fn(self, keys: list[Key]) -> list[Result]:
async with self._db() as session:
result: dict[Key, SpanIOValue] = {
id_: SpanIOValue(value=value, mime_type=MimeType(mime_type))
async for id_, value, mime_type in await session.stream(self._stmt(*keys))
if id_ is not None
}
return [result.get(key) for key in keys]


INPUT_VALUE = SpanAttributes.INPUT_VALUE.split(".")
INPUT_MIME_TYPE = SpanAttributes.INPUT_MIME_TYPE.split(".")
OUTPUT_VALUE = SpanAttributes.OUTPUT_VALUE.split(".")
OUTPUT_MIME_TYPE = SpanAttributes.OUTPUT_MIME_TYPE.split(".")
30 changes: 30 additions & 0 deletions src/phoenix/server/api/dataloaders/session_num_traces.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from sqlalchemy import func, select
from strawberry.dataloader import DataLoader
from typing_extensions import TypeAlias

from phoenix.db import models
from phoenix.server.types import DbSessionFactory

Key: TypeAlias = int
Result: TypeAlias = int


class SessionNumTracesDataLoader(DataLoader[Key, Result]):
def __init__(self, db: DbSessionFactory) -> None:
super().__init__(load_fn=self._load_fn)
self._db = db

async def _load_fn(self, keys: list[Key]) -> list[Result]:
stmt = (
select(
models.Trace.project_session_rowid.label("id_"),
func.count(models.Trace.id).label("value"),
)
.group_by(models.Trace.project_session_rowid)
.where(models.Trace.project_session_rowid.in_(keys))
)
async with self._db() as session:
result: dict[Key, int] = {
id_: value async for id_, value in await session.stream(stmt) if id_ is not None
}
return [result.get(key, 0) for key in keys]
41 changes: 41 additions & 0 deletions src/phoenix/server/api/dataloaders/session_token_usages.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from sqlalchemy import func, select
from sqlalchemy.sql.functions import coalesce
from strawberry.dataloader import DataLoader
from typing_extensions import TypeAlias

from phoenix.db import models
from phoenix.server.types import DbSessionFactory
from phoenix.trace.schemas import TokenUsage

Key: TypeAlias = int
Result: TypeAlias = TokenUsage


class SessionTokenUsagesDataLoader(DataLoader[Key, Result]):
def __init__(self, db: DbSessionFactory) -> None:
super().__init__(load_fn=self._load_fn)
self._db = db

async def _load_fn(self, keys: list[Key]) -> list[Result]:
stmt = (
select(
models.Trace.project_session_rowid.label("id_"),
func.sum(coalesce(models.Span.cumulative_llm_token_count_prompt, 0)).label(
"prompt"
),
func.sum(coalesce(models.Span.cumulative_llm_token_count_completion, 0)).label(
"completion"
),
)
.join_from(models.Span, models.Trace)
.where(models.Span.parent_id.is_(None))
.where(models.Trace.project_session_rowid.in_(keys))
.group_by(models.Trace.project_session_rowid)
)
async with self._db() as session:
result: dict[Key, TokenUsage] = {
id_: TokenUsage(prompt=prompt, completion=completion)
async for id_, prompt, completion in await session.stream(stmt)
if id_ is not None
}
return [result.get(key, TokenUsage()) for key in keys]
9 changes: 4 additions & 5 deletions src/phoenix/server/api/dataloaders/trace_by_trace_ids.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,7 @@
from phoenix.db import models
from phoenix.server.types import DbSessionFactory

TraceId: TypeAlias = str
Key: TypeAlias = TraceId
TraceRowId: TypeAlias = int
ProjectRowId: TypeAlias = int
Key: TypeAlias = str
Result: TypeAlias = Optional[models.Trace]


Expand All @@ -22,5 +19,7 @@ def __init__(self, db: DbSessionFactory) -> None:
async def _load_fn(self, keys: List[Key]) -> List[Result]:
stmt = select(models.Trace).where(models.Trace.trace_id.in_(keys))
async with self._db() as session:
result = {trace.trace_id: trace for trace in await session.scalars(stmt)}
result: dict[Key, models.Trace] = {
trace.trace_id: trace async for trace in await session.stream_scalars(stmt)
}
return [result.get(trace_id) for trace_id in keys]
32 changes: 32 additions & 0 deletions src/phoenix/server/api/dataloaders/trace_root_spans.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from typing import List, Optional

from sqlalchemy import select
from sqlalchemy.orm import contains_eager
from strawberry.dataloader import DataLoader
from typing_extensions import TypeAlias

from phoenix.db import models
from phoenix.server.types import DbSessionFactory

Key: TypeAlias = int
Result: TypeAlias = Optional[models.Span]


class TraceRootSpansDataLoader(DataLoader[Key, Result]):
def __init__(self, db: DbSessionFactory) -> None:
super().__init__(load_fn=self._load_fn)
self._db = db

async def _load_fn(self, keys: List[Key]) -> List[Result]:
stmt = (
select(models.Span)
.join(models.Trace)
.where(models.Span.parent_id.is_(None))
.where(models.Trace.id.in_(keys))
.options(contains_eager(models.Span.trace).load_only(models.Trace.trace_id))
)
async with self._db() as session:
result: dict[Key, models.Span] = {
span.trace_rowid: span async for span in await session.stream_scalars(stmt)
}
return [result.get(key) for key in keys]
Loading

0 comments on commit fddd0d1

Please sign in to comment.