3535 TypeExpression ,
3636 TypeName ,
3737 UnionTypeExpr ,
38- ensure_literal_type ,
3938 extract_inner_type ,
39+ render_literal_type ,
4040 render_type_expr ,
4141)
4242
@@ -167,7 +167,7 @@ def encode_type(
167167 in_module : list [ModuleName ],
168168 permit_unknown_members : bool ,
169169) -> Tuple [TypeExpression , list [ModuleName ], list [FileContents ], set [TypeName ]]:
170- encoder_name : Optional [ str ] = None # defining this up here to placate mypy
170+ encoder_name : TypeName | None = None # defining this up here to placate mypy
171171 chunks : List [FileContents ] = []
172172 if isinstance (type , RiverNotType ):
173173 return (TypeName ("None" ), [], [], set ())
@@ -234,7 +234,7 @@ def flatten_union(tpe: RiverType) -> list[RiverType]:
234234 and prop .const is not None
235235 ].pop ()
236236 one_of_pending .setdefault (
237- f"{ prefix } OneOf_{ discriminator_value } " ,
237+ f"{ render_literal_type ( prefix ) } OneOf_{ discriminator_value } " ,
238238 (discriminator_value , []),
239239 )[1 ].append (oneof_t )
240240
@@ -270,12 +270,13 @@ def flatten_union(tpe: RiverType) -> list[RiverType]:
270270 oneof_t .properties .keys ()
271271 ).difference (common_members )
272272 encoder_name = TypeName (
273- f"encode_{ ensure_literal_type (type_name )} "
273+ f"encode_{ render_literal_type (type_name )} "
274274 )
275275 encoder_names .add (encoder_name )
276+ _field_name = render_literal_type (encoder_name )
276277 typeddict_encoder .append (
277278 f"""\
278- { encoder_name } (x) # type: ignore[arg-type]
279+ { _field_name } (x) # type: ignore[arg-type]
279280 """ .strip ()
280281 )
281282 if local_discriminators :
@@ -299,12 +300,14 @@ def flatten_union(tpe: RiverType) -> list[RiverType]:
299300 one_of .append (type_name )
300301 chunks .extend (contents )
301302 encoder_name = TypeName (
302- f"encode_{ ensure_literal_type (type_name )} "
303+ f"encode_{ render_literal_type (type_name )} "
303304 )
304305 # TODO(dstewart): Figure out why uncommenting this breaks
305306 # generated code
306307 # encoder_names.add(encoder_name)
307- typeddict_encoder .append (f"{ encoder_name } (x)" )
308+ typeddict_encoder .append (
309+ f"{ render_literal_type (encoder_name )} (x)"
310+ )
308311 typeddict_encoder .append (
309312 f"""
310313 if x[{ repr (discriminator_name )} ]
@@ -317,19 +320,27 @@ def flatten_union(tpe: RiverType) -> list[RiverType]:
317320 union = OpenUnionTypeExpr (UnionTypeExpr (one_of ))
318321 else :
319322 union = UnionTypeExpr (one_of )
320- chunks .append (FileContents (f"{ prefix } = { render_type_expr (union )} " ))
323+ chunks .append (
324+ FileContents (
325+ f"{ render_literal_type (prefix )} = { render_type_expr (union )} "
326+ )
327+ )
321328 chunks .append (FileContents ("" ))
322329
323330 if base_model == "TypedDict" :
324- encoder_name = TypeName (f"encode_{ prefix } " )
331+ encoder_name = TypeName (f"encode_{ render_literal_type ( prefix ) } " )
325332 encoder_names .add (encoder_name )
333+ _field_name = render_literal_type (encoder_name )
334+ _field_type = (
335+ f"Callable[[{ repr (render_literal_type (prefix ))} ], Any]"
336+ )
326337 chunks .append (
327338 FileContents (
328339 "\n " .join (
329340 [
330341 dedent (
331342 f"""\
332- { encoder_name } : Callable[[ { repr ( prefix ) } ], Any] = (
343+ { _field_name } : { _field_type } = (
333344 lambda x:
334345 """ .rstrip ()
335346 )
@@ -349,7 +360,7 @@ def flatten_union(tpe: RiverType) -> list[RiverType]:
349360 for i , t in enumerate (type .anyOf ):
350361 type_name , _ , contents , _ = encode_type (
351362 t ,
352- TypeName (f"{ prefix } AnyOf_{ i } " ),
363+ TypeName (f"{ render_literal_type ( prefix ) } AnyOf_{ i } " ),
353364 base_model ,
354365 in_module ,
355366 permit_unknown_members = permit_unknown_members ,
@@ -366,7 +377,7 @@ def flatten_union(tpe: RiverType) -> list[RiverType]:
366377 match type_name :
367378 case ListTypeExpr (inner_type_name ):
368379 typeddict_encoder .append (
369- f"encode_{ ensure_literal_type (inner_type_name )} (x)"
380+ f"encode_{ render_literal_type (inner_type_name )} (x)"
370381 )
371382 case DictTypeExpr (_):
372383 raise ValueError (
@@ -377,25 +388,26 @@ def flatten_union(tpe: RiverType) -> list[RiverType]:
377388 typeddict_encoder .append (repr (const ))
378389 case other :
379390 typeddict_encoder .append (
380- f"encode_{ ensure_literal_type (other )} (x)"
391+ f"encode_{ render_literal_type (other )} (x)"
381392 )
382393 if permit_unknown_members :
383394 union = OpenUnionTypeExpr (UnionTypeExpr (any_of ))
384395 else :
385396 union = UnionTypeExpr (any_of )
386397 if is_literal (type ):
387398 typeddict_encoder = ["x" ]
388- chunks .append (FileContents (f"{ prefix } = { render_type_expr (union )} " ))
399+ chunks .append (
400+ FileContents (f"{ render_literal_type (prefix )} = { render_type_expr (union )} " )
401+ )
389402 if base_model == "TypedDict" :
390- encoder_name = TypeName (f"encode_{ prefix } " )
403+ encoder_name = TypeName (f"encode_{ render_literal_type ( prefix ) } " )
391404 encoder_names .add (encoder_name )
405+ _field_name = render_literal_type (encoder_name )
406+ _field_type = f"Callable[[{ repr (render_literal_type (prefix ))} ], Any]"
392407 chunks .append (
393408 FileContents (
394409 "\n " .join (
395- [
396- f"{ encoder_name } : Callable[[{ repr (prefix )} ], Any] = ("
397- "lambda x: "
398- ]
410+ [f"{ _field_name } : { _field_type } = (lambda x: " ]
399411 + typeddict_encoder
400412 + [")" ]
401413 )
@@ -491,7 +503,7 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
491503 match type_name :
492504 case ListTypeExpr (inner_type_name ):
493505 typeddict_encoder .append (
494- f"encode_{ ensure_literal_type (inner_type_name )} (x)"
506+ f"encode_{ render_literal_type (inner_type_name )} (x)"
495507 )
496508 case DictTypeExpr (_):
497509 raise ValueError (
@@ -500,11 +512,13 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
500512 case LiteralTypeExpr (const ):
501513 typeddict_encoder .append (repr (const ))
502514 case other :
503- typeddict_encoder .append (f"encode_{ ensure_literal_type (other )} (x)" )
515+ typeddict_encoder .append (f"encode_{ render_literal_type (other )} (x)" )
504516 return (DictTypeExpr (type_name ), module_info , type_chunks , encoder_names )
505517 assert type .type == "object" , type .type
506518
507- current_chunks : List [str ] = [f"class { prefix } ({ base_model } ):" ]
519+ current_chunks : List [str ] = [
520+ f"class { render_literal_type (prefix )} ({ base_model } ):"
521+ ]
508522 # For the encoder path, do we need "x" to be bound?
509523 # lambda x: ... vs lambda _: {}
510524 needs_binding = False
@@ -519,7 +533,7 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
519533 typeddict_encoder .append (f"{ repr (name )} :" )
520534 type_name , _ , contents , _ = encode_type (
521535 prop ,
522- TypeName (prefix + name .title ()),
536+ TypeName (prefix . value + name .title ()),
523537 base_model ,
524538 in_module ,
525539 permit_unknown_members = permit_unknown_members ,
@@ -531,17 +545,19 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
531545 typeddict_encoder .append ("'not implemented'" )
532546 elif isinstance (prop , RiverUnionType ):
533547 encoder_name = TypeName (
534- f"encode_{ ensure_literal_type (type_name )} "
548+ f"encode_{ render_literal_type (type_name )} "
535549 )
536550 encoder_names .add (encoder_name )
537- typeddict_encoder .append (f"{ encoder_name } (x[{ repr (name )} ])" )
551+ typeddict_encoder .append (
552+ f"{ render_literal_type (encoder_name )} (x[{ repr (name )} ])"
553+ )
538554 if name not in type .required :
539555 typeddict_encoder .append (
540556 f"if { repr (name )} in x and x[{ repr (name )} ] else None"
541557 )
542558 elif isinstance (prop , RiverIntersectionType ):
543559 encoder_name = TypeName (
544- f"encode_{ ensure_literal_type (type_name )} "
560+ f"encode_{ render_literal_type (type_name )} "
545561 )
546562 encoder_names .add (encoder_name )
547563 typeddict_encoder .append (f"{ encoder_name } (x[{ repr (name )} ])" )
@@ -552,11 +568,11 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
552568 safe_name = name
553569 if prop .type == "object" and not prop .patternProperties :
554570 encoder_name = TypeName (
555- f"encode_{ ensure_literal_type (type_name )} "
571+ f"encode_{ render_literal_type (type_name )} "
556572 )
557573 encoder_names .add (encoder_name )
558574 typeddict_encoder .append (
559- f"{ encoder_name } (x[{ repr (safe_name )} ])"
575+ f"{ render_literal_type ( encoder_name ) } (x[{ repr (safe_name )} ])"
560576 )
561577 if name not in prop .required :
562578 typeddict_encoder .append (
@@ -582,14 +598,14 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
582598 match type_name :
583599 case ListTypeExpr (inner_type_name ):
584600 encoder_name = TypeName (
585- f"encode_{ ensure_literal_type (inner_type_name )} "
601+ f"encode_{ render_literal_type (inner_type_name )} "
586602 )
587603 encoder_names .add (encoder_name )
588604 typeddict_encoder .append (
589605 dedent (
590606 f"""\
591607 [
592- { encoder_name } (y)
608+ { render_literal_type ( encoder_name ) } (y)
593609 for y in x[{ repr (name )} ]
594610 ]
595611 """ .rstrip ()
@@ -679,16 +695,18 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
679695
680696 if base_model == "TypedDict" :
681697 binding = "x" if needs_binding else "_"
682- encoder_name = TypeName (f"encode_{ prefix } " )
698+ encoder_name = TypeName (f"encode_{ render_literal_type ( prefix ) } " )
683699 encoder_names .add (encoder_name )
700+ _field_name = render_literal_type (encoder_name )
701+ _field_type = f"Callable[[{ repr (render_literal_type (prefix ))} ], Any]"
684702 current_chunks .insert (
685703 0 ,
686704 FileContents (
687705 "\n " .join (
688706 [
689707 dedent (
690708 f"""\
691- { encoder_name } : Callable[[ { repr ( prefix ) } ], Any] = (
709+ { _field_name } : { _field_type } = (
692710 lambda { binding } :
693711 """
694712 )
@@ -847,7 +865,7 @@ def __init__(self, client: river.Client[Any]):
847865 f"lambda xs: [encode_{ init_type_name } (x) for x in xs]"
848866 )
849867 else :
850- render_init_method = f"encode_{ ensure_literal_type (init_type )} "
868+ render_init_method = f"encode_{ render_literal_type (init_type )} "
851869 else :
852870 render_init_method = f"""\
853871 lambda x: TypeAdapter({ render_type_expr (init_type )} )
@@ -870,11 +888,11 @@ def __init__(self, client: river.Client[Any]):
870888 case ListTypeExpr (input_type_name ):
871889 render_input_method = f"""\
872890 lambda xs: [
873- encode_{ ensure_literal_type (input_type_name )} (x) for x in xs
891+ encode_{ render_literal_type (input_type_name )} (x) for x in xs
874892 ]
875893 """
876894 else :
877- render_input_method = f"encode_{ ensure_literal_type (input_type )} "
895+ render_input_method = f"encode_{ render_literal_type (input_type )} "
878896 else :
879897 render_input_method = f"""\
880898 lambda x: TypeAdapter({ render_type_expr (input_type )} )
@@ -957,9 +975,9 @@ async def {name}(
957975 f"""\
958976 async def { name } (
959977 self,
960- init: { init_type } ,
978+ init: { render_type_expr ( init_type ) } ,
961979 inputStream: AsyncIterable[{ render_type_expr (input_type )} ],
962- ) -> { output_type } :
980+ ) -> { render_type_expr ( output_type ) } :
963981 return await self.client.send_upload(
964982 { repr (schema_name )} ,
965983 { repr (name )} ,
@@ -1069,8 +1087,11 @@ async def {name}(
10691087 existing = emitted_files .get (file_path , FileContents (FILE_HEADER ))
10701088 emitted_files [file_path ] = FileContents ("\n " .join ([existing ] + contents ))
10711089
1090+ def render_names (xs : set [TypeName ]) -> str :
1091+ return ", " .join (sorted (render_literal_type (x ) for x in xs ))
1092+
10721093 rendered_imports = [
1073- f"from .{ dotted_modules } import { ', ' . join ( sorted ( names ) )} "
1094+ f"from .{ dotted_modules } import { render_names ( names )} "
10741095 for dotted_modules , names in imports .items ()
10751096 ]
10761097
0 commit comments