Skip to content

Commit f255585

Browse files
[bug] Forgot type render call (#136)
Why === Ended up leaking a raw `TypeExpression` into generated code by forgetting this render call. Also made the rest of these dataclasses throw when used incorrectly. What changed ============ - Fixed the bug - Made it impossible to happen again Test plan ========= _Describe what you did to test this change to a level of detail that allows your reviewer to test it_
1 parent e2eba79 commit f255585

File tree

4 files changed

+116
-48
lines changed

4 files changed

+116
-48
lines changed

pyproject.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,10 @@ dependencies = [
3333
"opentelemetry-api>=1.28.2",
3434
]
3535

36+
[project.scripts]
37+
lint = "lint:main"
38+
format = "lint:main"
39+
3640
[tool.uv]
3741
dev-dependencies = [
3842
"deptry>=0.14.0",

src/lint.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
#!/usr/bin/env python
2+
3+
import os
4+
import sys
5+
6+
7+
def raise_err(code: int) -> None:
8+
if code > 0:
9+
sys.exit(1)
10+
11+
12+
def main() -> None:
13+
fix = ["--fix"] if "--fix" in sys.argv else []
14+
raise_err(os.system(" ".join(["ruff", "check", "src"] + fix)))
15+
raise_err(os.system("ruff format src"))
16+
raise_err(os.system("mypy src"))
17+
raise_err(os.system("pyright src"))

src/replit_river/codegen/client.py

Lines changed: 59 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@
3535
TypeExpression,
3636
TypeName,
3737
UnionTypeExpr,
38-
ensure_literal_type,
3938
extract_inner_type,
39+
render_literal_type,
4040
render_type_expr,
4141
)
4242

@@ -167,7 +167,7 @@ def encode_type(
167167
in_module: list[ModuleName],
168168
permit_unknown_members: bool,
169169
) -> Tuple[TypeExpression, list[ModuleName], list[FileContents], set[TypeName]]:
170-
encoder_name: Optional[str] = None # defining this up here to placate mypy
170+
encoder_name: TypeName | None = None # defining this up here to placate mypy
171171
chunks: List[FileContents] = []
172172
if isinstance(type, RiverNotType):
173173
return (TypeName("None"), [], [], set())
@@ -234,7 +234,7 @@ def flatten_union(tpe: RiverType) -> list[RiverType]:
234234
and prop.const is not None
235235
].pop()
236236
one_of_pending.setdefault(
237-
f"{prefix}OneOf_{discriminator_value}",
237+
f"{render_literal_type(prefix)}OneOf_{discriminator_value}",
238238
(discriminator_value, []),
239239
)[1].append(oneof_t)
240240

@@ -270,12 +270,13 @@ def flatten_union(tpe: RiverType) -> list[RiverType]:
270270
oneof_t.properties.keys()
271271
).difference(common_members)
272272
encoder_name = TypeName(
273-
f"encode_{ensure_literal_type(type_name)}"
273+
f"encode_{render_literal_type(type_name)}"
274274
)
275275
encoder_names.add(encoder_name)
276+
_field_name = render_literal_type(encoder_name)
276277
typeddict_encoder.append(
277278
f"""\
278-
{encoder_name}(x) # type: ignore[arg-type]
279+
{_field_name}(x) # type: ignore[arg-type]
279280
""".strip()
280281
)
281282
if local_discriminators:
@@ -299,12 +300,14 @@ def flatten_union(tpe: RiverType) -> list[RiverType]:
299300
one_of.append(type_name)
300301
chunks.extend(contents)
301302
encoder_name = TypeName(
302-
f"encode_{ensure_literal_type(type_name)}"
303+
f"encode_{render_literal_type(type_name)}"
303304
)
304305
# TODO(dstewart): Figure out why uncommenting this breaks
305306
# generated code
306307
# encoder_names.add(encoder_name)
307-
typeddict_encoder.append(f"{encoder_name}(x)")
308+
typeddict_encoder.append(
309+
f"{render_literal_type(encoder_name)}(x)"
310+
)
308311
typeddict_encoder.append(
309312
f"""
310313
if x[{repr(discriminator_name)}]
@@ -317,19 +320,27 @@ def flatten_union(tpe: RiverType) -> list[RiverType]:
317320
union = OpenUnionTypeExpr(UnionTypeExpr(one_of))
318321
else:
319322
union = UnionTypeExpr(one_of)
320-
chunks.append(FileContents(f"{prefix} = {render_type_expr(union)}"))
323+
chunks.append(
324+
FileContents(
325+
f"{render_literal_type(prefix)} = {render_type_expr(union)}"
326+
)
327+
)
321328
chunks.append(FileContents(""))
322329

323330
if base_model == "TypedDict":
324-
encoder_name = TypeName(f"encode_{prefix}")
331+
encoder_name = TypeName(f"encode_{render_literal_type(prefix)}")
325332
encoder_names.add(encoder_name)
333+
_field_name = render_literal_type(encoder_name)
334+
_field_type = (
335+
f"Callable[[{repr(render_literal_type(prefix))}], Any]"
336+
)
326337
chunks.append(
327338
FileContents(
328339
"\n".join(
329340
[
330341
dedent(
331342
f"""\
332-
{encoder_name}: Callable[[{repr(prefix)}], Any] = (
343+
{_field_name}: {_field_type} = (
333344
lambda x:
334345
""".rstrip()
335346
)
@@ -349,7 +360,7 @@ def flatten_union(tpe: RiverType) -> list[RiverType]:
349360
for i, t in enumerate(type.anyOf):
350361
type_name, _, contents, _ = encode_type(
351362
t,
352-
TypeName(f"{prefix}AnyOf_{i}"),
363+
TypeName(f"{render_literal_type(prefix)}AnyOf_{i}"),
353364
base_model,
354365
in_module,
355366
permit_unknown_members=permit_unknown_members,
@@ -366,7 +377,7 @@ def flatten_union(tpe: RiverType) -> list[RiverType]:
366377
match type_name:
367378
case ListTypeExpr(inner_type_name):
368379
typeddict_encoder.append(
369-
f"encode_{ensure_literal_type(inner_type_name)}(x)"
380+
f"encode_{render_literal_type(inner_type_name)}(x)"
370381
)
371382
case DictTypeExpr(_):
372383
raise ValueError(
@@ -377,25 +388,26 @@ def flatten_union(tpe: RiverType) -> list[RiverType]:
377388
typeddict_encoder.append(repr(const))
378389
case other:
379390
typeddict_encoder.append(
380-
f"encode_{ensure_literal_type(other)}(x)"
391+
f"encode_{render_literal_type(other)}(x)"
381392
)
382393
if permit_unknown_members:
383394
union = OpenUnionTypeExpr(UnionTypeExpr(any_of))
384395
else:
385396
union = UnionTypeExpr(any_of)
386397
if is_literal(type):
387398
typeddict_encoder = ["x"]
388-
chunks.append(FileContents(f"{prefix} = {render_type_expr(union)}"))
399+
chunks.append(
400+
FileContents(f"{render_literal_type(prefix)} = {render_type_expr(union)}")
401+
)
389402
if base_model == "TypedDict":
390-
encoder_name = TypeName(f"encode_{prefix}")
403+
encoder_name = TypeName(f"encode_{render_literal_type(prefix)}")
391404
encoder_names.add(encoder_name)
405+
_field_name = render_literal_type(encoder_name)
406+
_field_type = f"Callable[[{repr(render_literal_type(prefix))}], Any]"
392407
chunks.append(
393408
FileContents(
394409
"\n".join(
395-
[
396-
f"{encoder_name}: Callable[[{repr(prefix)}], Any] = ("
397-
"lambda x: "
398-
]
410+
[f"{_field_name}: {_field_type} = (lambda x: "]
399411
+ typeddict_encoder
400412
+ [")"]
401413
)
@@ -491,7 +503,7 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
491503
match type_name:
492504
case ListTypeExpr(inner_type_name):
493505
typeddict_encoder.append(
494-
f"encode_{ensure_literal_type(inner_type_name)}(x)"
506+
f"encode_{render_literal_type(inner_type_name)}(x)"
495507
)
496508
case DictTypeExpr(_):
497509
raise ValueError(
@@ -500,11 +512,13 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
500512
case LiteralTypeExpr(const):
501513
typeddict_encoder.append(repr(const))
502514
case other:
503-
typeddict_encoder.append(f"encode_{ensure_literal_type(other)}(x)")
515+
typeddict_encoder.append(f"encode_{render_literal_type(other)}(x)")
504516
return (DictTypeExpr(type_name), module_info, type_chunks, encoder_names)
505517
assert type.type == "object", type.type
506518

507-
current_chunks: List[str] = [f"class {prefix}({base_model}):"]
519+
current_chunks: List[str] = [
520+
f"class {render_literal_type(prefix)}({base_model}):"
521+
]
508522
# For the encoder path, do we need "x" to be bound?
509523
# lambda x: ... vs lambda _: {}
510524
needs_binding = False
@@ -519,7 +533,7 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
519533
typeddict_encoder.append(f"{repr(name)}:")
520534
type_name, _, contents, _ = encode_type(
521535
prop,
522-
TypeName(prefix + name.title()),
536+
TypeName(prefix.value + name.title()),
523537
base_model,
524538
in_module,
525539
permit_unknown_members=permit_unknown_members,
@@ -531,17 +545,19 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
531545
typeddict_encoder.append("'not implemented'")
532546
elif isinstance(prop, RiverUnionType):
533547
encoder_name = TypeName(
534-
f"encode_{ensure_literal_type(type_name)}"
548+
f"encode_{render_literal_type(type_name)}"
535549
)
536550
encoder_names.add(encoder_name)
537-
typeddict_encoder.append(f"{encoder_name}(x[{repr(name)}])")
551+
typeddict_encoder.append(
552+
f"{render_literal_type(encoder_name)}(x[{repr(name)}])"
553+
)
538554
if name not in type.required:
539555
typeddict_encoder.append(
540556
f"if {repr(name)} in x and x[{repr(name)}] else None"
541557
)
542558
elif isinstance(prop, RiverIntersectionType):
543559
encoder_name = TypeName(
544-
f"encode_{ensure_literal_type(type_name)}"
560+
f"encode_{render_literal_type(type_name)}"
545561
)
546562
encoder_names.add(encoder_name)
547563
typeddict_encoder.append(f"{encoder_name}(x[{repr(name)}])")
@@ -552,11 +568,11 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
552568
safe_name = name
553569
if prop.type == "object" and not prop.patternProperties:
554570
encoder_name = TypeName(
555-
f"encode_{ensure_literal_type(type_name)}"
571+
f"encode_{render_literal_type(type_name)}"
556572
)
557573
encoder_names.add(encoder_name)
558574
typeddict_encoder.append(
559-
f"{encoder_name}(x[{repr(safe_name)}])"
575+
f"{render_literal_type(encoder_name)}(x[{repr(safe_name)}])"
560576
)
561577
if name not in prop.required:
562578
typeddict_encoder.append(
@@ -582,14 +598,14 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
582598
match type_name:
583599
case ListTypeExpr(inner_type_name):
584600
encoder_name = TypeName(
585-
f"encode_{ensure_literal_type(inner_type_name)}"
601+
f"encode_{render_literal_type(inner_type_name)}"
586602
)
587603
encoder_names.add(encoder_name)
588604
typeddict_encoder.append(
589605
dedent(
590606
f"""\
591607
[
592-
{encoder_name}(y)
608+
{render_literal_type(encoder_name)}(y)
593609
for y in x[{repr(name)}]
594610
]
595611
""".rstrip()
@@ -679,16 +695,18 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
679695

680696
if base_model == "TypedDict":
681697
binding = "x" if needs_binding else "_"
682-
encoder_name = TypeName(f"encode_{prefix}")
698+
encoder_name = TypeName(f"encode_{render_literal_type(prefix)}")
683699
encoder_names.add(encoder_name)
700+
_field_name = render_literal_type(encoder_name)
701+
_field_type = f"Callable[[{repr(render_literal_type(prefix))}], Any]"
684702
current_chunks.insert(
685703
0,
686704
FileContents(
687705
"\n".join(
688706
[
689707
dedent(
690708
f"""\
691-
{encoder_name}: Callable[[{repr(prefix)}], Any] = (
709+
{_field_name}: {_field_type} = (
692710
lambda {binding}:
693711
"""
694712
)
@@ -847,7 +865,7 @@ def __init__(self, client: river.Client[Any]):
847865
f"lambda xs: [encode_{init_type_name}(x) for x in xs]"
848866
)
849867
else:
850-
render_init_method = f"encode_{ensure_literal_type(init_type)}"
868+
render_init_method = f"encode_{render_literal_type(init_type)}"
851869
else:
852870
render_init_method = f"""\
853871
lambda x: TypeAdapter({render_type_expr(init_type)})
@@ -870,11 +888,11 @@ def __init__(self, client: river.Client[Any]):
870888
case ListTypeExpr(input_type_name):
871889
render_input_method = f"""\
872890
lambda xs: [
873-
encode_{ensure_literal_type(input_type_name)}(x) for x in xs
891+
encode_{render_literal_type(input_type_name)}(x) for x in xs
874892
]
875893
"""
876894
else:
877-
render_input_method = f"encode_{ensure_literal_type(input_type)}"
895+
render_input_method = f"encode_{render_literal_type(input_type)}"
878896
else:
879897
render_input_method = f"""\
880898
lambda x: TypeAdapter({render_type_expr(input_type)})
@@ -957,9 +975,9 @@ async def {name}(
957975
f"""\
958976
async def {name}(
959977
self,
960-
init: {init_type},
978+
init: {render_type_expr(init_type)},
961979
inputStream: AsyncIterable[{render_type_expr(input_type)}],
962-
) -> {output_type}:
980+
) -> {render_type_expr(output_type)}:
963981
return await self.client.send_upload(
964982
{repr(schema_name)},
965983
{repr(name)},
@@ -1069,8 +1087,11 @@ async def {name}(
10691087
existing = emitted_files.get(file_path, FileContents(FILE_HEADER))
10701088
emitted_files[file_path] = FileContents("\n".join([existing] + contents))
10711089

1090+
def render_names(xs: set[TypeName]) -> str:
1091+
return ", ".join(sorted(render_literal_type(x) for x in xs))
1092+
10721093
rendered_imports = [
1073-
f"from .{dotted_modules} import {', '.join(sorted(names))}"
1094+
f"from .{dotted_modules} import {render_names(names)}"
10741095
for dotted_modules, names in imports.items()
10751096
]
10761097

0 commit comments

Comments
 (0)