forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathDefaultTensorOptions.h
45 lines (36 loc) · 1.04 KB
/
DefaultTensorOptions.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
#pragma once
#include <c10/core/Device.h>
#include <c10/core/DeviceType.h>
#include <c10/core/Layout.h>
#include <c10/core/ScalarType.h>
#include <c10/util/typeid.h>
namespace c10 {
struct TensorOptions;
/// Like TensorOptions, but all fields are guaranteed to be filled.
struct DefaultTensorOptions {
DefaultTensorOptions() = default;
caffe2::TypeMeta dtype() const noexcept {
return dtype_;
}
Device device() const noexcept {
return device_;
}
Layout layout() const noexcept {
return layout_;
}
bool requires_grad() const noexcept {
return requires_grad_;
}
// Defined in TensorOptions.h
inline DefaultTensorOptions& merge(const TensorOptions& options);
private:
caffe2::TypeMeta dtype_ = caffe2::TypeMeta::Make<float>(); // 64-bit
Device device_ = at::kCPU; // 32-bit
Layout layout_ = at::kStrided; // 8-bit
bool requires_grad_ = false; // 8-bit
};
inline const DefaultTensorOptions& getDefaultTensorOptions() {
static const auto options = DefaultTensorOptions();
return options;
}
} // namespace c10