66import sys
77import warnings
88from importlib .metadata import entry_points
9- from typing import TYPE_CHECKING , Any
9+ from typing import TYPE_CHECKING , Any , Callable
1010
1111from xarray .backends .common import BACKEND_ENTRYPOINTS , BackendEntrypoint
1212
1313if TYPE_CHECKING :
1414 import os
15+ from importlib .metadata import EntryPoint , EntryPoints
1516 from io import BufferedIOBase
1617
1718 from xarray .backends .common import AbstractDataStore
1819
1920STANDARD_BACKENDS_ORDER = ["netcdf4" , "h5netcdf" , "scipy" ]
2021
2122
22- def remove_duplicates (entrypoints ) :
23+ def remove_duplicates (entrypoints : EntryPoints ) -> list [ EntryPoint ] :
2324 # sort and group entrypoints by name
2425 entrypoints = sorted (entrypoints , key = lambda ep : ep .name )
2526 entrypoints_grouped = itertools .groupby (entrypoints , key = lambda ep : ep .name )
@@ -42,7 +43,7 @@ def remove_duplicates(entrypoints):
4243 return unique_entrypoints
4344
4445
45- def detect_parameters (open_dataset ) :
46+ def detect_parameters (open_dataset : Callable ) -> tuple [ str , ...] :
4647 signature = inspect .signature (open_dataset )
4748 parameters = signature .parameters
4849 parameters_list = []
@@ -60,7 +61,9 @@ def detect_parameters(open_dataset):
6061 return tuple (parameters_list )
6162
6263
63- def backends_dict_from_pkg (entrypoints ):
64+ def backends_dict_from_pkg (
65+ entrypoints : list [EntryPoint ],
66+ ) -> dict [str , BackendEntrypoint ]:
6467 backend_entrypoints = {}
6568 for entrypoint in entrypoints :
6669 name = entrypoint .name
@@ -72,14 +75,16 @@ def backends_dict_from_pkg(entrypoints):
7275 return backend_entrypoints
7376
7477
75- def set_missing_parameters (backend_entrypoints ):
76- for name , backend in backend_entrypoints .items ():
78+ def set_missing_parameters (backend_entrypoints : dict [ str , BackendEntrypoint ] ):
79+ for _ , backend in backend_entrypoints .items ():
7780 if backend .open_dataset_parameters is None :
7881 open_dataset = backend .open_dataset
7982 backend .open_dataset_parameters = detect_parameters (open_dataset )
8083
8184
82- def sort_backends (backend_entrypoints ):
85+ def sort_backends (
86+ backend_entrypoints : dict [str , BackendEntrypoint ]
87+ ) -> dict [str , BackendEntrypoint ]:
8388 ordered_backends_entrypoints = {}
8489 for be_name in STANDARD_BACKENDS_ORDER :
8590 if backend_entrypoints .get (be_name , None ) is not None :
@@ -90,7 +95,7 @@ def sort_backends(backend_entrypoints):
9095 return ordered_backends_entrypoints
9196
9297
93- def build_engines (entrypoints ) -> dict [str , BackendEntrypoint ]:
98+ def build_engines (entrypoints : EntryPoints ) -> dict [str , BackendEntrypoint ]:
9499 backend_entrypoints = {}
95100 for backend_name , backend in BACKEND_ENTRYPOINTS .items ():
96101 if backend .available :
@@ -126,6 +131,13 @@ def list_engines() -> dict[str, BackendEntrypoint]:
126131 return build_engines (entrypoints )
127132
128133
134+ def refresh_engines () -> None :
135+ """Refreshes the backend engines based on installed packages."""
136+ list_engines .cache_clear ()
137+ for backend_entrypoint in BACKEND_ENTRYPOINTS .values ():
138+ backend_entrypoint ._set_availability ()
139+
140+
129141def guess_engine (
130142 store_spec : str | os .PathLike [Any ] | BufferedIOBase | AbstractDataStore ,
131143):
0 commit comments