15
15
* that the following types are not support by PyTorch libary bindings:
16
16
* - `int`
17
17
* - `float`
18
- * - `c10 ::optional<T> &`
19
- * - `c10 ::optional<const at::Tensor> &`
18
+ * - `std ::optional<T> &`
19
+ * - `std ::optional<const at::Tensor> &`
20
20
* So we convert them to (respectively):
21
21
* - `int64_t`
22
22
* - `double`
23
- * - `const c10 ::optional<T>&`
24
- * - `const c10 ::optional<at::Tensor>&`
23
+ * - `const std ::optional<T>&`
24
+ * - `const std ::optional<at::Tensor>&`
25
25
*/
26
26
27
27
template <typename T>
@@ -38,36 +38,36 @@ template<typename T>
38
38
T convert_from_pytorch_compatible_type (pytorch_library_compatible_type_t <T> arg)
39
39
{ return pytorch_library_compatible_type<T>::convert_from_type (arg); }
40
40
41
- // Map `c10 ::optional<T> &` -> `const c10 ::optional<T>&`
41
+ // Map `std ::optional<T> &` -> `const std ::optional<T>&`
42
42
// (NOTE: this is bit unsafe but non of the ops in flash_attn mutate
43
43
// the optional container)
44
44
template <typename T>
45
- struct pytorch_library_compatible_type <c10 ::optional<T> &> {
46
- using type = const c10 ::optional<T>&;
47
- static c10 ::optional<T>& convert_from_type (const c10 ::optional<T> &arg) {
48
- return const_cast <c10 ::optional<T>&>(arg);
45
+ struct pytorch_library_compatible_type <std ::optional<T> &> {
46
+ using type = const std ::optional<T>&;
47
+ static std ::optional<T>& convert_from_type (const std ::optional<T> &arg) {
48
+ return const_cast <std ::optional<T>&>(arg);
49
49
}
50
50
};
51
51
52
- // Map `c10 ::optional<T>` ->
53
- // `c10 ::optional<pytorch_library_compatible_type_t<T>>`
54
- // (NOTE: tested for `c10 ::optional<int>` -> `c10 ::optional<int64_t>`)
52
+ // Map `std ::optional<T>` ->
53
+ // `std ::optional<pytorch_library_compatible_type_t<T>>`
54
+ // (NOTE: tested for `std ::optional<int>` -> `std ::optional<int64_t>`)
55
55
template <typename T>
56
- struct pytorch_library_compatible_type <c10 ::optional<T>> {
57
- using type = c10 ::optional<pytorch_library_compatible_type_t <T>>;
58
- static c10 ::optional<pytorch_library_compatible_type_t <T>> convert_from_type (c10 ::optional<T> arg) {
56
+ struct pytorch_library_compatible_type <std ::optional<T>> {
57
+ using type = std ::optional<pytorch_library_compatible_type_t <T>>;
58
+ static std ::optional<pytorch_library_compatible_type_t <T>> convert_from_type (std ::optional<T> arg) {
59
59
return arg;
60
60
}
61
61
};
62
62
63
- // Map `c10 ::optional<const at::Tensor>&` -> `const c10 ::optional<at::Tensor>&`
63
+ // Map `std ::optional<const at::Tensor>&` -> `const std ::optional<at::Tensor>&`
64
64
template <>
65
- struct pytorch_library_compatible_type <c10 ::optional<const at::Tensor> &> {
66
- using type = const c10 ::optional<at::Tensor>&;
67
- static c10 ::optional<const at::Tensor>& convert_from_type (
68
- const c10 ::optional<at::Tensor> &arg) {
69
- return const_cast <c10 ::optional<const at::Tensor>&>(
70
- reinterpret_cast <const c10 ::optional<const at::Tensor>&>(arg));
65
+ struct pytorch_library_compatible_type <std ::optional<const at::Tensor> &> {
66
+ using type = const std ::optional<at::Tensor>&;
67
+ static std ::optional<const at::Tensor>& convert_from_type (
68
+ const std ::optional<at::Tensor> &arg) {
69
+ return const_cast <std ::optional<const at::Tensor>&>(
70
+ reinterpret_cast <const std ::optional<const at::Tensor>&>(arg));
71
71
}
72
72
};
73
73
0 commit comments