@@ -4305,7 +4305,7 @@ def visit_try_without_finally(self, s: TryStmt, try_frame: bool) -> None:
43054305 with self .binder .frame_context (can_skip = True , fall_through = 4 ):
43064306 typ = s .types [i ]
43074307 if typ :
4308- t = self .check_except_handler_test (typ )
4308+ t = self .check_except_handler_test (typ , s . is_star )
43094309 var = s .vars [i ]
43104310 if var :
43114311 # To support local variables, we make this a definition line,
@@ -4325,7 +4325,7 @@ def visit_try_without_finally(self, s: TryStmt, try_frame: bool) -> None:
43254325 if s .else_body :
43264326 self .accept (s .else_body )
43274327
4328- def check_except_handler_test (self , n : Expression ) -> Type :
4328+ def check_except_handler_test (self , n : Expression , is_star : bool ) -> Type :
43294329 """Type check an exception handler test clause."""
43304330 typ = self .expr_checker .accept (n )
43314331
@@ -4341,22 +4341,47 @@ def check_except_handler_test(self, n: Expression) -> Type:
43414341 item = ttype .items [0 ]
43424342 if not item .is_type_obj ():
43434343 self .fail (message_registry .INVALID_EXCEPTION_TYPE , n )
4344- return AnyType ( TypeOfAny . from_error )
4345- exc_type = item .ret_type
4344+ return self . default_exception_type ( is_star )
4345+ exc_type = erase_typevars ( item .ret_type )
43464346 elif isinstance (ttype , TypeType ):
43474347 exc_type = ttype .item
43484348 else :
43494349 self .fail (message_registry .INVALID_EXCEPTION_TYPE , n )
4350- return AnyType ( TypeOfAny . from_error )
4350+ return self . default_exception_type ( is_star )
43514351
43524352 if not is_subtype (exc_type , self .named_type ("builtins.BaseException" )):
43534353 self .fail (message_registry .INVALID_EXCEPTION_TYPE , n )
4354- return AnyType ( TypeOfAny . from_error )
4354+ return self . default_exception_type ( is_star )
43554355
43564356 all_types .append (exc_type )
43574357
4358+ if is_star :
4359+ new_all_types : list [Type ] = []
4360+ for typ in all_types :
4361+ if is_proper_subtype (typ , self .named_type ("builtins.BaseExceptionGroup" )):
4362+ self .fail (message_registry .INVALID_EXCEPTION_GROUP , n )
4363+ new_all_types .append (AnyType (TypeOfAny .from_error ))
4364+ else :
4365+ new_all_types .append (typ )
4366+ return self .wrap_exception_group (new_all_types )
43584367 return make_simplified_union (all_types )
43594368
4369+ def default_exception_type (self , is_star : bool ) -> Type :
4370+ """Exception type to return in case of a previous type error."""
4371+ any_type = AnyType (TypeOfAny .from_error )
4372+ if is_star :
4373+ return self .named_generic_type ("builtins.ExceptionGroup" , [any_type ])
4374+ return any_type
4375+
4376+ def wrap_exception_group (self , types : Sequence [Type ]) -> Type :
4377+ """Transform except* variable type into an appropriate exception group."""
4378+ arg = make_simplified_union (types )
4379+ if is_subtype (arg , self .named_type ("builtins.Exception" )):
4380+ base = "builtins.ExceptionGroup"
4381+ else :
4382+ base = "builtins.BaseExceptionGroup"
4383+ return self .named_generic_type (base , [arg ])
4384+
43604385 def get_types_from_except_handler (self , typ : Type , n : Expression ) -> list [Type ]:
43614386 """Helper for check_except_handler_test to retrieve handler types."""
43624387 typ = get_proper_type (typ )
0 commit comments