From 18eccfbe4246d1ec0342ff320d6f2d82198507e3 Mon Sep 17 00:00:00 2001 From: Meghan Lele Date: Wed, 2 Dec 2020 12:28:09 -0800 Subject: [PATCH] [JIT] Fix clang-tidy warnings in jit/python (#47985) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/47985 Test Plan: Imported from OSS Reviewed By: ZolotukhinM Differential Revision: D25258644 Pulled By: SplitInfinity fbshipit-source-id: dfc15dc62c148f79f4e99fd058a6bf2d071ccbb5 --- torch/csrc/jit/python/init.cpp | 30 ++++++------ torch/csrc/jit/python/pybind_utils.h | 45 +++++++++--------- torch/csrc/jit/python/python_arg_flatten.cpp | 2 +- torch/csrc/jit/python/python_ir.cpp | 31 +++++++------ .../csrc/jit/python/python_sugared_value.cpp | 8 ++-- torch/csrc/jit/python/python_sugared_value.h | 11 +++-- torch/csrc/jit/python/python_tracer.cpp | 8 ++-- torch/csrc/jit/python/python_tree_views.cpp | 46 ++++++++++--------- torch/csrc/jit/python/script_init.cpp | 22 ++++----- 9 files changed, 109 insertions(+), 94 deletions(-) diff --git a/torch/csrc/jit/python/init.cpp b/torch/csrc/jit/python/init.cpp index 8a82f844084266..254ad141e589f8 100644 --- a/torch/csrc/jit/python/init.cpp +++ b/torch/csrc/jit/python/init.cpp @@ -348,7 +348,7 @@ void initJITBindings(PyObject* module) { py::arg("prefix") = "top") .def( "_jit_pass_remove_inplace_ops", - [](std::shared_ptr g) { return RemoveInplaceOps(g); }) + [](const std::shared_ptr& g) { return RemoveInplaceOps(g); }) .def("_jit_pass_constant_pooling", ConstantPooling) .def( "_jit_pass_create_functional_graphs", @@ -378,7 +378,9 @@ void initJITBindings(PyObject* module) { .def("_jit_pass_lint", LintGraph) .def( "_jit_pass_complete_shape_analysis", - [](std::shared_ptr graph, py::tuple inputs, bool with_grad) { + [](const std::shared_ptr& graph, + const py::tuple& inputs, + bool with_grad) { ArgumentSpecCreator arg_spec_creator(*graph); Stack stack; stack.reserve(inputs.size()); // captures? @@ -400,7 +402,7 @@ void initJITBindings(PyObject* module) { }) .def( "_jit_interpret_graph", - [](std::shared_ptr& graph, py::tuple inputs) { + [](std::shared_ptr& graph, const py::tuple& inputs) { Stack stack; stack.reserve(inputs.size()); // captures? for (auto& obj : inputs) { @@ -441,7 +443,9 @@ void initJITBindings(PyObject* module) { .def("_jit_pass_erase_shape_information", EraseShapeInformation) .def( "_jit_pass_create_autodiff_subgraphs", - [](std::shared_ptr graph) { CreateAutodiffSubgraphs(graph); }) + [](const std::shared_ptr& graph) { + CreateAutodiffSubgraphs(graph); + }) #if defined(BUILDING_TESTS) && !defined(__HIP_PLATFORM_HCC__) .def( "_jit_run_cpp_tests", @@ -469,7 +473,7 @@ void initJITBindings(PyObject* module) { }) .def( "_jit_unflatten", - [](autograd::variable_list vars, python::IODescriptor& desc) { + [](const autograd::variable_list& vars, python::IODescriptor& desc) { return py::reinterpret_steal( python::unflatten(vars, desc)); }) @@ -494,8 +498,8 @@ void initJITBindings(PyObject* module) { }) .def( "_jit_check_alias_annotation", - [](std::shared_ptr g, - py::tuple args, + [](const std::shared_ptr& g, + const py::tuple& args, const std::string& unqualified_op_name) { auto stack = toTraceableStack(args); checkAliasAnnotation(g, std::move(stack), unqualified_op_name); @@ -553,7 +557,7 @@ void initJITBindings(PyObject* module) { .def( "_jit_try_infer_type", [](py::object obj) -> TypePtr { - auto match = tryToInferType(obj); + auto match = tryToInferType(std::move(obj)); if (match.success()) { return match.type(); } @@ -647,7 +651,7 @@ void initJITBindings(PyObject* module) { [](std::shared_ptr& g) { return FuseTensorExprs(g); }) .def( "_jit_fuser_get_fused_kernel_code", - [](Graph& g, std::vector inps) { + [](Graph& g, const std::vector& inps) { return debugGetFusedKernelCode(g, inps); }) .def( @@ -945,7 +949,7 @@ void initJITBindings(PyObject* module) { py::class_(m, "PyTorchFileReader") .def(py::init()) .def(py::init([](const py::object& buffer) { - auto adapter = std::make_unique(std::move(buffer)); + auto adapter = std::make_unique(buffer); return std::make_unique(std::move(adapter)); })) .def( @@ -1135,7 +1139,7 @@ void initJITBindings(PyObject* module) { m.def("_jit_get_custom_class_schemas", customClassSchemasForBCCheck); m.def("_jit_get_schemas_for_operator", [](const std::string& qualified_name) { auto symbol = Symbol::fromQualString(qualified_name); - auto operations = getAllOperatorsFor(symbol); + const auto& operations = getAllOperatorsFor(symbol); return fmap(operations, [](const std::shared_ptr& op) { return op->schema(); }); @@ -1281,8 +1285,8 @@ void initJITBindings(PyObject* module) { }); }); - m.def("_jit_assert_is_instance", [](py::object obj, TypePtr type) { - toIValue(obj, type); + m.def("_jit_assert_is_instance", [](py::object obj, const TypePtr& type) { + toIValue(std::move(obj), type); }); initPythonCustomClassBindings(module); diff --git a/torch/csrc/jit/python/pybind_utils.h b/torch/csrc/jit/python/pybind_utils.h index cb6ac0181db930..99b439aa185f93 100644 --- a/torch/csrc/jit/python/pybind_utils.h +++ b/torch/csrc/jit/python/pybind_utils.h @@ -280,8 +280,8 @@ inline TypedIValue toDictKeyIValue(py::handle key) { } inline c10::optional unifyOrInitializeType( - TypePtr accum, - TypePtr unify) { + const TypePtr& accum, + const TypePtr& unify) { if (!accum) { return unify; } @@ -499,7 +499,7 @@ inline InferredType tryToInferContainerType(py::handle input) { } } -inline bool isTraceableType(TypePtr type) { +inline bool isTraceableType(const TypePtr& type) { if (type->isSubtypeOf(TensorType::get())) { return true; } @@ -512,7 +512,9 @@ inline bool isTraceableType(TypePtr type) { return std::all_of( tuple_type->elements().begin(), tuple_type->elements().end(), - [](TypePtr element_type) { return isTraceableType(element_type); }); + [](const TypePtr& element_type) { + return isTraceableType(element_type); + }); } if (auto dict_type = type->cast()) { @@ -545,13 +547,13 @@ inline Stack toTraceableStack(const py::tuple& inputs) { inline IValue createGenericList(py::handle obj, const TypePtr& elem_type) { auto elems = c10::impl::GenericList(elem_type); for (auto elem : obj) { - elems.push_back(toIValue(std::move(elem), elem_type)); + elems.push_back(toIValue(elem, elem_type)); } return IValue(std::move(elems)); } inline IValue createGenericDict( - py::dict obj, + const py::dict& obj, const TypePtr& key_type, const TypePtr& value_type) { c10::impl::GenericDict elems(key_type, value_type); @@ -747,7 +749,7 @@ inline IValue toIValue( "a TorchScript compatible type, did you forget to", "turn it into a user defined TorchScript class?")); } - res = toIValue(std::move(obj), classType); + res = toIValue(obj, classType); } // check if the classType conform with the interface or not std::stringstream why_not; @@ -1074,9 +1076,9 @@ inline Stack createStackForSchema( push(stack, std::move(*self)); } // First push all positional args. - for (size_t i = 0; i < args.size(); ++i) { + for (const auto& arg : args) { // Use the type information from the schema to convert the PyObject. - push(stack, argumentToIValue(schema, stack.size(), args[i])); + push(stack, argumentToIValue(schema, stack.size(), arg)); } // Now for every remaining non-positional argument in the schema, look for it @@ -1153,15 +1155,16 @@ inline Stack evilDeprecatedBadCreateStackDoNotUse( // tracing graph. inline py::object runAndInsertCall( Function& callee, - tuple_slice args, - py::kwargs kwargs, + const tuple_slice& args, + const py::kwargs& kwargs, c10::optional self, // Lambda that tells this function how to insert `callee` into the graph if // we're tracing. - std::function callInserter) { - auto stack = createStackForSchema( - callee.getSchema(), std::move(args), std::move(kwargs), std::move(self)); - auto tracing_state = tracer::getTracingState(); + const std::function& + callInserter) { + auto stack = + createStackForSchema(callee.getSchema(), args, kwargs, std::move(self)); + const auto& tracing_state = tracer::getTracingState(); if (!tracing_state) { pybind11::gil_scoped_release no_gil_guard; // If we're not tracing, just run the callee as normal. @@ -1211,8 +1214,8 @@ inline py::object runAndInsertCall( inline py::object invokeScriptFunctionFromPython( Function& callee, - tuple_slice args, - py::kwargs kwargs) { + const tuple_slice& args, + const py::kwargs& kwargs) { return runAndInsertCall( callee, args, @@ -1225,8 +1228,8 @@ inline py::object invokeScriptFunctionFromPython( inline py::object invokeScriptMethodFromPython( Method& callee, - tuple_slice args, - py::kwargs kwargs) { + const tuple_slice& args, + const py::kwargs& kwargs) { auto self = callee.owner()._ivalue(); return runAndInsertCall( callee.function(), @@ -1241,14 +1244,14 @@ inline py::object invokeScriptMethodFromPython( inline py::object invokeOperatorFromPython( const std::vector>& operations, py::args args, - py::kwargs kwargs) { + const py::kwargs& kwargs) { Stack stack; if (operations.size() == 1) { const Operator& op = *operations.at(0); // Create a stack full of the arguments and keyword arguments. stack = createStackForSchema( - op.schema(), std::move(args), std::move(kwargs), c10::nullopt); + op.schema(), std::move(args), kwargs, c10::nullopt); pybind11::gil_scoped_release no_gil_guard; op.getOperation()(&stack); diff --git a/torch/csrc/jit/python/python_arg_flatten.cpp b/torch/csrc/jit/python/python_arg_flatten.cpp index b854ae14387a3f..adb77eaffba658 100644 --- a/torch/csrc/jit/python/python_arg_flatten.cpp +++ b/torch/csrc/jit/python/python_arg_flatten.cpp @@ -100,7 +100,7 @@ py::object cast_dict(std::vector objs) { py::dict sequence = {}; for (size_t i = 0; i < num_objs; ++i) { py::tuple obj = py::reinterpret_borrow(objs[i]); - sequence[obj[0]] = std::move(obj[1]); + sequence[obj[0]] = obj[1]; } return std::move(sequence); } diff --git a/torch/csrc/jit/python/python_ir.cpp b/torch/csrc/jit/python/python_ir.cpp index 20d0e6272a19a4..f5cdea1e7eb44d 100644 --- a/torch/csrc/jit/python/python_ir.cpp +++ b/torch/csrc/jit/python/python_ir.cpp @@ -16,6 +16,7 @@ #include #include #include +#include namespace torch { namespace jit { @@ -216,12 +217,12 @@ void initPythonIRBindings(PyObject* module_) { .def( "dump_alias_db", [](std::shared_ptr g) { - AliasDb db(g); + AliasDb db(std::move(g)); db.dump(); }) .def( "_export_onnx", - [](const std::shared_ptr g, + [](const std::shared_ptr& g, const std::map& initializers, int64_t onnx_opset_version, const std::unordered_map< @@ -282,7 +283,7 @@ void initPythonIRBindings(PyObject* module_) { py::arg("onnx_file_path") = std::string()) .def( "_pretty_print_onnx", - [](const std::shared_ptr g, + [](const std::shared_ptr& g, const std::map& initializers, int64_t onnx_opset_version, bool defer_weight_export, @@ -389,7 +390,7 @@ void initPythonIRBindings(PyObject* module_) { .GS(prependNode) .def( "insertConstant", - [](Graph& g, IValue ival) { return g.insertConstant(ival); }) + [](Graph& g, const IValue& ival) { return g.insertConstant(ival); }) .GS(lint) .GS(insertNode); #undef GS @@ -587,7 +588,7 @@ void initPythonIRBindings(PyObject* module_) { // Tensor (t_) -- manually written to unwrap the variable into a tensor. .def( "t_", - [](Node& n, const char* name, torch::autograd::Variable v) { + [](Node& n, const char* name, const torch::autograd::Variable& v) { AT_ASSERT(!v.requires_grad()); return n.t_(Symbol::attr(name), v); }) @@ -599,7 +600,7 @@ void initPythonIRBindings(PyObject* module_) { "ts_", [](Node& n, const char* name, - std::vector vs) { + const std::vector& vs) { std::vector tensors; tensors.reserve(vs.size()); for (auto& variable : vs) { @@ -621,7 +622,7 @@ void initPythonIRBindings(PyObject* module_) { }) .def( "z_", - [](Node& n, const char* name, at::Tensor v) { + [](Node& n, const char* name, const at::Tensor& v) { return n.t_( Symbol::attr(name), autograd::Variable(v.view({})).set_requires_grad(false)); @@ -729,7 +730,7 @@ void initPythonIRBindings(PyObject* module_) { }) .def( "isSubtypeOf", - [](std::shared_ptr& self, std::shared_ptr other) { + [](std::shared_ptr& self, std::shared_ptr& other) { if (!other) { return false; } @@ -767,8 +768,9 @@ void initPythonIRBindings(PyObject* module_) { .def_static("get", &NoneType::get); py::class_>(m, "TupleType") - .def( - py::init([](std::vector a) { return TupleType::create(a); })) + .def(py::init([](std::vector a) { + return TupleType::create(std::move(a)); + })) .def("elements", [](TupleType& self) { std::vector types; for (const auto& type : self.elements()) { @@ -785,21 +787,22 @@ void initPythonIRBindings(PyObject* module_) { .def("getElementType", &ListType::getElementType); py::class_>(m, "DictType") .def(py::init([](TypePtr key, TypePtr value) { - return DictType::create(key, value); + return DictType::create(std::move(key), std::move(value)); })) .def("getKeyType", &DictType::getKeyType) .def("getValueType", &DictType::getValueType); py::class_>( m, "OptionalType") - .def(py::init([](TypePtr a) { return OptionalType::create(a); })) + .def(py::init( + [](TypePtr a) { return OptionalType::create(std::move(a)); })) .def_static("ofTensor", &OptionalType::ofTensor) .def("getElementType", &OptionalType::getElementType); py::class_>(m, "RRefType") - .def(py::init([](TypePtr a) { return RRefType::create(a); })) + .def(py::init([](TypePtr a) { return RRefType::create(std::move(a)); })) .def("getElementType", &RRefType::getElementType); py::class_>(m, "FutureType") - .def(py::init([](TypePtr a) { return FutureType::create(a); })) + .def(py::init([](TypePtr a) { return FutureType::create(std::move(a)); })) .def("getElementType", &FutureType::getElementType); py::class_>(m, "ClassType") diff --git a/torch/csrc/jit/python/python_sugared_value.cpp b/torch/csrc/jit/python/python_sugared_value.cpp index 401cec6f695dd8..933d3bb1a8677e 100644 --- a/torch/csrc/jit/python/python_sugared_value.cpp +++ b/torch/csrc/jit/python/python_sugared_value.cpp @@ -92,12 +92,12 @@ FunctionSchema PythonValue::getSchema( auto types_it = arg_types.begin(); for (; types_it != arg_types.end(); ++types_it, ++names_it) { - args.push_back(Argument( + args.emplace_back( /*name=*/*names_it, /*type=*/std::move(*types_it), /*N=*/c10::nullopt, /*default_value=*/c10::nullopt, - /*kwarg_only=*/false)); + /*kwarg_only=*/false); } rets.push_back(Argument("0", std::move(ret_type), {}, {}, false)); } @@ -293,7 +293,7 @@ SugaredValuePtr ModuleValue::getitem( void checkInterface( const SourceRange& loc, Function& m, - std::shared_ptr self, + const std::shared_ptr& self, const std::string& field) { if (self->asValue(loc, m)->type()->cast()) { throw ErrorReport(loc) @@ -307,7 +307,7 @@ void recurseThroughNestedModules( Function& m, std::vector& keys, std::vector& values, - std::shared_ptr self, + std::shared_ptr& self, const std::string& prefix, const std::string& field) { auto prefix_value = diff --git a/torch/csrc/jit/python/python_sugared_value.h b/torch/csrc/jit/python/python_sugared_value.h index 12a5d87b063ed8..b5d8f4490b3e06 100644 --- a/torch/csrc/jit/python/python_sugared_value.h +++ b/torch/csrc/jit/python/python_sugared_value.h @@ -7,6 +7,7 @@ #include #include #include +#include #include namespace torch { @@ -110,8 +111,8 @@ struct VISIBILITY_HIDDEN ConstantParameterList : public SugaredValue { }; struct VISIBILITY_HIDDEN ModuleDictMethod : public SugaredValue { - explicit ModuleDictMethod(SugaredValuePtr iterable, const std::string& name) - : iterable_(iterable), name_(name){}; + explicit ModuleDictMethod(SugaredValuePtr iterable, std::string name) + : iterable_(std::move(iterable)), name_(std::move(name)){}; std::string kind() const override { return name_; @@ -217,7 +218,7 @@ void recurseThroughNestedModules( Function& m, std::vector& keys, std::vector& values, - std::shared_ptr self, + std::shared_ptr& self, const std::string& prefix, const std::string& field); @@ -249,7 +250,7 @@ struct VISIBILITY_HIDDEN SugaredDict : public SugaredValue { Function& m, const std::string& field) override; - SugaredValuePtr iter(const SourceRange& loc, Function& m) { + SugaredValuePtr iter(const SourceRange& loc, Function& m) override { return keys_; }; @@ -316,7 +317,7 @@ struct VISIBILITY_HIDDEN PythonExceptionValue : public ExceptionValue { // Python Slice class. struct VISIBILITY_HIDDEN PythonSliceClass : public SugaredValue { - explicit PythonSliceClass() {} + explicit PythonSliceClass() = default; std::string kind() const override { return "Python slice class"; diff --git a/torch/csrc/jit/python/python_tracer.cpp b/torch/csrc/jit/python/python_tracer.cpp index 43a11f75a76851..550ba12a46c081 100644 --- a/torch/csrc/jit/python/python_tracer.cpp +++ b/torch/csrc/jit/python/python_tracer.cpp @@ -176,7 +176,9 @@ void initPythonTracerBindings(PyObject* module) { }) .def( "set_graph", - [](TracingState& s, std::shared_ptr g) { s.graph = g; }) + [](TracingState& s, std::shared_ptr g) { + s.graph = std::move(g); + }) .def("graph", [](TracingState& s) { return s.graph; }); m.def("_tracer_warn_use_python", []() { tracer::setWarn(pythonWarn); }); @@ -191,7 +193,7 @@ void initPythonTracerBindings(PyObject* module) { py::arg("self") = nullptr); m.def("_get_tracing_state", []() { return getTracingState(); }); m.def("_set_tracing_state", [](std::shared_ptr state) { - return setTracingState(state); + return setTracingState(std::move(state)); }); m.def("_get_value_trace", [](const Variable& var) { return getValueTrace(var); @@ -199,7 +201,7 @@ void initPythonTracerBindings(PyObject* module) { m.def("_set_value_trace", [](const Variable& var, Value* value) { return setValueTrace(var, value); }); - m.def("_tracer_set_get_unique_name_fn", [](py::function func) { + m.def("_tracer_set_get_unique_name_fn", [](const py::function& func) { const auto& tracing_state = getTracingState(); AT_ASSERT(tracing_state); tracing_state->lookup_var_name_fn = diff --git a/torch/csrc/jit/python/python_tree_views.cpp b/torch/csrc/jit/python/python_tree_views.cpp index 49b8f7d4f4afeb..1355352c827890 100644 --- a/torch/csrc/jit/python/python_tree_views.cpp +++ b/torch/csrc/jit/python/python_tree_views.cpp @@ -25,7 +25,7 @@ c10::optional maybeConvertToString(const py::object& obj) { struct SourceRangeFactory { SourceRangeFactory( std::string text, - py::object filename, + const py::object& filename, size_t file_lineno, size_t leading_whitespace_chars) : source_(std::make_shared( @@ -200,7 +200,7 @@ void initTreeViewBindings(PyObject* module) { r, wrap_list(r, std::move(params)), wrap_maybe(r, return_type)); })); - py::class_(m, "Delete").def(py::init([](Expr expr) { + py::class_(m, "Delete").def(py::init([](const Expr& expr) { return Delete::create(expr); })); @@ -227,12 +227,13 @@ void initTreeViewBindings(PyObject* module) { wrap_maybe(li.range(), type)); })); py::class_(m, "AugAssign") - .def(py::init([](const Expr& lhs, std::string kind_str, const Expr& rhs) { - const auto& r = lhs.range(); - auto kind = - AugAssignKind(Compound::create(stringToKind(kind_str), r, {})); - return AugAssign::create(r, lhs, kind, rhs); - })); + .def(py::init( + [](const Expr& lhs, const std::string& kind_str, const Expr& rhs) { + const auto& r = lhs.range(); + auto kind = + AugAssignKind(Compound::create(stringToKind(kind_str), r, {})); + return AugAssign::create(r, lhs, kind, rhs); + })); py::class_(m, "Return") .def(py::init([](const SourceRange& range, Expr* value) { return Return::create( @@ -282,7 +283,7 @@ void initTreeViewBindings(PyObject* module) { wrap_list(range, std::move(targets)), wrap_list(range, std::move(body))); })); - py::class_(m, "For").def(py::init([](const SourceRange range, + py::class_(m, "For").def(py::init([](const SourceRange& range, std::vector& targets, std::vector& itrs, std::vector body) { @@ -301,25 +302,26 @@ void initTreeViewBindings(PyObject* module) { [](const Ident& name) { return Var::create(name.range(), name); })) .def_property_readonly("name", [](const Var& var) { return var.name(); }); py::class_(m, "BinOp") - .def(py::init([](std::string kind, const Expr& lhs, const Expr& rhs) { - return BinOp::create(lhs.range(), stringToKind(kind), lhs, rhs); - })); + .def(py::init( + [](const std::string& kind, const Expr& lhs, const Expr& rhs) { + return BinOp::create(lhs.range(), stringToKind(kind), lhs, rhs); + })); // NB: we take range here, because unary ops precede their exprs, so we need // to include them py::class_(m, "UnaryOp") - .def(py::init( - [](const SourceRange& range, std::string kind, const Expr& expr) { - auto resolved_kind = stringToKind(kind); - resolved_kind = - resolved_kind == '-' ? TK_UNARY_MINUS : resolved_kind; - return UnaryOp::create(range, resolved_kind, expr); - })); + .def(py::init([](const SourceRange& range, + const std::string& kind, + const Expr& expr) { + auto resolved_kind = stringToKind(kind); + resolved_kind = resolved_kind == '-' ? TK_UNARY_MINUS : resolved_kind; + return UnaryOp::create(range, resolved_kind, expr); + })); py::class_(m, "Const") - .def(py::init([](const SourceRange& range, std::string value) { + .def(py::init([](const SourceRange& range, const std::string& value) { return Const::create(range, value); })); py::class_(m, "StringLiteral") - .def(py::init([](const SourceRange& range, std::string value) { + .def(py::init([](const SourceRange& range, const std::string& value) { return StringLiteral::create(range, value); })); py::class_(m, "Apply") @@ -383,7 +385,7 @@ void initTreeViewBindings(PyObject* module) { wrap_maybe(range, step)); })); py::class_(m, "Starred") - .def(py::init([](const SourceRange& range, Expr expr) { + .def(py::init([](const SourceRange& range, const Expr& expr) { return Starred::create(range, expr); })); py::class_, TreeView>(m, "EmptyTypeAnnotation") diff --git a/torch/csrc/jit/python/script_init.cpp b/torch/csrc/jit/python/script_init.cpp index 7c571384e481be..feab73df6d1b92 100644 --- a/torch/csrc/jit/python/script_init.cpp +++ b/torch/csrc/jit/python/script_init.cpp @@ -339,7 +339,7 @@ static StrongFunctionPtr script_compile_overloaded_function( const c10::QualifiedName& name, const Decl& overload_decl, const Def& implementation_def, - ResolutionCallback rcb, + const ResolutionCallback& rcb, const FunctionDefaults& implementation_defaults, const py::object& signature) { if (signature.is(py::none())) { @@ -356,7 +356,7 @@ static StrongFunctionPtr script_compile_overloaded_function( /*properties=*/{}, /*propResolvers=*/{}, {new_def}, - {pythonResolver(std::move(rcb))}, + {pythonResolver(rcb)}, nullptr, true); TORCH_INTERNAL_ASSERT(defined_functions.size() == 1); @@ -377,14 +377,14 @@ static StrongFunctionPtr script_compile_function( const c10::QualifiedName& name, const Def& def, const FunctionDefaults& defaults, - ResolutionCallback rcb) { + const ResolutionCallback& rcb) { auto cu = get_python_cu(); auto defined_functions = cu->define( QualifiedName(name.prefix()), /*properties=*/{}, /*propResolvers=*/{}, {def}, - {pythonResolver(std::move(rcb))}, + {pythonResolver(rcb)}, nullptr, true); TORCH_INTERNAL_ASSERT(defined_functions.size() == 1); @@ -1243,19 +1243,19 @@ void initJitScriptBindings(PyObject* module) { "_jit_script_compile", [](const std::string& qualname, const Def& def, - ResolutionCallback rcb, + const ResolutionCallback& rcb, const FunctionDefaults& defaults) { C10_LOG_API_USAGE_ONCE("torch.script.compile"); const auto name = c10::QualifiedName(qualname); TORCH_INTERNAL_ASSERT(name.name() == def.name().name()); - return script_compile_function(name, def, defaults, std::move(rcb)); + return script_compile_function(name, def, defaults, rcb); }); m.def( "_jit_script_compile_overload", [](const std::string& qualname, const Decl& overload_decl, const Def& implementation_def, - ResolutionCallback rcb, + const ResolutionCallback& rcb, const FunctionDefaults& implementation_defaults, const py::object& signature) { const auto name = c10::QualifiedName(qualname); @@ -1263,7 +1263,7 @@ void initJitScriptBindings(PyObject* module) { name, overload_decl, implementation_def, - std::move(rcb), + rcb, implementation_defaults, signature); }); @@ -1368,12 +1368,12 @@ void initJitScriptBindings(PyObject* module) { "_jit_script_interface_compile", [](const std::string& qualifiedName, const ClassDef& classDef, - ResolutionCallback rcb, + const ResolutionCallback& rcb, bool is_module) { get_python_cu()->define_interface( c10::QualifiedName(qualifiedName), classDef, - pythonResolver(std::move(rcb)), + pythonResolver(rcb), is_module); }); @@ -1475,7 +1475,7 @@ void initJitScriptBindings(PyObject* module) { // TODO this should go in the global Python CU auto cu = std::make_shared(); c10::QualifiedName name(qualname); - auto fn = cu->create_function(std::move(name), graph); + auto fn = cu->create_function(std::move(name), std::move(graph)); return StrongFunctionPtr(std::move(cu), fn); }); m.def("_ivalue_tags_match", ivalue_tags_match);