Skip to content

Commit cfaed4e

Browse files
clee2000Chao1Han
authored andcommitted
Fix periodic debug tests failing due to FakeProcessGroup things (pytorch#165479)
These happen when building with CMAKE_BUILD_TYPE=RelWithAssert This should fix two types of failures that started with pytorch#163665 Disclaimer that I used a lot of AI since I don't how pybind works or what refcounts and pointers are, so idk if this is a good solution, or even a solution at all (fwiw the tests pass now) The first one type is Truncated: ``` default_pg, _ = _new_process_group_helper( File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 2096, in _new_process_group_helper backend_class = creator_fn(dist_backend_opts, backend_options) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/testing/_internal/distributed/fake_pg.py", line 25, in _create_fake_pg return FakeProcessGroup._create_internal( RuntimeError: new_refcount != 1 INTERNAL ASSERT FAILED at "/var/lib/jenkins/workspace/c10/util/intrusive_ptr.h":319, please report a bug to PyTorch. intrusive_ptr: Cannot increase refcount after it reached zero. Exception raised from retain_ at /var/lib/jenkins/workspace/c10/util/intrusive_ptr.h:319 (most recent call first): C++ CapturedTraceback: #4 std::_Function_handler<std::shared_ptr<c10::LazyValue<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > const> (), c10::SetStackTraceFetcher(std::function<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > ()>)::{lambda()#1}>::_M_invoke(std::_Any_data const&) from Logging.cpp:0 #5 c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) from ??:0 #6 c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) from ??:0 #7 c10::detail::torchInternalAssertFail(char const*, char const*, unsigned int, char const*, char const*) from ??:0 #8 void pybind11::class_<c10d::FakeProcessGroup, (anonymous namespace)::IntrusivePtrNoGilDestructor<c10d::FakeProcessGroup> >::init_instance<(anonymous namespace)::IntrusivePtrNoGilDestructor<c10d::FakeProcessGroup>, 0>(pybind11::detail::instance*, void const*) from init.cpp:0 #9 pybind11::detail::type_caster_generic::cast(void const*, pybind11::return_value_policy, pybind11::handle, pybind11::detail::type_info const*, void* (*)(void const*), void* (*)(void const*), void const*) from :0 #10 pybind11::cpp_function::initialize<torch::distributed::c10d::(anonymous namespace)::c10d_init(_object*, _object*)::{lambda(int, int, c10::intrusive_ptr<c10d::FakeProcessGroup::Options, c10::detail::intrusive_target_default_null_type<c10d::FakeProcessGroup::Options> >)pytorch#127}, c10::intrusive_ptr<c10d::FakeProcessGroup, c10::detail::intrusive_target_default_null_type<c10d::FakeProcessGroup> >, int, int, c10::intrusive_ptr<c10d::FakeProcessGroup::Options, c10::detail::intrusive_target_default_null_type<c10d::FakeProcessGroup::Options> >, pybind11::name, pybind11::scope, pybind11::sibling, pybind11::arg, pybind11::arg, pybind11::arg_v>(torch::distributed::c10d::(anonymous namespace)::c10d_init(_object*, _object*)::{lambda(int, int, c10::intrusive_ptr<c10d::FakeProcessGroup::Options, c10::detail::intrusive_target_default_null_type<c10d::FakeProcessGroup::Options> >)pytorch#127}&&, c10::intrusive_ptr<c10d::FakeProcessGroup, c10::detail::intrusive_target_default_null_type<c10d::FakeProcessGroup> > (*)(int, int, c10::intrusive_ptr<c10d::FakeProcessGroup::Options, c10::detail::intrusive_target_default_null_type<c10d::FakeProcessGroup::Options> >), pybind11::name const&, pybind11::scope const&, pybind11::sibling const&, pybind11::arg const&, pybind11::arg const&, pybind11::arg_v const&)::{lambda(pybind11::detail::function_call&)#3}::_FUN(pybind11::detail::function_call&) from init.cpp:0 ``` and I fix it here by getting rid of `DontIncreaseRefcount` and using make_intrusive to do the ref count handling instead. However, I also had to move the constructor to be public, which I think is not good, based on the reasoning of the original PR The other one type is ``` Traceback (most recent call last): File "/var/lib/jenkins/workspace/test/test_testing.py", line 2415, in test_no_warning_on_import self.assertEqual(out, "") File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/testing/_internal/common_utils.py", line 4233, in assertEqual raise error_metas.pop()[0].to_error( # type: ignore[index] AssertionError: String comparison failed: "/opt/conda/envs/py_3.10/lib/python3.10/s[352 chars]):\n" != '' - /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/distributed/__init__.py:29: FutureWarning: pybind11-bound class 'torch._C._distributed_c10d.FakeProcessGroup' is using an old-style placement-new '__init__' which has been deprecated. See the upgrade guide in pybind11's docs. This message is only visible when compiled in debug mode. - if is_available() and not torch._C._c10d_init(): To execute this test, run the following from the base repo dir: python test/test_testing.py TestImports.test_no_warning_on_import ``` which I fix by getting rid of the `__init__` which I think is ok since it'll just error if you try to make one? Pull Request resolved: pytorch#165479 Approved by: https://github.com/ezyang
1 parent e4bbaa2 commit cfaed4e

File tree

4 files changed

+4
-22
lines changed

4 files changed

+4
-22
lines changed

test/distributed/test_fake_pg.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -273,12 +273,7 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None):
273273
kwargs = {}
274274
return func(*args, **kwargs)
275275

276-
with self.assertRaisesRegex(
277-
RuntimeError,
278-
r"FakeProcessGroup cannot be constructed directly\. "
279-
r"Use torch\.distributed\.init_process_group\(backend='fake'\) instead to ensure "
280-
r"proper dispatch system integration\.",
281-
):
276+
with self.assertRaisesRegex(TypeError, r"No constructor defined"):
282277
fake_pg = FakeProcessGroup(rank=0, world_size=3)
283278

284279
with SimpleTensorMode():

torch/_C/_distributed_c10d.pyi

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -607,7 +607,6 @@ class ProcessGroup:
607607
def group_desc(self) -> str: ...
608608

609609
class FakeProcessGroup(Backend):
610-
def __init__(self, rank: int, world_size: int) -> None: ...
611610
@staticmethod
612611
def _create_internal(rank: int, world_size: int) -> FakeProcessGroup: ...
613612

torch/csrc/distributed/c10d/FakeProcessGroup.hpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,8 @@ class FakeProcessGroup : public Backend {
3333
int rank,
3434
int size,
3535
c10::intrusive_ptr<Options> options = c10::make_intrusive<Options>()) {
36-
return c10::intrusive_ptr<FakeProcessGroup>(
37-
new FakeProcessGroup(rank, size, std::move(options)),
38-
c10::raw::DontIncreaseRefcount{});
36+
return c10::make_intrusive<FakeProcessGroup>(
37+
rank, size, std::move(options));
3938
}
4039

4140
const std::string getBackendName() const override {
@@ -238,12 +237,12 @@ class FakeProcessGroup : public Backend {
238237
return c10::make_intrusive<FakeWork>();
239238
}
240239

241-
private:
242240
// Private constructor used by official APIs
243241
FakeProcessGroup(int rank, int size, c10::intrusive_ptr<Options> options)
244242
: Backend(rank, size), options_(std::move(options)) {}
245243
c10::intrusive_ptr<Options> options_;
246244

245+
private:
247246
void checkCollectiveError() {
248247
TORCH_CHECK(
249248
!options_ || !options_->error_on_collective,

torch/csrc/distributed/c10d/init.cpp

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3862,17 +3862,6 @@ such as `dist.all_reduce(tensor, async_op=True)`.
38623862
py::arg("world_size"),
38633863
py::arg("options") =
38643864
c10::make_intrusive<::c10d::FakeProcessGroup::Options>())
3865-
.def(
3866-
"__init__",
3867-
[](const py::object&,
3868-
const py::args& args,
3869-
const py::kwargs& kwargs) {
3870-
TORCH_CHECK(
3871-
false,
3872-
"FakeProcessGroup cannot be constructed directly. "
3873-
"Use torch.distributed.init_process_group(backend='fake') instead to ensure "
3874-
"proper dispatch system integration.");
3875-
})
38763865
.def_property_readonly(
38773866
"options", &::c10d::FakeProcessGroup::getBackendOptions);
38783867
auto fakeWork =

0 commit comments

Comments
 (0)