@@ -5943,9 +5943,13 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
59435943 ast:: Expr :: Set ( set) => self . infer_set_expression ( set, tcx) ,
59445944 ast:: Expr :: Dict ( dict) => self . infer_dict_expression ( dict, tcx) ,
59455945 ast:: Expr :: Generator ( generator) => self . infer_generator_expression ( generator) ,
5946- ast:: Expr :: ListComp ( listcomp) => self . infer_list_comprehension_expression ( listcomp) ,
5947- ast:: Expr :: DictComp ( dictcomp) => self . infer_dict_comprehension_expression ( dictcomp) ,
5948- ast:: Expr :: SetComp ( setcomp) => self . infer_set_comprehension_expression ( setcomp) ,
5946+ ast:: Expr :: ListComp ( listcomp) => {
5947+ self . infer_list_comprehension_expression ( listcomp, tcx)
5948+ }
5949+ ast:: Expr :: DictComp ( dictcomp) => {
5950+ self . infer_dict_comprehension_expression ( dictcomp, tcx)
5951+ }
5952+ ast:: Expr :: SetComp ( setcomp) => self . infer_set_comprehension_expression ( setcomp, tcx) ,
59495953 ast:: Expr :: Name ( name) => self . infer_name_expression ( name) ,
59505954 ast:: Expr :: Attribute ( attribute) => self . infer_attribute_expression ( attribute) ,
59515955 ast:: Expr :: UnaryOp ( unary_op) => self . infer_unary_expression ( unary_op) ,
@@ -6450,52 +6454,115 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
64506454 )
64516455 }
64526456
6453- fn infer_list_comprehension_expression ( & mut self , listcomp : & ast:: ExprListComp ) -> Type < ' db > {
6457+ /// Return a specialization of the collection class (list, dict, set) based on the type context and the inferred
6458+ /// element / key-value types from the comprehension expression.
6459+ fn infer_comprehension_specialization (
6460+ & self ,
6461+ collection_class : KnownClass ,
6462+ inferred_element_types : & [ Type < ' db > ] ,
6463+ tcx : TypeContext < ' db > ,
6464+ ) -> Type < ' db > {
6465+ // Remove any union elements of that are unrelated to the collection type.
6466+ let tcx = tcx. map ( |annotation| {
6467+ annotation. filter_disjoint_elements (
6468+ self . db ( ) ,
6469+ collection_class. to_instance ( self . db ( ) ) ,
6470+ InferableTypeVars :: None ,
6471+ )
6472+ } ) ;
6473+
6474+ if let Some ( annotated_element_types) = tcx
6475+ . known_specialization ( self . db ( ) , collection_class)
6476+ . map ( |specialization| specialization. types ( self . db ( ) ) )
6477+ && annotated_element_types
6478+ . iter ( )
6479+ . zip ( inferred_element_types. iter ( ) )
6480+ . all ( |( annotated, inferred) | inferred. is_assignable_to ( self . db ( ) , * annotated) )
6481+ {
6482+ collection_class
6483+ . to_specialized_instance ( self . db ( ) , annotated_element_types. iter ( ) . copied ( ) )
6484+ } else {
6485+ collection_class. to_specialized_instance (
6486+ self . db ( ) ,
6487+ inferred_element_types
6488+ . iter ( )
6489+ . map ( |ty| ty. promote_literals ( self . db ( ) , TypeContext :: default ( ) ) ) ,
6490+ )
6491+ }
6492+ }
6493+
6494+ fn infer_list_comprehension_expression (
6495+ & mut self ,
6496+ listcomp : & ast:: ExprListComp ,
6497+ tcx : TypeContext < ' db > ,
6498+ ) -> Type < ' db > {
64546499 let ast:: ExprListComp {
64556500 range : _,
64566501 node_index : _,
6457- elt : _ ,
6502+ elt,
64586503 generators,
64596504 } = listcomp;
64606505
64616506 self . infer_first_comprehension_iter ( generators) ;
64626507
6463- KnownClass :: List
6464- . to_specialized_instance ( self . db ( ) , [ todo_type ! ( "list comprehension element type" ) ] )
6508+ let scope_id = self
6509+ . index
6510+ . node_scope ( NodeWithScopeRef :: ListComprehension ( listcomp) ) ;
6511+ let scope = scope_id. to_scope_id ( self . db ( ) , self . file ( ) ) ;
6512+ let inference = infer_scope_types ( self . db ( ) , scope) ;
6513+ let element_type = inference. expression_type ( elt. as_ref ( ) ) ;
6514+
6515+ self . infer_comprehension_specialization ( KnownClass :: List , & [ element_type] , tcx)
64656516 }
64666517
6467- fn infer_dict_comprehension_expression ( & mut self , dictcomp : & ast:: ExprDictComp ) -> Type < ' db > {
6518+ fn infer_dict_comprehension_expression (
6519+ & mut self ,
6520+ dictcomp : & ast:: ExprDictComp ,
6521+ tcx : TypeContext < ' db > ,
6522+ ) -> Type < ' db > {
64686523 let ast:: ExprDictComp {
64696524 range : _,
64706525 node_index : _,
6471- key : _ ,
6472- value : _ ,
6526+ key,
6527+ value,
64736528 generators,
64746529 } = dictcomp;
64756530
64766531 self . infer_first_comprehension_iter ( generators) ;
64776532
6478- KnownClass :: Dict . to_specialized_instance (
6479- self . db ( ) ,
6480- [
6481- todo_type ! ( "dict comprehension key type" ) ,
6482- todo_type ! ( "dict comprehension value type" ) ,
6483- ] ,
6484- )
6533+ let scope_id = self
6534+ . index
6535+ . node_scope ( NodeWithScopeRef :: DictComprehension ( dictcomp) ) ;
6536+ let scope = scope_id. to_scope_id ( self . db ( ) , self . file ( ) ) ;
6537+ let inference = infer_scope_types ( self . db ( ) , scope) ;
6538+ let key_type = inference. expression_type ( key. as_ref ( ) ) ;
6539+ let value_type = inference. expression_type ( value. as_ref ( ) ) ;
6540+
6541+ self . infer_comprehension_specialization ( KnownClass :: Dict , & [ key_type, value_type] , tcx)
64856542 }
64866543
6487- fn infer_set_comprehension_expression ( & mut self , setcomp : & ast:: ExprSetComp ) -> Type < ' db > {
6544+ fn infer_set_comprehension_expression (
6545+ & mut self ,
6546+ setcomp : & ast:: ExprSetComp ,
6547+ tcx : TypeContext < ' db > ,
6548+ ) -> Type < ' db > {
64886549 let ast:: ExprSetComp {
64896550 range : _,
64906551 node_index : _,
6491- elt : _ ,
6552+ elt,
64926553 generators,
64936554 } = setcomp;
64946555
64956556 self . infer_first_comprehension_iter ( generators) ;
64966557
6497- KnownClass :: Set
6498- . to_specialized_instance ( self . db ( ) , [ todo_type ! ( "set comprehension element type" ) ] )
6558+ let scope_id = self
6559+ . index
6560+ . node_scope ( NodeWithScopeRef :: SetComprehension ( setcomp) ) ;
6561+ let scope = scope_id. to_scope_id ( self . db ( ) , self . file ( ) ) ;
6562+ let inference = infer_scope_types ( self . db ( ) , scope) ;
6563+ let element_type = inference. expression_type ( elt. as_ref ( ) ) ;
6564+
6565+ self . infer_comprehension_specialization ( KnownClass :: Set , & [ element_type] , tcx)
64996566 }
65006567
65016568 fn infer_generator_expression_scope ( & mut self , generator : & ast:: ExprGenerator ) {
0 commit comments