88 TYPE_CHECKING ,
99 Any ,
1010 TypeVar ,
11+ cast ,
12+ overload ,
1113)
1214
1315import numpy as np
1416
1517from pandas ._libs .hashtable import object_hash
1618from pandas ._typing import (
1719 DtypeObj ,
20+ npt ,
1821 type_t ,
1922)
2023from pandas .errors import AbstractMethodError
2932 from pandas .core .arrays import ExtensionArray
3033
3134 # To parameterize on same ExtensionDtype
32- E = TypeVar ("E " , bound = "ExtensionDtype" )
35+ ExtensionDtypeT = TypeVar ("ExtensionDtypeT " , bound = "ExtensionDtype" )
3336
3437
3538class ExtensionDtype :
@@ -206,7 +209,9 @@ def construct_array_type(cls) -> type_t[ExtensionArray]:
206209 raise AbstractMethodError (cls )
207210
208211 @classmethod
209- def construct_from_string (cls , string : str ):
212+ def construct_from_string (
213+ cls : type_t [ExtensionDtypeT ], string : str
214+ ) -> ExtensionDtypeT :
210215 r"""
211216 Construct this type from a string.
212217
@@ -368,7 +373,7 @@ def _can_hold_na(self) -> bool:
368373 return True
369374
370375
371- def register_extension_dtype (cls : type [ E ]) -> type [ E ]:
376+ def register_extension_dtype (cls : type_t [ ExtensionDtypeT ]) -> type_t [ ExtensionDtypeT ]:
372377 """
373378 Register an ExtensionType with pandas as class decorator.
374379
@@ -409,9 +414,9 @@ class Registry:
409414 """
410415
411416 def __init__ (self ):
412- self .dtypes : list [type [ExtensionDtype ]] = []
417+ self .dtypes : list [type_t [ExtensionDtype ]] = []
413418
414- def register (self , dtype : type [ExtensionDtype ]) -> None :
419+ def register (self , dtype : type_t [ExtensionDtype ]) -> None :
415420 """
416421 Parameters
417422 ----------
@@ -422,22 +427,46 @@ def register(self, dtype: type[ExtensionDtype]) -> None:
422427
423428 self .dtypes .append (dtype )
424429
425- def find (self , dtype : type [ExtensionDtype ] | str ) -> type [ExtensionDtype ] | None :
430+ @overload
431+ def find (self , dtype : type_t [ExtensionDtypeT ]) -> type_t [ExtensionDtypeT ]:
432+ ...
433+
434+ @overload
435+ def find (self , dtype : ExtensionDtypeT ) -> ExtensionDtypeT :
436+ ...
437+
438+ @overload
439+ def find (self , dtype : str ) -> ExtensionDtype | None :
440+ ...
441+
442+ @overload
443+ def find (
444+ self , dtype : npt .DTypeLike
445+ ) -> type_t [ExtensionDtype ] | ExtensionDtype | None :
446+ ...
447+
448+ def find (
449+ self , dtype : type_t [ExtensionDtype ] | ExtensionDtype | npt .DTypeLike
450+ ) -> type_t [ExtensionDtype ] | ExtensionDtype | None :
426451 """
427452 Parameters
428453 ----------
429- dtype : Type[ ExtensionDtype] or str
454+ dtype : ExtensionDtype class or instance or str or numpy dtype or python type
430455
431456 Returns
432457 -------
433458 return the first matching dtype, otherwise return None
434459 """
435460 if not isinstance (dtype , str ):
436- dtype_type = dtype
461+ dtype_type : type_t
437462 if not isinstance (dtype , type ):
438463 dtype_type = type (dtype )
464+ else :
465+ dtype_type = dtype
439466 if issubclass (dtype_type , ExtensionDtype ):
440- return dtype
467+ # cast needed here as mypy doesn't know we have figured
468+ # out it is an ExtensionDtype or type_t[ExtensionDtype]
469+ return cast ("ExtensionDtype | type_t[ExtensionDtype]" , dtype )
441470
442471 return None
443472
0 commit comments