Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit ee64f6f

Browse files
Ryan SepassiCopybara-Service
Ryan Sepassi
authored and
Copybara-Service
committed
Add a generic create_registry method
PiperOrigin-RevId: 230422498
1 parent 57b3d59 commit ee64f6f

File tree

2 files changed

+92
-1
lines changed

2 files changed

+92
-1
lines changed

tensor2tensor/utils/registry.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ class MyModel(T2TModel):
4444
from __future__ import division
4545
from __future__ import print_function
4646

47+
import collections
48+
4749
from tensor2tensor.utils import misc_utils
4850
import tensorflow as tf
4951
from tensorflow.python.util import tf_inspect as inspect
@@ -57,6 +59,68 @@ class MyModel(T2TModel):
5759
_PRUNING_STRATEGY = {}
5860
_RANGED_HPARAMS = {}
5961

62+
# Key: registry name, Value: Registry
63+
_GENERIC_REGISTRIES = {}
64+
Registry = collections.namedtuple(
65+
"_Registry", ["register", "get", "list", "registry"])
66+
67+
68+
def registry(registry_name):
69+
"""Returns `Registry` created by `create_registry`."""
70+
if registry_name not in _GENERIC_REGISTRIES:
71+
raise KeyError("No registry named %s. Available:\n%s" % (
72+
registry_name, sorted(_GENERIC_REGISTRIES)))
73+
return _GENERIC_REGISTRIES[registry_name]
74+
75+
76+
def create_registry(registry_name):
77+
"""Create a generic object registry.
78+
79+
Args:
80+
registry_name: str, name of the object registry.
81+
82+
Returns:
83+
`Registry` that contains functions for register (decorator), get, and list.
84+
85+
Raises:
86+
KeyError: if `registry_name` is a pre-existing registry.
87+
"""
88+
if registry_name in _GENERIC_REGISTRIES:
89+
raise KeyError(
90+
"Registry %s already exists." % registry_name)
91+
92+
registry_ = {}
93+
94+
def register(name):
95+
"""Returns decorator to register an object."""
96+
97+
def register_dec(obj):
98+
if name in registry_:
99+
raise KeyError(
100+
"Registry %s already contains key %s." % (registry_name, name))
101+
registry_[name] = obj
102+
return obj
103+
104+
return register_dec
105+
106+
def get(name):
107+
if name not in registry_:
108+
raise KeyError(
109+
"Registry %s contains no object named %s" % (registry_name, name))
110+
return registry_[name]
111+
112+
def list_registry():
113+
return sorted(registry_)
114+
115+
registry_obj = Registry(
116+
register=register,
117+
get=get,
118+
list=list_registry,
119+
registry=registry_,
120+
)
121+
_GENERIC_REGISTRIES[registry_name] = registry_obj
122+
return registry_obj
123+
60124

61125
def _reset():
62126
for ctr in [_MODELS, _HPARAMS, _RANGED_HPARAMS, _ATTACK_PARAMS]:

tensor2tensor/utils/registry_test.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,8 +194,35 @@ def rhp_bad2(a, b): # pylint: disable=unused-argument
194194
pass
195195

196196

197+
class CreateRegistry(tf.test.TestCase):
198+
"""Test class for `create_registry`."""
199+
200+
def testCreateRegistry(self):
201+
my_registry = registry.create_registry("test_reg1")
202+
self.assertIs(my_registry, registry.registry("test_reg1"))
203+
204+
# Use as decorator on a fn
205+
@my_registry.register("foo")
206+
def some_fn(num):
207+
return num + 2
208+
209+
# Register a regular object
210+
pod_obj = 4
211+
my_registry.register("bar")(pod_obj)
212+
213+
# Register a class
214+
@my_registry.register("foobar")
215+
class A(object):
216+
pass
217+
218+
self.assertEqual(9, my_registry.get("foo")(7))
219+
self.assertEqual(["bar", "foo", "foobar"], my_registry.list())
220+
foobar = my_registry.get("foobar")
221+
self.assertTrue(isinstance(foobar(), A))
222+
223+
197224
class RegistryTest(tf.test.TestCase):
198-
""" Test class for common functions."""
225+
"""Test class for common functions."""
199226

200227
def testRegistryHelp(self):
201228
help_str = registry.help_string()

0 commit comments

Comments
 (0)