Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit 23f40e7

Browse files
authored
RegistryMixin - tooling for easier registry/plugin patterns across NM repos (#365)
* RegistryMixin - tooling for easier registry/plugin patterns across NM repos * add 'registry' to method names * add registered_names * testing * add class level setting for requires_subclass * docstring code example typo * review suggestion
1 parent 77db6a0 commit 23f40e7

File tree

2 files changed

+317
-0
lines changed

2 files changed

+317
-0
lines changed

src/sparsezoo/utils/registry.py

Lines changed: 242 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,242 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""
16+
Universal registry to support registration and loading of child classes and plugins
17+
of neuralmagic utilities
18+
"""
19+
20+
import importlib
21+
from collections import defaultdict
22+
from typing import Any, Dict, List, Optional, Type
23+
24+
25+
__all__ = [
26+
"RegistryMixin",
27+
"register",
28+
"get_from_registry",
29+
"registered_names",
30+
]
31+
32+
33+
_REGISTRY: Dict[Type, Dict[str, Any]] = defaultdict(dict)
34+
35+
36+
class RegistryMixin:
37+
"""
38+
Universal registry to support registration and loading of child classes and plugins
39+
of neuralmagic utilities.
40+
41+
Classes that require a registry or plugins may add the `RegistryMixin` and use
42+
`register` and `load` as the main entrypoints for adding new implementations and
43+
loading requested values from its registry.
44+
45+
If a class should only have its child classes in its registry, the class should
46+
set the static attribute `registry_requires_subclass` to True
47+
48+
example
49+
```python
50+
class Dataset(RegistryMixin):
51+
pass
52+
53+
54+
# register with default name
55+
@Dataset.register()
56+
class ImageNetDataset(Dataset):
57+
pass
58+
59+
# load as "ImageNetDataset"
60+
imagenet = Dataset.load("ImageNetDataset")
61+
62+
# register with custom name
63+
@Dataset.register(name="cifar-dataset")
64+
class Cifar(Dataset):
65+
pass
66+
67+
# load as "cifar-dataset"
68+
cifar = Dataset.load_from_registry("cifar-dataset")
69+
70+
# load from custom file that implements a dataset
71+
mnist = Dataset.load_from_registry("/path/to/mnnist_dataset.py:MnistDataset")
72+
```
73+
"""
74+
75+
# set to True in child class to add check that registered/retrieved values
76+
# implement the class it is registered to
77+
registry_requires_subclass: bool = False
78+
79+
@classmethod
80+
def register(cls, name: Optional[str] = None):
81+
"""
82+
Decorator for registering a value (ie class or function) wrapped by this
83+
decorator to the base class (class that .register is called from)
84+
85+
:param name: name to register the wrapped value as, defaults to value.__name__
86+
:return: register decorator
87+
"""
88+
89+
def decorator(value: Any):
90+
cls.register_value(value, name=name)
91+
return value
92+
93+
return decorator
94+
95+
@classmethod
96+
def register_value(cls, value: Any, name: Optional[str] = None):
97+
"""
98+
Registers the given value to the class `.register_value` is called from
99+
:param value: value to register
100+
:param name: name to register the wrapped value as, defaults to value.__name__
101+
"""
102+
register(
103+
parent_class=cls,
104+
value=value,
105+
name=name,
106+
require_subclass=cls.registry_requires_subclass,
107+
)
108+
109+
@classmethod
110+
def load_from_registry(cls, name: str, **constructor_kwargs) -> object:
111+
"""
112+
:param name: name of registered class to load
113+
:param constructor_kwargs: arguments to pass to the constructor retrieved
114+
from the registry
115+
:return: loaded object registered to this class under the given name,
116+
constructed with the given kwargs. Raises error if the name is
117+
not found in the registry
118+
"""
119+
constructor = cls.get_value_from_registry(name=name)
120+
return constructor(**constructor_kwargs)
121+
122+
@classmethod
123+
def get_value_from_registry(cls, name: str):
124+
"""
125+
:param name: name to retrieve from the registry
126+
:return: value from retrieved the registry for the given name, raises
127+
error if not found
128+
"""
129+
return get_from_registry(
130+
parent_class=cls,
131+
name=name,
132+
require_subclass=cls.registry_requires_subclass,
133+
)
134+
135+
@classmethod
136+
def registered_names(cls) -> List[str]:
137+
"""
138+
:return: list of all names registered to this class
139+
"""
140+
return registered_names(cls)
141+
142+
143+
def register(
144+
parent_class: Type,
145+
value: Any,
146+
name: Optional[str] = None,
147+
require_subclass: bool = False,
148+
):
149+
"""
150+
:param parent_class: class to register the name under
151+
:param value: value to register
152+
:param name: name to register the wrapped value as, defaults to value.__name__
153+
:param require_subclass: require that value is a subclass of the class this
154+
method is called from
155+
"""
156+
if name is None:
157+
# default name
158+
name = value.__name__
159+
160+
if require_subclass:
161+
_validate_subclass(parent_class, value)
162+
163+
if name in _REGISTRY[parent_class]:
164+
# name already exists - raise error if two different values are attempting
165+
# to share the same name
166+
registered_value = _REGISTRY[parent_class][name]
167+
if registered_value is not value:
168+
raise RuntimeError(
169+
f"Attempting to register name {name} as {value} "
170+
f"however {name} has already been registered as {registered_value}"
171+
)
172+
else:
173+
_REGISTRY[parent_class][name] = value
174+
175+
176+
def get_from_registry(
177+
parent_class: Type, name: str, require_subclass: bool = False
178+
) -> Any:
179+
"""
180+
:param parent_class: class that the name is registered under
181+
:param name: name to retrieve from the registry of the class
182+
:param require_subclass: require that value is a subclass of the class this
183+
method is called from
184+
:return: value from retrieved the registry for the given name, raises
185+
error if not found
186+
"""
187+
188+
if ":" in name:
189+
# user specifying specific module to load and value to import
190+
module_path, value_name = name.split(":")
191+
retrieved_value = _import_and_get_value_from_module(module_path, value_name)
192+
else:
193+
# look up name in registry
194+
retrieved_value = _REGISTRY[parent_class].get(name)
195+
if retrieved_value is None:
196+
raise ValueError(
197+
f"Unable to find {name} registered under type {parent_class}. "
198+
f"Registered values for {parent_class}: "
199+
f"{registered_names(parent_class)}"
200+
)
201+
202+
if require_subclass:
203+
_validate_subclass(parent_class, retrieved_value)
204+
205+
return retrieved_value
206+
207+
208+
def registered_names(parent_class: Type) -> List[str]:
209+
"""
210+
:param parent_class: class to look up the registry of
211+
:return: all names registered to the given class
212+
"""
213+
return list(_REGISTRY[parent_class].keys())
214+
215+
216+
def _import_and_get_value_from_module(module_path: str, value_name: str) -> Any:
217+
# import the given module path and try to get the value_name if it is included
218+
# in the module
219+
220+
# load module
221+
spec = importlib.util.spec_from_file_location(
222+
f"plugin_module_for_{value_name}", module_path
223+
)
224+
module = importlib.util.module_from_spec(spec)
225+
spec.loader.exec_module(module)
226+
227+
# get value from module
228+
value = getattr(module, value_name, None)
229+
230+
if not value:
231+
raise RuntimeError(
232+
f"Unable to find attribute {value_name} in module {module_path}"
233+
)
234+
return value
235+
236+
237+
def _validate_subclass(parent_class: Type, child_class: Type):
238+
if not issubclass(child_class, parent_class):
239+
raise ValueError(
240+
f"class {child_class} is not a subclass of the class it is "
241+
f"registered for: {parent_class}."
242+
)
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import pytest
16+
17+
from sparsezoo.utils.registry import RegistryMixin
18+
19+
20+
def test_registery_flow_single():
21+
class Foo(RegistryMixin):
22+
pass
23+
24+
@Foo.register()
25+
class Foo1(Foo):
26+
pass
27+
28+
@Foo.register(name="name_2")
29+
class Foo2(Foo):
30+
pass
31+
32+
assert {"Foo1", "name_2"} == set(Foo.registered_names())
33+
34+
with pytest.raises(ValueError):
35+
Foo.get_value_from_registry("Foo2")
36+
37+
assert Foo.get_value_from_registry("Foo1") is Foo1
38+
assert isinstance(Foo.load_from_registry("name_2"), Foo2)
39+
40+
41+
def test_registry_flow_multiple():
42+
class Foo(RegistryMixin):
43+
pass
44+
45+
class Bar(RegistryMixin):
46+
pass
47+
48+
@Foo.register()
49+
class Foo1(Foo):
50+
pass
51+
52+
@Bar.register()
53+
class Bar1(Bar):
54+
pass
55+
56+
assert ["Foo1"] == Foo.registered_names()
57+
assert ["Bar1"] == Bar.registered_names()
58+
59+
assert Foo.get_value_from_registry("Foo1") is Foo1
60+
assert Bar.get_value_from_registry("Bar1") is Bar1
61+
62+
63+
def test_registry_requires_subclass():
64+
class Foo(RegistryMixin):
65+
registry_requires_subclass = True
66+
67+
@Foo.register()
68+
class Foo1(Foo):
69+
pass
70+
71+
with pytest.raises(ValueError):
72+
73+
@Foo.register()
74+
class NotFoo:
75+
pass

0 commit comments

Comments
 (0)