3838 ensure_literal_type ,
3939 extract_inner_type ,
4040 render_type_expr ,
41+ render_literal_type ,
4142)
4243
4344_NON_ALNUM_RE = re .compile (r"[^a-zA-Z0-9_]+" )
@@ -167,7 +168,7 @@ def encode_type(
167168 in_module : list [ModuleName ],
168169 permit_unknown_members : bool ,
169170) -> Tuple [TypeExpression , list [ModuleName ], list [FileContents ], set [TypeName ]]:
170- encoder_name : Optional [ str ] = None # defining this up here to placate mypy
171+ encoder_name : TypeName | None = None # defining this up here to placate mypy
171172 chunks : List [FileContents ] = []
172173 if isinstance (type , RiverNotType ):
173174 return (TypeName ("None" ), [], [], set ())
@@ -234,7 +235,7 @@ def flatten_union(tpe: RiverType) -> list[RiverType]:
234235 and prop .const is not None
235236 ].pop ()
236237 one_of_pending .setdefault (
237- f"{ prefix } OneOf_{ discriminator_value } " ,
238+ f"{ render_literal_type ( prefix ) } OneOf_{ discriminator_value } " ,
238239 (discriminator_value , []),
239240 )[1 ].append (oneof_t )
240241
@@ -270,12 +271,12 @@ def flatten_union(tpe: RiverType) -> list[RiverType]:
270271 oneof_t .properties .keys ()
271272 ).difference (common_members )
272273 encoder_name = TypeName (
273- f"encode_{ ensure_literal_type (type_name )} "
274+ f"encode_{ render_literal_type (type_name )} "
274275 )
275276 encoder_names .add (encoder_name )
276277 typeddict_encoder .append (
277278 f"""\
278- { encoder_name } (x) # type: ignore[arg-type]
279+ { render_literal_type ( encoder_name ) } (x) # type: ignore[arg-type]
279280 """ .strip ()
280281 )
281282 if local_discriminators :
@@ -299,12 +300,14 @@ def flatten_union(tpe: RiverType) -> list[RiverType]:
299300 one_of .append (type_name )
300301 chunks .extend (contents )
301302 encoder_name = TypeName (
302- f"encode_{ ensure_literal_type (type_name )} "
303+ f"encode_{ render_literal_type (type_name )} "
303304 )
304305 # TODO(dstewart): Figure out why uncommenting this breaks
305306 # generated code
306307 # encoder_names.add(encoder_name)
307- typeddict_encoder .append (f"{ encoder_name } (x)" )
308+ typeddict_encoder .append (
309+ f"{ render_literal_type (encoder_name )} (x)"
310+ )
308311 typeddict_encoder .append (
309312 f"""
310313 if x[{ repr (discriminator_name )} ]
@@ -317,19 +320,23 @@ def flatten_union(tpe: RiverType) -> list[RiverType]:
317320 union = OpenUnionTypeExpr (UnionTypeExpr (one_of ))
318321 else :
319322 union = UnionTypeExpr (one_of )
320- chunks .append (FileContents (f"{ prefix } = { render_type_expr (union )} " ))
323+ chunks .append (
324+ FileContents (
325+ f"{ render_literal_type (prefix )} = { render_type_expr (union )} "
326+ )
327+ )
321328 chunks .append (FileContents ("" ))
322329
323330 if base_model == "TypedDict" :
324- encoder_name = TypeName (f"encode_{ prefix } " )
331+ encoder_name = TypeName (f"encode_{ render_literal_type ( prefix ) } " )
325332 encoder_names .add (encoder_name )
326333 chunks .append (
327334 FileContents (
328335 "\n " .join (
329336 [
330337 dedent (
331338 f"""\
332- { encoder_name } : Callable[[{ repr (prefix )} ], Any] = (
339+ { render_literal_type ( encoder_name ) } : Callable[[{ repr (render_literal_type ( prefix ) )} ], Any] = (
333340 lambda x:
334341 """ .rstrip ()
335342 )
@@ -349,7 +356,7 @@ def flatten_union(tpe: RiverType) -> list[RiverType]:
349356 for i , t in enumerate (type .anyOf ):
350357 type_name , _ , contents , _ = encode_type (
351358 t ,
352- TypeName (f"{ prefix } AnyOf_{ i } " ),
359+ TypeName (f"{ render_literal_type ( prefix ) } AnyOf_{ i } " ),
353360 base_model ,
354361 in_module ,
355362 permit_unknown_members = permit_unknown_members ,
@@ -366,7 +373,7 @@ def flatten_union(tpe: RiverType) -> list[RiverType]:
366373 match type_name :
367374 case ListTypeExpr (inner_type_name ):
368375 typeddict_encoder .append (
369- f"encode_{ ensure_literal_type (inner_type_name )} (x)"
376+ f"encode_{ render_literal_type (inner_type_name )} (x)"
370377 )
371378 case DictTypeExpr (_):
372379 raise ValueError (
@@ -377,23 +384,25 @@ def flatten_union(tpe: RiverType) -> list[RiverType]:
377384 typeddict_encoder .append (repr (const ))
378385 case other :
379386 typeddict_encoder .append (
380- f"encode_{ ensure_literal_type (other )} (x)"
387+ f"encode_{ render_literal_type (other )} (x)"
381388 )
382389 if permit_unknown_members :
383390 union = OpenUnionTypeExpr (UnionTypeExpr (any_of ))
384391 else :
385392 union = UnionTypeExpr (any_of )
386393 if is_literal (type ):
387394 typeddict_encoder = ["x" ]
388- chunks .append (FileContents (f"{ prefix } = { render_type_expr (union )} " ))
395+ chunks .append (
396+ FileContents (f"{ render_literal_type (prefix )} = { render_type_expr (union )} " )
397+ )
389398 if base_model == "TypedDict" :
390- encoder_name = TypeName (f"encode_{ prefix } " )
399+ encoder_name = TypeName (f"encode_{ render_literal_type ( prefix ) } " )
391400 encoder_names .add (encoder_name )
392401 chunks .append (
393402 FileContents (
394403 "\n " .join (
395404 [
396- f"{ encoder_name } : Callable[[{ repr (prefix )} ], Any] = ("
405+ f"{ render_literal_type ( encoder_name ) } : Callable[[{ repr (render_literal_type ( prefix ) )} ], Any] = ("
397406 "lambda x: "
398407 ]
399408 + typeddict_encoder
@@ -491,7 +500,7 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
491500 match type_name :
492501 case ListTypeExpr (inner_type_name ):
493502 typeddict_encoder .append (
494- f"encode_{ ensure_literal_type (inner_type_name )} (x)"
503+ f"encode_{ render_literal_type (inner_type_name )} (x)"
495504 )
496505 case DictTypeExpr (_):
497506 raise ValueError (
@@ -500,11 +509,13 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
500509 case LiteralTypeExpr (const ):
501510 typeddict_encoder .append (repr (const ))
502511 case other :
503- typeddict_encoder .append (f"encode_{ ensure_literal_type (other )} (x)" )
512+ typeddict_encoder .append (f"encode_{ render_literal_type (other )} (x)" )
504513 return (DictTypeExpr (type_name ), module_info , type_chunks , encoder_names )
505514 assert type .type == "object" , type .type
506515
507- current_chunks : List [str ] = [f"class { prefix } ({ base_model } ):" ]
516+ current_chunks : List [str ] = [
517+ f"class { render_literal_type (prefix )} ({ base_model } ):"
518+ ]
508519 # For the encoder path, do we need "x" to be bound?
509520 # lambda x: ... vs lambda _: {}
510521 needs_binding = False
@@ -519,7 +530,7 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
519530 typeddict_encoder .append (f"{ repr (name )} :" )
520531 type_name , _ , contents , _ = encode_type (
521532 prop ,
522- TypeName (prefix + name .title ()),
533+ TypeName (prefix . value + name .title ()),
523534 base_model ,
524535 in_module ,
525536 permit_unknown_members = permit_unknown_members ,
@@ -531,17 +542,19 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
531542 typeddict_encoder .append ("'not implemented'" )
532543 elif isinstance (prop , RiverUnionType ):
533544 encoder_name = TypeName (
534- f"encode_{ ensure_literal_type (type_name )} "
545+ f"encode_{ render_literal_type (type_name )} "
535546 )
536547 encoder_names .add (encoder_name )
537- typeddict_encoder .append (f"{ encoder_name } (x[{ repr (name )} ])" )
548+ typeddict_encoder .append (
549+ f"{ render_literal_type (encoder_name )} (x[{ repr (name )} ])"
550+ )
538551 if name not in type .required :
539552 typeddict_encoder .append (
540553 f"if { repr (name )} in x and x[{ repr (name )} ] else None"
541554 )
542555 elif isinstance (prop , RiverIntersectionType ):
543556 encoder_name = TypeName (
544- f"encode_{ ensure_literal_type (type_name )} "
557+ f"encode_{ render_literal_type (type_name )} "
545558 )
546559 encoder_names .add (encoder_name )
547560 typeddict_encoder .append (f"{ encoder_name } (x[{ repr (name )} ])" )
@@ -552,11 +565,11 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
552565 safe_name = name
553566 if prop .type == "object" and not prop .patternProperties :
554567 encoder_name = TypeName (
555- f"encode_{ ensure_literal_type (type_name )} "
568+ f"encode_{ render_literal_type (type_name )} "
556569 )
557570 encoder_names .add (encoder_name )
558571 typeddict_encoder .append (
559- f"{ encoder_name } (x[{ repr (safe_name )} ])"
572+ f"{ render_literal_type ( encoder_name ) } (x[{ repr (safe_name )} ])"
560573 )
561574 if name not in prop .required :
562575 typeddict_encoder .append (
@@ -582,14 +595,14 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
582595 match type_name :
583596 case ListTypeExpr (inner_type_name ):
584597 encoder_name = TypeName (
585- f"encode_{ ensure_literal_type (inner_type_name )} "
598+ f"encode_{ render_literal_type (inner_type_name )} "
586599 )
587600 encoder_names .add (encoder_name )
588601 typeddict_encoder .append (
589602 dedent (
590603 f"""\
591604 [
592- { encoder_name } (y)
605+ { render_literal_type ( encoder_name ) } (y)
593606 for y in x[{ repr (name )} ]
594607 ]
595608 """ .rstrip ()
@@ -679,7 +692,7 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
679692
680693 if base_model == "TypedDict" :
681694 binding = "x" if needs_binding else "_"
682- encoder_name = TypeName (f"encode_{ prefix } " )
695+ encoder_name = TypeName (f"encode_{ render_literal_type ( prefix ) } " )
683696 encoder_names .add (encoder_name )
684697 current_chunks .insert (
685698 0 ,
@@ -688,7 +701,7 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
688701 [
689702 dedent (
690703 f"""\
691- { encoder_name } : Callable[[{ repr (prefix )} ], Any] = (
704+ { render_literal_type ( encoder_name ) } : Callable[[{ repr (render_literal_type ( prefix ) )} ], Any] = (
692705 lambda { binding } :
693706 """
694707 )
@@ -847,7 +860,7 @@ def __init__(self, client: river.Client[Any]):
847860 f"lambda xs: [encode_{ init_type_name } (x) for x in xs]"
848861 )
849862 else :
850- render_init_method = f"encode_{ ensure_literal_type (init_type )} "
863+ render_init_method = f"encode_{ render_literal_type (init_type )} "
851864 else :
852865 render_init_method = f"""\
853866 lambda x: TypeAdapter({ render_type_expr (init_type )} )
@@ -870,11 +883,11 @@ def __init__(self, client: river.Client[Any]):
870883 case ListTypeExpr (input_type_name ):
871884 render_input_method = f"""\
872885 lambda xs: [
873- encode_{ ensure_literal_type (input_type_name )} (x) for x in xs
886+ encode_{ render_literal_type (input_type_name )} (x) for x in xs
874887 ]
875888 """
876889 else :
877- render_input_method = f"encode_{ ensure_literal_type (input_type )} "
890+ render_input_method = f"encode_{ render_literal_type (input_type )} "
878891 else :
879892 render_input_method = f"""\
880893 lambda x: TypeAdapter({ render_type_expr (input_type )} )
@@ -1070,7 +1083,7 @@ async def {name}(
10701083 emitted_files [file_path ] = FileContents ("\n " .join ([existing ] + contents ))
10711084
10721085 rendered_imports = [
1073- f"from .{ dotted_modules } import { ', ' .join (sorted (names ))} "
1086+ f"from .{ dotted_modules } import { ', ' .join (sorted (render_literal_type ( x ) for x in names ))} "
10741087 for dotted_modules , names in imports .items ()
10751088 ]
10761089
0 commit comments