Skip to content

Commit

Permalink
Apply some std::move and param value fixups to aten (pytorch#92901)
Browse files Browse the repository at this point in the history
I noticed a few perf issues in the latest ATen and decided to fixup a few other miscellaneous ones I noticed recently.
Pull Request resolved: pytorch#92901
Approved by: https://github.com/ezyang
  • Loading branch information
Skylion007 authored and pytorchmergebot committed Jan 25, 2023
1 parent b073c09 commit f2f42e5
Show file tree
Hide file tree
Showing 9 changed files with 33 additions and 26 deletions.
3 changes: 2 additions & 1 deletion aten/src/ATen/LegacyBatchedTensorImpl.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include <bitset>
#include <utility>

#include <ATen/ArrayRef.h>
#include <ATen/SmallVector.h>
Expand Down Expand Up @@ -120,7 +121,7 @@ inline BatchedTensorImpl* maybeGetBatchedImpl(Tensor tensor) {
if (!isBatchedTensor(tensor)) {
return nullptr;
}
return unsafeGetBatchedImpl(tensor);
return unsafeGetBatchedImpl(std::move(tensor));
}

// Returns a bitset. If bit i is set, then that means dim i is a batchdim.
Expand Down
4 changes: 3 additions & 1 deletion aten/src/ATen/TensorIndexing.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@

#include <ATen/core/List.h>

#include <utility>

namespace at {
namespace indexing {

Expand Down Expand Up @@ -230,7 +232,7 @@ static inline Tensor applySlice(
return self;
}
}
return self.slice_symint(dim, start, stop, step);
return self.slice_symint(dim, start, stop, std::move(step));
}

static inline Tensor applySelect(
Expand Down
14 changes: 8 additions & 6 deletions aten/src/ATen/ThreadLocalPythonObjects.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,29 @@
#include <ATen/ThreadLocalPythonObjects.h>
#include <c10/util/Exception.h>

#include <utility>

namespace at {
namespace impl {

static thread_local ThreadLocalPythonObjects py_objects;


void ThreadLocalPythonObjects::set(std::string key, std::shared_ptr<SafePyObject> value) {
py_objects.obj_dict_[key] = value;
void ThreadLocalPythonObjects::set(const std::string& key, std::shared_ptr<SafePyObject> value) {
py_objects.obj_dict_[key] = std::move(value);
}

const std::shared_ptr<SafePyObject>& ThreadLocalPythonObjects::get(std::string key) {
const std::shared_ptr<SafePyObject>& ThreadLocalPythonObjects::get(const std::string& key) {
TORCH_CHECK(py_objects.obj_dict_.count(key));
return py_objects.obj_dict_[key];
}

bool ThreadLocalPythonObjects::contains(std::string key) {
bool ThreadLocalPythonObjects::contains(const std::string& key) {
return py_objects.obj_dict_.count(key);
}

void ThreadLocalPythonObjects::set_state(const ThreadLocalPythonObjects& state) {
py_objects = state;
void ThreadLocalPythonObjects::set_state(ThreadLocalPythonObjects state) {
py_objects = std::move(state);
}

const ThreadLocalPythonObjects& ThreadLocalPythonObjects::get_state() {
Expand Down
8 changes: 4 additions & 4 deletions aten/src/ATen/ThreadLocalPythonObjects.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@ namespace at {
namespace impl {

struct TORCH_API ThreadLocalPythonObjects {
static void set(std::string key, std::shared_ptr<SafePyObject> value);
static const std::shared_ptr<SafePyObject>& get(std::string key);
static bool contains(std::string key);
static void set(const std::string& key, std::shared_ptr<SafePyObject> value);
static const std::shared_ptr<SafePyObject>& get(const std::string& key);
static bool contains(const std::string& key);

static const ThreadLocalPythonObjects& get_state();
static void set_state(const ThreadLocalPythonObjects& state);
static void set_state(ThreadLocalPythonObjects state);

private:
std::unordered_map<std::string, std::shared_ptr<c10::SafePyObject>> obj_dict_;
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/core/ivalue_inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -944,7 +944,7 @@ struct C10_EXPORT ivalue::Future final : c10::intrusive_ptr_target {
"Skipping setting following error on the Future since "
"it is already marked completed (this is not necessarily "
"an error):\n",
tryRetrieveErrorMessageInternal(eptr));
tryRetrieveErrorMessageInternal(std::move(eptr)));
if (eptr_) {
msg += c10::str(
", \nOriginal exception:\n",
Expand Down Expand Up @@ -1199,7 +1199,7 @@ struct C10_EXPORT ivalue::Future final : c10::intrusive_ptr_target {
// Tries to retrieve the error message from std::exception_ptr.
std::string tryRetrieveErrorMessageInternal(std::exception_ptr eptr) const {
try {
std::rethrow_exception(eptr);
std::rethrow_exception(std::move(eptr));
} catch (const std::exception& e) {
return e.what();
} catch (...) {
Expand Down
11 changes: 6 additions & 5 deletions aten/src/ATen/core/jit_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include <ostream>
#include <sstream>
#include <type_traits>
#include <utility>

namespace torch {
namespace jit {
Expand Down Expand Up @@ -239,7 +240,7 @@ struct TORCH_API OptionalType : public UnionType {

std::string annotation_str_impl(TypePrinter printer = nullptr) const override {
std::stringstream ss;
ss << "Optional[" << getElementType()->annotation_str(printer) << "]";
ss << "Optional[" << getElementType()->annotation_str(std::move(printer)) << "]";
return ss.str();
}
};
Expand Down Expand Up @@ -906,7 +907,7 @@ struct TORCH_API ListType

std::string annotation_str_impl(TypePrinter printer = nullptr) const override {
std::stringstream ss;
ss << "List[" << getElementType()->annotation_str(printer) << "]";
ss << "List[" << getElementType()->annotation_str(std::move(printer)) << "]";
return ss.str();
}
};
Expand Down Expand Up @@ -1001,7 +1002,7 @@ struct TORCH_API DictType : public SharedType {
std::string annotation_str_impl(TypePrinter printer = nullptr) const override {
std::stringstream ss;
ss << "Dict[" << getKeyType()->annotation_str(printer) << ", "
<< getValueType()->annotation_str(printer) << "]";
<< getValueType()->annotation_str(std::move(printer)) << "]";
return ss.str();
}

Expand Down Expand Up @@ -1046,7 +1047,7 @@ struct TORCH_API FutureType

std::string annotation_str_impl(TypePrinter printer = nullptr) const override {
std::stringstream ss;
ss << "Future[" << getElementType()->annotation_str(printer) << "]";
ss << "Future[" << getElementType()->annotation_str(std::move(printer)) << "]";
return ss.str();
}
};
Expand Down Expand Up @@ -1078,7 +1079,7 @@ struct TORCH_API RRefType

std::string annotation_str_impl(TypePrinter printer = nullptr) const override {
std::stringstream ss;
ss << "RRef[" << getElementType()->annotation_str(printer) << "]";
ss << "RRef[" << getElementType()->annotation_str(std::move(printer)) << "]";
return ss.str();
}
};
Expand Down
3 changes: 2 additions & 1 deletion aten/src/ATen/core/jit_type_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <functional>
#include <memory>
#include <string>
#include <utility>

#include <ATen/core/qualified_name.h>
#include <ATen/core/type_ptr.h>
Expand Down Expand Up @@ -451,7 +452,7 @@ struct TORCH_API Type {
return *renamed;
}
}
return annotation_str_impl(printer);
return annotation_str_impl(std::move(printer));
}
std::string annotation_str() const {
// Overload instead of define a default value for `printer` to help
Expand Down
10 changes: 5 additions & 5 deletions c10/util/ThreadLocalDebugInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ std::shared_ptr<ThreadLocalDebugInfo> ThreadLocalDebugInfo::current() {

/* static */
void ThreadLocalDebugInfo::_forceCurrentDebugInfo(
const std::shared_ptr<ThreadLocalDebugInfo>& info) {
debug_info = info;
std::shared_ptr<ThreadLocalDebugInfo> info) {
debug_info = std::move(info);
}

/* static */
Expand All @@ -39,7 +39,7 @@ void ThreadLocalDebugInfo::_push(
debug_info = std::make_shared<ThreadLocalDebugInfo>();
debug_info->parent_info_ = prev_info;
debug_info->kind_ = kind;
debug_info->info_ = info;
debug_info->info_ = std::move(info);
}

/* static */
Expand Down Expand Up @@ -86,8 +86,8 @@ DebugInfoGuard::DebugInfoGuard(std::shared_ptr<ThreadLocalDebugInfo> info) {
if (!info) {
return;
}
prev_info_ = debug_info;
debug_info = info;
prev_info_ = std::move(debug_info);
debug_info = std::move(info);
active_ = true;
}

Expand Down
2 changes: 1 addition & 1 deletion c10/util/ThreadLocalDebugInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class C10_API ThreadLocalDebugInfo {

// Internal, use DebugInfoGuard/ThreadLocalStateGuard
static void _forceCurrentDebugInfo(
const std::shared_ptr<ThreadLocalDebugInfo>& info);
std::shared_ptr<ThreadLocalDebugInfo> info);

// Push debug info struct of a given kind
static void _push(DebugInfoKind kind, std::shared_ptr<DebugInfoBase> info);
Expand Down

0 comments on commit f2f42e5

Please sign in to comment.