Skip to content

Commit c9bd6d7

Browse files
wesmpitrou
authored andcommitted
ARROW-6301: [C++][Python] Prevent ExtensionType-related race condition in Python process teardown by exposing shared_ptr to global "ExtensionTypeRegistry"
A user observed a race condition in process teardown where the extension type registry was destroyed before the PyExtensionType could be unregistered. This is the simplest way I could think of how to keep the registry alive at the end of the process. Closes apache#5174 from wesm/ARROW-6301 and squashes the following commits: 3852ddb <Wes McKinney> Add missing 'override' keywords 5524868 <Wes McKinney> Externalize extension type registry so that Python can keep it alive until PyExtensionType is unregistered 84e7bdd <Wes McKinney> tmp Authored-by: Wes McKinney <wesm+git@apache.org> Signed-off-by: Antoine Pitrou <antoine@python.org>
1 parent 7f3ff24 commit c9bd6d7

File tree

4 files changed

+97
-25
lines changed

4 files changed

+97
-25
lines changed

cpp/src/arrow/extension_type.cc

Lines changed: 63 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -65,39 +65,77 @@ void ExtensionArray::SetData(const std::shared_ptr<ArrayData>& data) {
6565
storage_ = MakeArray(storage_data);
6666
}
6767

68-
std::unordered_map<std::string, std::shared_ptr<ExtensionType>> g_extension_registry;
69-
std::mutex g_extension_registry_guard;
68+
class ExtensionTypeRegistryImpl : public ExtensionTypeRegistry {
69+
public:
70+
ExtensionTypeRegistryImpl() {}
71+
72+
Status RegisterType(std::shared_ptr<ExtensionType> type) override {
73+
std::lock_guard<std::mutex> lock(lock_);
74+
std::string type_name = type->extension_name();
75+
auto it = name_to_type_.find(type_name);
76+
if (it != name_to_type_.end()) {
77+
return Status::KeyError("A type extension with name ", type_name,
78+
" already defined");
79+
}
80+
name_to_type_[type_name] = std::move(type);
81+
return Status::OK();
82+
}
7083

71-
Status RegisterExtensionType(std::shared_ptr<ExtensionType> type) {
72-
std::lock_guard<std::mutex> lock_(g_extension_registry_guard);
73-
std::string type_name = type->extension_name();
74-
auto it = g_extension_registry.find(type_name);
75-
if (it != g_extension_registry.end()) {
76-
return Status::KeyError("A type extension with name ", type_name, " already defined");
84+
Status UnregisterType(const std::string& type_name) override {
85+
std::lock_guard<std::mutex> lock(lock_);
86+
auto it = name_to_type_.find(type_name);
87+
if (it == name_to_type_.end()) {
88+
return Status::KeyError("No type extension with name ", type_name, " found");
89+
}
90+
name_to_type_.erase(it);
91+
return Status::OK();
92+
}
93+
94+
std::shared_ptr<ExtensionType> GetType(const std::string& type_name) override {
95+
std::lock_guard<std::mutex> lock(lock_);
96+
auto it = name_to_type_.find(type_name);
97+
if (it == name_to_type_.end()) {
98+
return nullptr;
99+
} else {
100+
return it->second;
101+
}
102+
return nullptr;
77103
}
78-
g_extension_registry[type_name] = std::move(type);
79-
return Status::OK();
104+
105+
private:
106+
std::mutex lock_;
107+
std::unordered_map<std::string, std::shared_ptr<ExtensionType>> name_to_type_;
108+
};
109+
110+
static std::shared_ptr<ExtensionTypeRegistry> g_registry;
111+
static std::once_flag registry_initialized;
112+
113+
namespace internal {
114+
115+
static void CreateGlobalRegistry() {
116+
g_registry = std::make_shared<ExtensionTypeRegistryImpl>();
117+
}
118+
119+
} // namespace internal
120+
121+
std::shared_ptr<ExtensionTypeRegistry> ExtensionTypeRegistry::GetGlobalRegistry() {
122+
std::call_once(registry_initialized, internal::CreateGlobalRegistry);
123+
return g_registry;
124+
}
125+
126+
Status RegisterExtensionType(std::shared_ptr<ExtensionType> type) {
127+
auto registry = ExtensionTypeRegistry::GetGlobalRegistry();
128+
return registry->RegisterType(type);
80129
}
81130

82131
Status UnregisterExtensionType(const std::string& type_name) {
83-
std::lock_guard<std::mutex> lock_(g_extension_registry_guard);
84-
auto it = g_extension_registry.find(type_name);
85-
if (it == g_extension_registry.end()) {
86-
return Status::KeyError("No type extension with name ", type_name, " found");
87-
}
88-
g_extension_registry.erase(it);
89-
return Status::OK();
132+
auto registry = ExtensionTypeRegistry::GetGlobalRegistry();
133+
return registry->UnregisterType(type_name);
90134
}
91135

92136
std::shared_ptr<ExtensionType> GetExtensionType(const std::string& type_name) {
93-
std::lock_guard<std::mutex> lock_(g_extension_registry_guard);
94-
auto it = g_extension_registry.find(type_name);
95-
if (it == g_extension_registry.end()) {
96-
return nullptr;
97-
} else {
98-
return it->second;
99-
}
100-
return nullptr;
137+
auto registry = ExtensionTypeRegistry::GetGlobalRegistry();
138+
return registry->GetType(type_name);
101139
}
102140

103141
} // namespace arrow

cpp/src/arrow/extension_type.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,20 @@ class ARROW_EXPORT ExtensionArray : public Array {
108108
std::shared_ptr<Array> storage_;
109109
};
110110

111+
class ARROW_EXPORT ExtensionTypeRegistry {
112+
public:
113+
/// \brief Provide access to the global registry to allow code to control for
114+
/// race conditions in registry teardown when some types need to be
115+
/// unregistered and destroyed first
116+
static std::shared_ptr<ExtensionTypeRegistry> GetGlobalRegistry();
117+
118+
virtual ~ExtensionTypeRegistry() = default;
119+
120+
virtual Status RegisterType(std::shared_ptr<ExtensionType> type) = 0;
121+
virtual Status UnregisterType(const std::string& type_name) = 0;
122+
virtual std::shared_ptr<ExtensionType> GetType(const std::string& type_name) = 0;
123+
};
124+
111125
/// \brief Register an extension type globally. The name returned by the type's
112126
/// extension_name() method should be unique. This method is thread-safe
113127
/// \param[in] type an instance of the extension type

python/pyarrow/includes/libarrow.pxd

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1415,6 +1415,10 @@ cdef extern from 'arrow/python/inference.h' namespace 'arrow::py':
14151415

14161416

14171417
cdef extern from 'arrow/extension_type.h' namespace 'arrow':
1418+
cdef cppclass CExtensionTypeRegistry" arrow::ExtensionTypeRegistry":
1419+
@staticmethod
1420+
shared_ptr[CExtensionTypeRegistry] GetGlobalRegistry()
1421+
14181422
cdef cppclass CExtensionType" arrow::ExtensionType"(CDataType):
14191423
c_string extension_name()
14201424
shared_ptr[CDataType] storage_type()

python/pyarrow/types.pxi

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1961,6 +1961,21 @@ def is_float_value(object obj):
19611961
return IsPyFloat(obj)
19621962

19631963

1964+
cdef class _ExtensionRegistryNanny:
1965+
# Keep the registry alive until we have unregistered PyExtensionType
1966+
cdef:
1967+
shared_ptr[CExtensionTypeRegistry] registry
1968+
1969+
def __cinit__(self):
1970+
self.registry = CExtensionTypeRegistry.GetGlobalRegistry()
1971+
1972+
def release_registry(self):
1973+
self.registry.reset()
1974+
1975+
1976+
_registry_nanny = _ExtensionRegistryNanny()
1977+
1978+
19641979
def _register_py_extension_type():
19651980
cdef:
19661981
DataType storage_type
@@ -1980,6 +1995,7 @@ def _unregister_py_extension_type():
19801995
# teardown stage, it will invoke CPython APIs such as Py_DECREF
19811996
# with a destroyed interpreter.
19821997
check_status(UnregisterPyExtensionType())
1998+
_registry_nanny.release_registry()
19831999

19842000

19852001
_register_py_extension_type()

0 commit comments

Comments
 (0)