Skip to content

Commit f2f42e5

Browse files
Skylion007pytorchmergebot
authored andcommitted
Apply some std::move and param value fixups to aten (pytorch#92901)
I noticed a few perf issues in the latest ATen and decided to fixup a few other miscellaneous ones I noticed recently. Pull Request resolved: pytorch#92901 Approved by: https://github.com/ezyang
1 parent b073c09 commit f2f42e5

9 files changed

+33
-26
lines changed

aten/src/ATen/LegacyBatchedTensorImpl.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#pragma once
22

33
#include <bitset>
4+
#include <utility>
45

56
#include <ATen/ArrayRef.h>
67
#include <ATen/SmallVector.h>
@@ -120,7 +121,7 @@ inline BatchedTensorImpl* maybeGetBatchedImpl(Tensor tensor) {
120121
if (!isBatchedTensor(tensor)) {
121122
return nullptr;
122123
}
123-
return unsafeGetBatchedImpl(tensor);
124+
return unsafeGetBatchedImpl(std::move(tensor));
124125
}
125126

126127
// Returns a bitset. If bit i is set, then that means dim i is a batchdim.

aten/src/ATen/TensorIndexing.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020

2121
#include <ATen/core/List.h>
2222

23+
#include <utility>
24+
2325
namespace at {
2426
namespace indexing {
2527

@@ -230,7 +232,7 @@ static inline Tensor applySlice(
230232
return self;
231233
}
232234
}
233-
return self.slice_symint(dim, start, stop, step);
235+
return self.slice_symint(dim, start, stop, std::move(step));
234236
}
235237

236238
static inline Tensor applySelect(

aten/src/ATen/ThreadLocalPythonObjects.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,27 +2,29 @@
22
#include <ATen/ThreadLocalPythonObjects.h>
33
#include <c10/util/Exception.h>
44

5+
#include <utility>
6+
57
namespace at {
68
namespace impl {
79

810
static thread_local ThreadLocalPythonObjects py_objects;
911

1012

11-
void ThreadLocalPythonObjects::set(std::string key, std::shared_ptr<SafePyObject> value) {
12-
py_objects.obj_dict_[key] = value;
13+
void ThreadLocalPythonObjects::set(const std::string& key, std::shared_ptr<SafePyObject> value) {
14+
py_objects.obj_dict_[key] = std::move(value);
1315
}
1416

15-
const std::shared_ptr<SafePyObject>& ThreadLocalPythonObjects::get(std::string key) {
17+
const std::shared_ptr<SafePyObject>& ThreadLocalPythonObjects::get(const std::string& key) {
1618
TORCH_CHECK(py_objects.obj_dict_.count(key));
1719
return py_objects.obj_dict_[key];
1820
}
1921

20-
bool ThreadLocalPythonObjects::contains(std::string key) {
22+
bool ThreadLocalPythonObjects::contains(const std::string& key) {
2123
return py_objects.obj_dict_.count(key);
2224
}
2325

24-
void ThreadLocalPythonObjects::set_state(const ThreadLocalPythonObjects& state) {
25-
py_objects = state;
26+
void ThreadLocalPythonObjects::set_state(ThreadLocalPythonObjects state) {
27+
py_objects = std::move(state);
2628
}
2729

2830
const ThreadLocalPythonObjects& ThreadLocalPythonObjects::get_state() {

aten/src/ATen/ThreadLocalPythonObjects.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,12 @@ namespace at {
88
namespace impl {
99

1010
struct TORCH_API ThreadLocalPythonObjects {
11-
static void set(std::string key, std::shared_ptr<SafePyObject> value);
12-
static const std::shared_ptr<SafePyObject>& get(std::string key);
13-
static bool contains(std::string key);
11+
static void set(const std::string& key, std::shared_ptr<SafePyObject> value);
12+
static const std::shared_ptr<SafePyObject>& get(const std::string& key);
13+
static bool contains(const std::string& key);
1414

1515
static const ThreadLocalPythonObjects& get_state();
16-
static void set_state(const ThreadLocalPythonObjects& state);
16+
static void set_state(ThreadLocalPythonObjects state);
1717

1818
private:
1919
std::unordered_map<std::string, std::shared_ptr<c10::SafePyObject>> obj_dict_;

aten/src/ATen/core/ivalue_inl.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -944,7 +944,7 @@ struct C10_EXPORT ivalue::Future final : c10::intrusive_ptr_target {
944944
"Skipping setting following error on the Future since "
945945
"it is already marked completed (this is not necessarily "
946946
"an error):\n",
947-
tryRetrieveErrorMessageInternal(eptr));
947+
tryRetrieveErrorMessageInternal(std::move(eptr)));
948948
if (eptr_) {
949949
msg += c10::str(
950950
", \nOriginal exception:\n",
@@ -1199,7 +1199,7 @@ struct C10_EXPORT ivalue::Future final : c10::intrusive_ptr_target {
11991199
// Tries to retrieve the error message from std::exception_ptr.
12001200
std::string tryRetrieveErrorMessageInternal(std::exception_ptr eptr) const {
12011201
try {
1202-
std::rethrow_exception(eptr);
1202+
std::rethrow_exception(std::move(eptr));
12031203
} catch (const std::exception& e) {
12041204
return e.what();
12051205
} catch (...) {

aten/src/ATen/core/jit_type.h

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include <ostream>
1717
#include <sstream>
1818
#include <type_traits>
19+
#include <utility>
1920

2021
namespace torch {
2122
namespace jit {
@@ -239,7 +240,7 @@ struct TORCH_API OptionalType : public UnionType {
239240

240241
std::string annotation_str_impl(TypePrinter printer = nullptr) const override {
241242
std::stringstream ss;
242-
ss << "Optional[" << getElementType()->annotation_str(printer) << "]";
243+
ss << "Optional[" << getElementType()->annotation_str(std::move(printer)) << "]";
243244
return ss.str();
244245
}
245246
};
@@ -906,7 +907,7 @@ struct TORCH_API ListType
906907

907908
std::string annotation_str_impl(TypePrinter printer = nullptr) const override {
908909
std::stringstream ss;
909-
ss << "List[" << getElementType()->annotation_str(printer) << "]";
910+
ss << "List[" << getElementType()->annotation_str(std::move(printer)) << "]";
910911
return ss.str();
911912
}
912913
};
@@ -1001,7 +1002,7 @@ struct TORCH_API DictType : public SharedType {
10011002
std::string annotation_str_impl(TypePrinter printer = nullptr) const override {
10021003
std::stringstream ss;
10031004
ss << "Dict[" << getKeyType()->annotation_str(printer) << ", "
1004-
<< getValueType()->annotation_str(printer) << "]";
1005+
<< getValueType()->annotation_str(std::move(printer)) << "]";
10051006
return ss.str();
10061007
}
10071008

@@ -1046,7 +1047,7 @@ struct TORCH_API FutureType
10461047

10471048
std::string annotation_str_impl(TypePrinter printer = nullptr) const override {
10481049
std::stringstream ss;
1049-
ss << "Future[" << getElementType()->annotation_str(printer) << "]";
1050+
ss << "Future[" << getElementType()->annotation_str(std::move(printer)) << "]";
10501051
return ss.str();
10511052
}
10521053
};
@@ -1078,7 +1079,7 @@ struct TORCH_API RRefType
10781079

10791080
std::string annotation_str_impl(TypePrinter printer = nullptr) const override {
10801081
std::stringstream ss;
1081-
ss << "RRef[" << getElementType()->annotation_str(printer) << "]";
1082+
ss << "RRef[" << getElementType()->annotation_str(std::move(printer)) << "]";
10821083
return ss.str();
10831084
}
10841085
};

aten/src/ATen/core/jit_type_base.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include <functional>
44
#include <memory>
55
#include <string>
6+
#include <utility>
67

78
#include <ATen/core/qualified_name.h>
89
#include <ATen/core/type_ptr.h>
@@ -451,7 +452,7 @@ struct TORCH_API Type {
451452
return *renamed;
452453
}
453454
}
454-
return annotation_str_impl(printer);
455+
return annotation_str_impl(std::move(printer));
455456
}
456457
std::string annotation_str() const {
457458
// Overload instead of define a default value for `printer` to help

c10/util/ThreadLocalDebugInfo.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ std::shared_ptr<ThreadLocalDebugInfo> ThreadLocalDebugInfo::current() {
2727

2828
/* static */
2929
void ThreadLocalDebugInfo::_forceCurrentDebugInfo(
30-
const std::shared_ptr<ThreadLocalDebugInfo>& info) {
31-
debug_info = info;
30+
std::shared_ptr<ThreadLocalDebugInfo> info) {
31+
debug_info = std::move(info);
3232
}
3333

3434
/* static */
@@ -39,7 +39,7 @@ void ThreadLocalDebugInfo::_push(
3939
debug_info = std::make_shared<ThreadLocalDebugInfo>();
4040
debug_info->parent_info_ = prev_info;
4141
debug_info->kind_ = kind;
42-
debug_info->info_ = info;
42+
debug_info->info_ = std::move(info);
4343
}
4444

4545
/* static */
@@ -86,8 +86,8 @@ DebugInfoGuard::DebugInfoGuard(std::shared_ptr<ThreadLocalDebugInfo> info) {
8686
if (!info) {
8787
return;
8888
}
89-
prev_info_ = debug_info;
90-
debug_info = info;
89+
prev_info_ = std::move(debug_info);
90+
debug_info = std::move(info);
9191
active_ = true;
9292
}
9393

c10/util/ThreadLocalDebugInfo.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ class C10_API ThreadLocalDebugInfo {
4141

4242
// Internal, use DebugInfoGuard/ThreadLocalStateGuard
4343
static void _forceCurrentDebugInfo(
44-
const std::shared_ptr<ThreadLocalDebugInfo>& info);
44+
std::shared_ptr<ThreadLocalDebugInfo> info);
4545

4646
// Push debug info struct of a given kind
4747
static void _push(DebugInfoKind kind, std::shared_ptr<DebugInfoBase> info);

0 commit comments

Comments
 (0)