Skip to content

Commit

Permalink
Fix pybind11 warnings in python_rpc_handler.cpp (pytorch#27284)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#27284

The warnings related to usage of the deprecated != operator. Instead
of checking the member field on every function call, we can check it
once, on construction of PythonRpcHandler.

Test Plan: Imported from OSS

Differential Revision: D17808213

Pulled By: pietern

fbshipit-source-id: 022c8f77f266942c49c55b1729e62dbb06262d77
  • Loading branch information
pietern authored and facebook-github-bot committed Oct 8, 2019
1 parent 0d22f3b commit c742918
Showing 1 changed file with 17 additions and 5 deletions.
22 changes: 17 additions & 5 deletions torch/csrc/distributed/rpc/python_rpc_handler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,27 @@ namespace torch {
namespace distributed {
namespace rpc {

namespace {

py::object getFunction(const py::object& module, const char* name) {
py::object fn = module.attr(name);
TORCH_CHECK(
py::isinstance<py::function>(fn),
"attribute ",
name,
" is not a function");
return fn;
}

} // namespace

PythonRpcHandler::PythonRpcHandler() {
AutoGIL ag;
py::object module =
py::module::import("torch.distributed.internal_rpc_utils");
runUDFFunction_ = module.attr("run_python_udf_internal");
loadResultFunction_ = module.attr("load_python_udf_result_internal");
serializeFunction_ = module.attr("serialize");
runUDFFunction_ = getFunction(module, "run_python_udf_internal");
loadResultFunction_ = getFunction(module, "load_python_udf_result_internal");
serializeFunction_ = getFunction(module, "serialize");
}

PythonRpcHandler& PythonRpcHandler::getInstance() {
Expand All @@ -24,7 +38,6 @@ std::vector<char> PythonRpcHandler::generatePythonUDFResult(
std::vector<torch::Tensor>& responseTensorTable) {
AutoGIL ag;
auto pargs = py::bytes(pickledPayload.data(), pickledPayload.size());
TORCH_CHECK(runUDFFunction_ != nullptr, "runUDFFunction_ is nullptr");
py::tuple pres =
serializeFunction_(runUDFFunction_(pargs, requestTensorTable));
const auto& presStr = pres[0].cast<std::string>();
Expand All @@ -38,7 +51,6 @@ py::object PythonRpcHandler::loadPythonUDFResult(
const std::vector<torch::Tensor>& tensorTable) {
AutoGIL ag;
auto pargs = py::bytes(pickledPayload.data(), pickledPayload.size());
TORCH_CHECK(loadResultFunction_ != nullptr, "loadResultFunction_ is nullptr");
return loadResultFunction_(pargs, tensorTable);
}

Expand Down

0 comments on commit c742918

Please sign in to comment.