Skip to content

Commit 2526fc7

Browse files
Fix mixed anyOf unions
1 parent 5d8d755 commit 2526fc7

File tree

6 files changed

+333
-23
lines changed

6 files changed

+333
-23
lines changed

src/replit_river/codegen/client.py

Lines changed: 53 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,8 @@ def {_field_name}(
389389
type = original_type
390390
any_of: list[TypeExpression] = []
391391

392-
typeddict_encoder = []
392+
# Collect (type_check, encoder_expr) pairs for building ternary chain
393+
encoder_parts: list[tuple[str | None, str]] = []
393394
for i, t in enumerate(type.anyOf):
394395
type_name, _, contents, _ = encode_type(
395396
t,
@@ -403,34 +404,63 @@ def {_field_name}(
403404
chunks.extend(contents)
404405
if isinstance(t, RiverConcreteType):
405406
if t.type == "string":
406-
typeddict_encoder.extend(["x", " if isinstance(x, str) else "])
407-
else:
408-
# TODO(dstewart): This structure changed since we were incorrectly
409-
# leaking ListTypeExprs into codegen. This generated
410-
# code is probably wrong.
407+
encoder_parts.append(("isinstance(x, str)", "x"))
408+
elif t.type == "array":
411409
match type_name:
412410
case ListTypeExpr(inner_type_name):
413-
typeddict_encoder.append(
414-
f"encode_{render_literal_type(inner_type_name)}(x)"
411+
# Primitives don't need encoding
412+
inner_type_str = render_literal_type(inner_type_name)
413+
if inner_type_str in ("str", "int", "float", "bool", "Any"):
414+
encoder_parts.append(("isinstance(x, list)", "list(x)"))
415+
else:
416+
encoder_parts.append(
417+
(
418+
"isinstance(x, list)",
419+
f"[encode_{inner_type_str}(y) for y in x]",
420+
)
421+
)
422+
case _:
423+
encoder_parts.append(("isinstance(x, list)", "list(x)"))
424+
elif t.type == "object":
425+
match type_name:
426+
case TypeName(value):
427+
encoder_parts.append(
428+
("isinstance(x, dict)", f"encode_{value}(x)")
415429
)
430+
case _:
431+
encoder_parts.append(("isinstance(x, dict)", "dict(x)"))
432+
elif t.type in ("number", "integer"):
433+
match type_name:
416434
case LiteralTypeExpr(const):
417-
typeddict_encoder.append(repr(const))
435+
encoder_parts.append((f"x == {repr(const)}", repr(const)))
436+
case _:
437+
encoder_parts.append(("isinstance(x, (int, float))", "x"))
438+
elif t.type == "boolean":
439+
encoder_parts.append(("isinstance(x, bool)", "x"))
440+
elif t.type == "null" or t.type == "undefined":
441+
encoder_parts.append(("x is None", "None"))
442+
else:
443+
# Fallback for other types
444+
match type_name:
418445
case TypeName(value):
419-
typeddict_encoder.append(f"encode_{value}(x)")
446+
encoder_parts.append((None, f"encode_{value}(x)"))
447+
case LiteralTypeExpr(const):
448+
encoder_parts.append((None, repr(const)))
420449
case NoneTypeExpr():
421-
typeddict_encoder.append("None")
422-
case other:
423-
_o2: (
424-
DictTypeExpr
425-
| OpenUnionTypeExpr
426-
| UnionTypeExpr
427-
| LiteralType
428-
) = other
429-
raise ValueError(
430-
f"What does it mean to have {
431-
render_type_expr(_o2)
432-
} here?"
433-
)
450+
encoder_parts.append((None, "None"))
451+
case _:
452+
encoder_parts.append((None, "x"))
453+
454+
# Build the ternary chain from encoder_parts
455+
typeddict_encoder: list[str] = []
456+
for i, (type_check, encoder_expr) in enumerate(encoder_parts):
457+
is_last = i == len(encoder_parts) - 1
458+
if is_last or type_check is None:
459+
# Last item or no type check - just the expression
460+
typeddict_encoder.append(encoder_expr)
461+
else:
462+
# Add expression with type check
463+
typeddict_encoder.append(f"{encoder_expr} if {type_check} else")
434464
if permit_unknown_members:
435465
union = _make_open_union_type_expr(any_of)
436466
else:
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Code generated by river.codegen. DO NOT EDIT.
2+
from pydantic import BaseModel
3+
from typing import Literal
4+
5+
import replit_river as river
6+
7+
8+
from .test_service import Test_ServiceService
9+
10+
11+
class AnyOfMixedClient:
12+
def __init__(self, client: river.Client[Literal[None]]):
13+
self.test_service = Test_ServiceService(client)
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# Code generated by river.codegen. DO NOT EDIT.
2+
from collections.abc import AsyncIterable, AsyncIterator
3+
from typing import Any
4+
import datetime
5+
6+
from pydantic import TypeAdapter
7+
8+
from replit_river.error_schema import RiverError, RiverErrorTypeAdapter
9+
import replit_river as river
10+
11+
12+
from .anyof_mixed_method import (
13+
Anyof_Mixed_MethodInput,
14+
Anyof_Mixed_MethodOutput,
15+
Anyof_Mixed_MethodOutputTypeAdapter,
16+
encode_Anyof_Mixed_MethodInput,
17+
encode_Anyof_Mixed_MethodInputNumber_Or_String,
18+
encode_Anyof_Mixed_MethodInputRun_Command,
19+
encode_Anyof_Mixed_MethodInputValue_Or_Null,
20+
)
21+
22+
23+
class Test_ServiceService:
24+
def __init__(self, client: river.Client[Any]):
25+
self.client = client
26+
27+
async def anyof_mixed_method(
28+
self,
29+
input: Anyof_Mixed_MethodInput,
30+
timeout: datetime.timedelta,
31+
) -> Anyof_Mixed_MethodOutput:
32+
return await self.client.send_rpc(
33+
"test_service",
34+
"anyof_mixed_method",
35+
input,
36+
encode_Anyof_Mixed_MethodInput,
37+
lambda x: Anyof_Mixed_MethodOutputTypeAdapter.validate_python(
38+
x # type: ignore[arg-type]
39+
),
40+
lambda x: RiverErrorTypeAdapter.validate_python(
41+
x # type: ignore[arg-type]
42+
),
43+
timeout,
44+
)
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
# Code generated by river.codegen. DO NOT EDIT.
2+
from collections.abc import AsyncIterable, AsyncIterator
3+
import datetime
4+
from typing import (
5+
Any,
6+
Literal,
7+
Mapping,
8+
NotRequired,
9+
TypedDict,
10+
)
11+
from typing_extensions import Annotated
12+
13+
from pydantic import BaseModel, Field, TypeAdapter, WrapValidator
14+
from replit_river.error_schema import RiverError
15+
from replit_river.client import (
16+
RiverUnknownError,
17+
translate_unknown_error,
18+
RiverUnknownValue,
19+
translate_unknown_value,
20+
)
21+
22+
import replit_river as river
23+
24+
25+
Anyof_Mixed_MethodInputNumber_Or_String = float | str
26+
27+
28+
def encode_Anyof_Mixed_MethodInputNumber_Or_String(
29+
x: "Anyof_Mixed_MethodInputNumber_Or_String",
30+
) -> Any:
31+
return x
32+
33+
34+
def encode_Anyof_Mixed_MethodInputRun_CommandAnyOf_0(
35+
x: "Anyof_Mixed_MethodInputRun_CommandAnyOf_0",
36+
) -> Any:
37+
return {
38+
k: v
39+
for (k, v) in (
40+
{
41+
"args": x.get("args"),
42+
"env": x.get("env"),
43+
}
44+
).items()
45+
if v is not None
46+
}
47+
48+
49+
class Anyof_Mixed_MethodInputRun_CommandAnyOf_0(TypedDict):
50+
args: list[str]
51+
env: NotRequired[dict[str, str] | None]
52+
53+
54+
Anyof_Mixed_MethodInputRun_Command = (
55+
Anyof_Mixed_MethodInputRun_CommandAnyOf_0 | str | list[str]
56+
)
57+
58+
59+
def encode_Anyof_Mixed_MethodInputRun_Command(
60+
x: "Anyof_Mixed_MethodInputRun_Command",
61+
) -> Any:
62+
return (
63+
encode_Anyof_Mixed_MethodInputRun_CommandAnyOf_0(x)
64+
if isinstance(x, dict)
65+
else x
66+
if isinstance(x, str)
67+
else list(x)
68+
)
69+
70+
71+
Anyof_Mixed_MethodInputValue_Or_Null = str | None
72+
73+
74+
def encode_Anyof_Mixed_MethodInputValue_Or_Null(
75+
x: "Anyof_Mixed_MethodInputValue_Or_Null",
76+
) -> Any:
77+
return x
78+
79+
80+
def encode_Anyof_Mixed_MethodInput(
81+
x: "Anyof_Mixed_MethodInput",
82+
) -> Any:
83+
return {
84+
k: v
85+
for (k, v) in (
86+
{
87+
"number_or_string": encode_Anyof_Mixed_MethodInputNumber_Or_String(
88+
x["number_or_string"]
89+
)
90+
if "number_or_string" in x and x["number_or_string"]
91+
else None,
92+
"run_command": encode_Anyof_Mixed_MethodInputRun_Command(
93+
x["run_command"]
94+
),
95+
"value_or_null": encode_Anyof_Mixed_MethodInputValue_Or_Null(
96+
x["value_or_null"]
97+
)
98+
if "value_or_null" in x and x["value_or_null"]
99+
else None,
100+
}
101+
).items()
102+
if v is not None
103+
}
104+
105+
106+
class Anyof_Mixed_MethodInput(TypedDict):
107+
number_or_string: NotRequired[Anyof_Mixed_MethodInputNumber_Or_String | None]
108+
run_command: Anyof_Mixed_MethodInputRun_Command
109+
value_or_null: NotRequired[Anyof_Mixed_MethodInputValue_Or_Null | None]
110+
111+
112+
class Anyof_Mixed_MethodOutput(BaseModel):
113+
success: bool
114+
115+
116+
Anyof_Mixed_MethodOutputTypeAdapter: TypeAdapter[Anyof_Mixed_MethodOutput] = (
117+
TypeAdapter(Anyof_Mixed_MethodOutput)
118+
)
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
from pytest_snapshot.plugin import Snapshot
2+
3+
from tests.fixtures.codegen_snapshot_fixtures import validate_codegen
4+
5+
6+
async def test_anyof_mixed_types(snapshot: Snapshot) -> None:
7+
"""Test codegen for anyOf unions with mixed types (object, string, array).
8+
9+
This tests the fix for the bug where non-discriminated anyOf unions
10+
with mixed types like [object, string, array] would generate malformed
11+
Python code with broken ternary expressions.
12+
"""
13+
validate_codegen(
14+
snapshot=snapshot,
15+
snapshot_dir="tests/v1/codegen/snapshot/snapshots",
16+
read_schema=lambda: open("tests/v1/codegen/types/anyof_mixed_schema.json"),
17+
target_path="test_anyof_mixed_types",
18+
client_name="AnyOfMixedClient",
19+
protocol_version="v1.1",
20+
)
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
{
2+
"services": {
3+
"test_service": {
4+
"procedures": {
5+
"anyof_mixed_method": {
6+
"input": {
7+
"type": "object",
8+
"properties": {
9+
"run_command": {
10+
"description": "Command can be object with args, string, or array of strings",
11+
"anyOf": [
12+
{
13+
"type": "object",
14+
"properties": {
15+
"args": {
16+
"type": "array",
17+
"items": {
18+
"type": "string"
19+
}
20+
},
21+
"env": {
22+
"type": "object",
23+
"patternProperties": {
24+
"^(.*)$": {
25+
"type": "string"
26+
}
27+
}
28+
}
29+
},
30+
"required": ["args"]
31+
},
32+
{
33+
"type": "string"
34+
},
35+
{
36+
"type": "array",
37+
"items": {
38+
"type": "string"
39+
}
40+
}
41+
]
42+
},
43+
"value_or_null": {
44+
"description": "Value can be string or null",
45+
"anyOf": [
46+
{
47+
"type": "string"
48+
},
49+
{
50+
"type": "null"
51+
}
52+
]
53+
},
54+
"number_or_string": {
55+
"description": "Can be number or string",
56+
"anyOf": [
57+
{
58+
"type": "number"
59+
},
60+
{
61+
"type": "string"
62+
}
63+
]
64+
}
65+
},
66+
"required": ["run_command"]
67+
},
68+
"output": {
69+
"type": "object",
70+
"properties": {
71+
"success": {
72+
"type": "boolean"
73+
}
74+
},
75+
"required": ["success"]
76+
},
77+
"errors": {
78+
"not": {}
79+
},
80+
"type": "rpc"
81+
}
82+
}
83+
}
84+
}
85+
}

0 commit comments

Comments
 (0)