Skip to content

Commit

Permalink
chore(registry): add registry write lock
Browse files Browse the repository at this point in the history
  • Loading branch information
XuehaiPan committed Nov 21, 2022
1 parent 09d1dcb commit 30173c3
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions optree/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import inspect
from collections import OrderedDict, defaultdict, deque
from operator import methodcaller
from threading import Lock
from typing import (
Any,
Callable,
Expand Down Expand Up @@ -59,6 +60,7 @@


__GLOBAL_NAMESPACE: str = object() # type: ignore[assignment]
__REGISTRY_LOCK: Lock = Lock()


def register_pytree_node(
Expand Down Expand Up @@ -176,14 +178,18 @@ def register_pytree_node(
raise TypeError(f'The namespace must be a string, got {namespace}.')
if namespace == '':
raise ValueError('The namespace cannot be an empty string.')

registration_key: Union[Type, Tuple[str, Type]]
if namespace is __GLOBAL_NAMESPACE:
_C.register_node(cls, flatten_func, unflatten_func, '')
CustomTreeNode.register(cls) # pylint: disable=no-member
_nodetype_registry[cls] = PyTreeNodeRegistryEntry(flatten_func, unflatten_func)
registration_key = cls
namespace = ''
else:
registration_key = (namespace, cls)

with __REGISTRY_LOCK:
_C.register_node(cls, flatten_func, unflatten_func, namespace)
CustomTreeNode.register(cls) # pylint: disable=no-member
_nodetype_registry[(namespace, cls)] = PyTreeNodeRegistryEntry(flatten_func, unflatten_func)
_nodetype_registry[registration_key] = PyTreeNodeRegistryEntry(flatten_func, unflatten_func)
return cls


Expand Down

0 comments on commit 30173c3

Please sign in to comment.