Skip to content

Commit 16ddd55

Browse files
committed
feat: recursive types
1 parent 407bb47 commit 16ddd55

File tree

6 files changed

+189
-0
lines changed

6 files changed

+189
-0
lines changed

src/replit_river/codegen/client.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,8 @@ class RiverConcreteType(BaseModel):
101101
items: "RiverType | None" = Field(default=None)
102102
const: str | int | None = Field(default=None)
103103
patternProperties: dict[str, "RiverType"] = Field(default_factory=lambda: dict())
104+
id_: str | None = Field(default=None, alias="$id")
105+
ref: str | None = Field(default=None, alias="$ref")
104106

105107

106108
class RiverUnionType(BaseModel):
@@ -158,7 +160,12 @@ def encode_type(
158160
base_model: str,
159161
in_module: list[ModuleName],
160162
permit_unknown_members: bool,
163+
type_registry: dict[str, TypeName] | None = None,
161164
) -> tuple[TypeExpression, list[ModuleName], list[FileContents], set[TypeName]]:
165+
# Registry to track $id -> TypeName mappings for resolving $ref
166+
if type_registry is None:
167+
type_registry = {}
168+
162169
def _make_open_union_type_expr(one_of: list[TypeExpression]) -> OpenUnionTypeExpr:
163170
if base_model == "RiverError":
164171
return OpenUnionTypeExpr(
@@ -175,6 +182,17 @@ def _make_open_union_type_expr(one_of: list[TypeExpression]) -> OpenUnionTypeExp
175182

176183
encoder_name: TypeName | None = None # defining this up here to placate mypy
177184
chunks: list[FileContents] = []
185+
186+
# Handle $ref - return a forward reference to the registered type
187+
if isinstance(type, RiverConcreteType) and type.ref is not None:
188+
ref_id = type.ref
189+
if ref_id in type_registry:
190+
# Use forward reference string for the type
191+
return (TypeName(f'"{type_registry[ref_id].value}"'), [], [], set())
192+
else:
193+
# Unknown ref, fall back to Any
194+
return (TypeName("Any"), [], [], set())
195+
178196
if isinstance(type, RiverNotType):
179197
return (NoneTypeExpr(), [], [], set())
180198
elif isinstance(type, RiverUnionType):
@@ -269,6 +287,7 @@ def flatten_union(tpe: RiverType) -> list[RiverType]:
269287
base_model,
270288
in_module,
271289
permit_unknown_members=permit_unknown_members,
290+
type_registry=type_registry,
272291
)
273292
one_of.append(type_name)
274293
chunks.extend(contents)
@@ -304,6 +323,7 @@ def flatten_union(tpe: RiverType) -> list[RiverType]:
304323
base_model,
305324
in_module,
306325
permit_unknown_members=permit_unknown_members,
326+
type_registry=type_registry,
307327
)
308328
one_of.append(type_name)
309329
chunks.extend(contents)
@@ -377,6 +397,7 @@ def {_field_name}(
377397
base_model,
378398
in_module,
379399
permit_unknown_members=permit_unknown_members,
400+
type_registry=type_registry,
380401
)
381402
any_of.append(type_name)
382403
chunks.extend(contents)
@@ -462,6 +483,7 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
462483
base_model,
463484
in_module,
464485
permit_unknown_members=permit_unknown_members,
486+
type_registry=type_registry,
465487
)
466488
elif isinstance(type, RiverConcreteType):
467489
typeddict_encoder = list[str]()
@@ -509,6 +531,7 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
509531
base_model,
510532
in_module,
511533
permit_unknown_members=permit_unknown_members,
534+
type_registry=type_registry,
512535
)
513536
typeddict_encoder.append("TODO: dstewart")
514537
return (ListTypeExpr(type_name), module_info, type_chunks, encoder_names)
@@ -523,10 +546,15 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
523546
base_model,
524547
in_module,
525548
permit_unknown_members=permit_unknown_members,
549+
type_registry=type_registry,
526550
)
527551
return (DictTypeExpr(type_name), module_info, type_chunks, encoder_names)
528552
assert type.type == "object", type.type
529553

554+
# Register $id for this type so $ref can resolve to it
555+
if type.id_ is not None:
556+
type_registry[type.id_] = prefix
557+
530558
current_chunks: list[str] = [
531559
f"class {render_literal_type(prefix)}({base_model}):"
532560
]
@@ -551,6 +579,7 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
551579
"BaseModel" if base_model == "RiverError" else base_model,
552580
in_module,
553581
permit_unknown_members=permit_unknown_members,
582+
type_registry=type_registry,
554583
)
555584
encoder_name = None
556585
chunks.extend(contents)
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Code generated by river.codegen. DO NOT EDIT.
2+
from pydantic import BaseModel
3+
from typing import Literal
4+
5+
import replit_river as river
6+
7+
8+
from .recursiveService import RecursiveserviceService
9+
10+
11+
class RecursiveClient:
12+
def __init__(self, client: river.Client[Literal[None]]):
13+
self.recursiveService = RecursiveserviceService(client)
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Code generated by river.codegen. DO NOT EDIT.
2+
from collections.abc import AsyncIterable, AsyncIterator
3+
from typing import Any
4+
import datetime
5+
6+
from pydantic import TypeAdapter
7+
8+
from replit_river.error_schema import RiverError, RiverErrorTypeAdapter
9+
import replit_river as river
10+
11+
12+
from .getTree import (
13+
GettreeInput,
14+
GettreeOutput,
15+
GettreeOutputTypeAdapter,
16+
encode_GettreeInput,
17+
)
18+
19+
20+
class RecursiveserviceService:
21+
def __init__(self, client: river.Client[Any]):
22+
self.client = client
23+
24+
async def getTree(
25+
self,
26+
input: GettreeInput,
27+
timeout: datetime.timedelta,
28+
) -> GettreeOutput:
29+
return await self.client.send_rpc(
30+
"recursiveService",
31+
"getTree",
32+
input,
33+
encode_GettreeInput,
34+
lambda x: GettreeOutputTypeAdapter.validate_python(
35+
x # type: ignore[arg-type]
36+
),
37+
lambda x: RiverErrorTypeAdapter.validate_python(
38+
x # type: ignore[arg-type]
39+
),
40+
timeout,
41+
)
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# Code generated by river.codegen. DO NOT EDIT.
2+
from collections.abc import AsyncIterable, AsyncIterator
3+
import datetime
4+
from typing import (
5+
Any,
6+
Literal,
7+
Mapping,
8+
NotRequired,
9+
TypedDict,
10+
)
11+
from typing_extensions import Annotated
12+
13+
from pydantic import BaseModel, Field, TypeAdapter, WrapValidator
14+
from replit_river.error_schema import RiverError
15+
from replit_river.client import (
16+
RiverUnknownError,
17+
translate_unknown_error,
18+
RiverUnknownValue,
19+
translate_unknown_value,
20+
)
21+
22+
import replit_river as river
23+
24+
25+
def encode_GettreeInput(
26+
x: "GettreeInput",
27+
) -> Any:
28+
return {
29+
k: v
30+
for (k, v) in (
31+
{
32+
"rootId": x.get("rootId"),
33+
}
34+
).items()
35+
if v is not None
36+
}
37+
38+
39+
class GettreeInput(TypedDict):
40+
rootId: str
41+
42+
43+
class GettreeOutput(BaseModel):
44+
children: list["GettreeOutput"]
45+
id: str
46+
name: str
47+
48+
49+
GettreeOutputTypeAdapter: TypeAdapter[GettreeOutput] = TypeAdapter(GettreeOutput)
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
from pytest_snapshot.plugin import Snapshot
2+
3+
from tests.fixtures.codegen_snapshot_fixtures import validate_codegen
4+
5+
6+
async def test_recursive_types(snapshot: Snapshot) -> None:
7+
"""Test that recursive types using $id/$ref are generated correctly."""
8+
validate_codegen(
9+
snapshot=snapshot,
10+
snapshot_dir="tests/v1/codegen/snapshot/snapshots",
11+
read_schema=lambda: open("tests/v1/codegen/types/recursive_schema.json"),
12+
target_path="test_recursive_types",
13+
client_name="RecursiveClient",
14+
protocol_version="v1.1",
15+
)
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
{
2+
"services": {
3+
"recursiveService": {
4+
"procedures": {
5+
"getTree": {
6+
"type": "rpc",
7+
"input": {
8+
"type": "object",
9+
"properties": {
10+
"rootId": {
11+
"type": "string"
12+
}
13+
},
14+
"required": ["rootId"]
15+
},
16+
"output": {
17+
"$id": "TreeNode",
18+
"type": "object",
19+
"properties": {
20+
"id": {
21+
"type": "string"
22+
},
23+
"name": {
24+
"type": "string"
25+
},
26+
"children": {
27+
"type": "array",
28+
"items": {
29+
"$ref": "TreeNode"
30+
}
31+
}
32+
},
33+
"required": ["id", "name", "children"]
34+
},
35+
"errors": {
36+
"not": {}
37+
}
38+
}
39+
}
40+
}
41+
}
42+
}

0 commit comments

Comments
 (0)