forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathLegacyTHFunctions.cpp
44 lines (36 loc) · 984 Bytes
/
LegacyTHFunctions.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
#include <ATen/LegacyTHFunctions${Backend}.h>
// ${generated_comment}
#include <ATen/ATen.h>
#include <ATen/Utils.h>
#include <ATen/NamedTensorUtils.h>
#include <ATen/${Generator}.h>
#include <ATen/ExpandUtils.h>
#include <ATen/core/EnableNamedTensor.h>
${th_headers}
${extra_cuda_headers}
namespace at {
namespace native {
namespace legacy {
namespace ${namespace} {
namespace {
ScalarType infer_scalar_type(const Tensor & t) {
return t.scalar_type();
}
ScalarType infer_scalar_type(const TensorList & tl) {
TORCH_CHECK(tl.size() > 0, "expected a non-empty list of Tensors");
return tl[0].scalar_type();
}
TensorOptions options(ScalarType s) {
return TensorOptions().dtype(s)
.device(DeviceType::${DeviceType})
.layout(kStrided);
}
Allocator* allocator() {
return ${allocator};
}
}
${legacy_th_definitions}
} // namespace th
} // namespace legacy
} // namespace native
} // namespace at