Skip to content

Commit a92d2be

Browse files
Found the rest of them!
1 parent 926f821 commit a92d2be

File tree

1 file changed

+45
-32
lines changed

1 file changed

+45
-32
lines changed

src/replit_river/codegen/client.py

Lines changed: 45 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
ensure_literal_type,
3939
extract_inner_type,
4040
render_type_expr,
41+
render_literal_type,
4142
)
4243

4344
_NON_ALNUM_RE = re.compile(r"[^a-zA-Z0-9_]+")
@@ -167,7 +168,7 @@ def encode_type(
167168
in_module: list[ModuleName],
168169
permit_unknown_members: bool,
169170
) -> Tuple[TypeExpression, list[ModuleName], list[FileContents], set[TypeName]]:
170-
encoder_name: Optional[str] = None # defining this up here to placate mypy
171+
encoder_name: TypeName | None = None # defining this up here to placate mypy
171172
chunks: List[FileContents] = []
172173
if isinstance(type, RiverNotType):
173174
return (TypeName("None"), [], [], set())
@@ -234,7 +235,7 @@ def flatten_union(tpe: RiverType) -> list[RiverType]:
234235
and prop.const is not None
235236
].pop()
236237
one_of_pending.setdefault(
237-
f"{prefix}OneOf_{discriminator_value}",
238+
f"{render_literal_type(prefix)}OneOf_{discriminator_value}",
238239
(discriminator_value, []),
239240
)[1].append(oneof_t)
240241

@@ -270,12 +271,12 @@ def flatten_union(tpe: RiverType) -> list[RiverType]:
270271
oneof_t.properties.keys()
271272
).difference(common_members)
272273
encoder_name = TypeName(
273-
f"encode_{ensure_literal_type(type_name)}"
274+
f"encode_{render_literal_type(type_name)}"
274275
)
275276
encoder_names.add(encoder_name)
276277
typeddict_encoder.append(
277278
f"""\
278-
{encoder_name}(x) # type: ignore[arg-type]
279+
{render_literal_type(encoder_name)}(x) # type: ignore[arg-type]
279280
""".strip()
280281
)
281282
if local_discriminators:
@@ -299,12 +300,14 @@ def flatten_union(tpe: RiverType) -> list[RiverType]:
299300
one_of.append(type_name)
300301
chunks.extend(contents)
301302
encoder_name = TypeName(
302-
f"encode_{ensure_literal_type(type_name)}"
303+
f"encode_{render_literal_type(type_name)}"
303304
)
304305
# TODO(dstewart): Figure out why uncommenting this breaks
305306
# generated code
306307
# encoder_names.add(encoder_name)
307-
typeddict_encoder.append(f"{encoder_name}(x)")
308+
typeddict_encoder.append(
309+
f"{render_literal_type(encoder_name)}(x)"
310+
)
308311
typeddict_encoder.append(
309312
f"""
310313
if x[{repr(discriminator_name)}]
@@ -317,19 +320,23 @@ def flatten_union(tpe: RiverType) -> list[RiverType]:
317320
union = OpenUnionTypeExpr(UnionTypeExpr(one_of))
318321
else:
319322
union = UnionTypeExpr(one_of)
320-
chunks.append(FileContents(f"{prefix} = {render_type_expr(union)}"))
323+
chunks.append(
324+
FileContents(
325+
f"{render_literal_type(prefix)} = {render_type_expr(union)}"
326+
)
327+
)
321328
chunks.append(FileContents(""))
322329

323330
if base_model == "TypedDict":
324-
encoder_name = TypeName(f"encode_{prefix}")
331+
encoder_name = TypeName(f"encode_{render_literal_type(prefix)}")
325332
encoder_names.add(encoder_name)
326333
chunks.append(
327334
FileContents(
328335
"\n".join(
329336
[
330337
dedent(
331338
f"""\
332-
{encoder_name}: Callable[[{repr(prefix)}], Any] = (
339+
{render_literal_type(encoder_name)}: Callable[[{repr(render_literal_type(prefix))}], Any] = (
333340
lambda x:
334341
""".rstrip()
335342
)
@@ -349,7 +356,7 @@ def flatten_union(tpe: RiverType) -> list[RiverType]:
349356
for i, t in enumerate(type.anyOf):
350357
type_name, _, contents, _ = encode_type(
351358
t,
352-
TypeName(f"{prefix}AnyOf_{i}"),
359+
TypeName(f"{render_literal_type(prefix)}AnyOf_{i}"),
353360
base_model,
354361
in_module,
355362
permit_unknown_members=permit_unknown_members,
@@ -366,7 +373,7 @@ def flatten_union(tpe: RiverType) -> list[RiverType]:
366373
match type_name:
367374
case ListTypeExpr(inner_type_name):
368375
typeddict_encoder.append(
369-
f"encode_{ensure_literal_type(inner_type_name)}(x)"
376+
f"encode_{render_literal_type(inner_type_name)}(x)"
370377
)
371378
case DictTypeExpr(_):
372379
raise ValueError(
@@ -377,23 +384,25 @@ def flatten_union(tpe: RiverType) -> list[RiverType]:
377384
typeddict_encoder.append(repr(const))
378385
case other:
379386
typeddict_encoder.append(
380-
f"encode_{ensure_literal_type(other)}(x)"
387+
f"encode_{render_literal_type(other)}(x)"
381388
)
382389
if permit_unknown_members:
383390
union = OpenUnionTypeExpr(UnionTypeExpr(any_of))
384391
else:
385392
union = UnionTypeExpr(any_of)
386393
if is_literal(type):
387394
typeddict_encoder = ["x"]
388-
chunks.append(FileContents(f"{prefix} = {render_type_expr(union)}"))
395+
chunks.append(
396+
FileContents(f"{render_literal_type(prefix)} = {render_type_expr(union)}")
397+
)
389398
if base_model == "TypedDict":
390-
encoder_name = TypeName(f"encode_{prefix}")
399+
encoder_name = TypeName(f"encode_{render_literal_type(prefix)}")
391400
encoder_names.add(encoder_name)
392401
chunks.append(
393402
FileContents(
394403
"\n".join(
395404
[
396-
f"{encoder_name}: Callable[[{repr(prefix)}], Any] = ("
405+
f"{render_literal_type(encoder_name)}: Callable[[{repr(render_literal_type(prefix))}], Any] = ("
397406
"lambda x: "
398407
]
399408
+ typeddict_encoder
@@ -491,7 +500,7 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
491500
match type_name:
492501
case ListTypeExpr(inner_type_name):
493502
typeddict_encoder.append(
494-
f"encode_{ensure_literal_type(inner_type_name)}(x)"
503+
f"encode_{render_literal_type(inner_type_name)}(x)"
495504
)
496505
case DictTypeExpr(_):
497506
raise ValueError(
@@ -500,11 +509,13 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
500509
case LiteralTypeExpr(const):
501510
typeddict_encoder.append(repr(const))
502511
case other:
503-
typeddict_encoder.append(f"encode_{ensure_literal_type(other)}(x)")
512+
typeddict_encoder.append(f"encode_{render_literal_type(other)}(x)")
504513
return (DictTypeExpr(type_name), module_info, type_chunks, encoder_names)
505514
assert type.type == "object", type.type
506515

507-
current_chunks: List[str] = [f"class {prefix}({base_model}):"]
516+
current_chunks: List[str] = [
517+
f"class {render_literal_type(prefix)}({base_model}):"
518+
]
508519
# For the encoder path, do we need "x" to be bound?
509520
# lambda x: ... vs lambda _: {}
510521
needs_binding = False
@@ -519,7 +530,7 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
519530
typeddict_encoder.append(f"{repr(name)}:")
520531
type_name, _, contents, _ = encode_type(
521532
prop,
522-
TypeName(prefix + name.title()),
533+
TypeName(prefix.value + name.title()),
523534
base_model,
524535
in_module,
525536
permit_unknown_members=permit_unknown_members,
@@ -531,17 +542,19 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
531542
typeddict_encoder.append("'not implemented'")
532543
elif isinstance(prop, RiverUnionType):
533544
encoder_name = TypeName(
534-
f"encode_{ensure_literal_type(type_name)}"
545+
f"encode_{render_literal_type(type_name)}"
535546
)
536547
encoder_names.add(encoder_name)
537-
typeddict_encoder.append(f"{encoder_name}(x[{repr(name)}])")
548+
typeddict_encoder.append(
549+
f"{render_literal_type(encoder_name)}(x[{repr(name)}])"
550+
)
538551
if name not in type.required:
539552
typeddict_encoder.append(
540553
f"if {repr(name)} in x and x[{repr(name)}] else None"
541554
)
542555
elif isinstance(prop, RiverIntersectionType):
543556
encoder_name = TypeName(
544-
f"encode_{ensure_literal_type(type_name)}"
557+
f"encode_{render_literal_type(type_name)}"
545558
)
546559
encoder_names.add(encoder_name)
547560
typeddict_encoder.append(f"{encoder_name}(x[{repr(name)}])")
@@ -552,11 +565,11 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
552565
safe_name = name
553566
if prop.type == "object" and not prop.patternProperties:
554567
encoder_name = TypeName(
555-
f"encode_{ensure_literal_type(type_name)}"
568+
f"encode_{render_literal_type(type_name)}"
556569
)
557570
encoder_names.add(encoder_name)
558571
typeddict_encoder.append(
559-
f"{encoder_name}(x[{repr(safe_name)}])"
572+
f"{render_literal_type(encoder_name)}(x[{repr(safe_name)}])"
560573
)
561574
if name not in prop.required:
562575
typeddict_encoder.append(
@@ -582,14 +595,14 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
582595
match type_name:
583596
case ListTypeExpr(inner_type_name):
584597
encoder_name = TypeName(
585-
f"encode_{ensure_literal_type(inner_type_name)}"
598+
f"encode_{render_literal_type(inner_type_name)}"
586599
)
587600
encoder_names.add(encoder_name)
588601
typeddict_encoder.append(
589602
dedent(
590603
f"""\
591604
[
592-
{encoder_name}(y)
605+
{render_literal_type(encoder_name)}(y)
593606
for y in x[{repr(name)}]
594607
]
595608
""".rstrip()
@@ -679,7 +692,7 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
679692

680693
if base_model == "TypedDict":
681694
binding = "x" if needs_binding else "_"
682-
encoder_name = TypeName(f"encode_{prefix}")
695+
encoder_name = TypeName(f"encode_{render_literal_type(prefix)}")
683696
encoder_names.add(encoder_name)
684697
current_chunks.insert(
685698
0,
@@ -688,7 +701,7 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
688701
[
689702
dedent(
690703
f"""\
691-
{encoder_name}: Callable[[{repr(prefix)}], Any] = (
704+
{render_literal_type(encoder_name)}: Callable[[{repr(render_literal_type(prefix))}], Any] = (
692705
lambda {binding}:
693706
"""
694707
)
@@ -847,7 +860,7 @@ def __init__(self, client: river.Client[Any]):
847860
f"lambda xs: [encode_{init_type_name}(x) for x in xs]"
848861
)
849862
else:
850-
render_init_method = f"encode_{ensure_literal_type(init_type)}"
863+
render_init_method = f"encode_{render_literal_type(init_type)}"
851864
else:
852865
render_init_method = f"""\
853866
lambda x: TypeAdapter({render_type_expr(init_type)})
@@ -870,11 +883,11 @@ def __init__(self, client: river.Client[Any]):
870883
case ListTypeExpr(input_type_name):
871884
render_input_method = f"""\
872885
lambda xs: [
873-
encode_{ensure_literal_type(input_type_name)}(x) for x in xs
886+
encode_{render_literal_type(input_type_name)}(x) for x in xs
874887
]
875888
"""
876889
else:
877-
render_input_method = f"encode_{ensure_literal_type(input_type)}"
890+
render_input_method = f"encode_{render_literal_type(input_type)}"
878891
else:
879892
render_input_method = f"""\
880893
lambda x: TypeAdapter({render_type_expr(input_type)})
@@ -1070,7 +1083,7 @@ async def {name}(
10701083
emitted_files[file_path] = FileContents("\n".join([existing] + contents))
10711084

10721085
rendered_imports = [
1073-
f"from .{dotted_modules} import {', '.join(sorted(names))}"
1086+
f"from .{dotted_modules} import {', '.join(sorted(render_literal_type(x) for x in names))}"
10741087
for dotted_modules, names in imports.items()
10751088
]
10761089

0 commit comments

Comments
 (0)