Skip to content

Commit 75302d3

Browse files
yeqcharlotteLucasWilkinson
authored andcommitted
c10::optional -> std::optional (#58)
Signed-off-by: Ye (Charlotte) Qi <yeq@meta.com>
1 parent dc9d410 commit 75302d3

File tree

3 files changed

+30
-30
lines changed

3 files changed

+30
-30
lines changed

csrc/common/pytorch_shim.h

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,13 @@
1515
* that the following types are not support by PyTorch libary bindings:
1616
* - `int`
1717
* - `float`
18-
* - `c10::optional<T> &`
19-
* - `c10::optional<const at::Tensor> &`
18+
* - `std::optional<T> &`
19+
* - `std::optional<const at::Tensor> &`
2020
* So we convert them to (respectively):
2121
* - `int64_t`
2222
* - `double`
23-
* - `const c10::optional<T>&`
24-
* - `const c10::optional<at::Tensor>&`
23+
* - `const std::optional<T>&`
24+
* - `const std::optional<at::Tensor>&`
2525
*/
2626

2727
template<typename T>
@@ -38,36 +38,36 @@ template<typename T>
3838
T convert_from_pytorch_compatible_type(pytorch_library_compatible_type_t<T> arg)
3939
{ return pytorch_library_compatible_type<T>::convert_from_type(arg); }
4040

41-
// Map `c10::optional<T> &` -> `const c10::optional<T>&`
41+
// Map `std::optional<T> &` -> `const std::optional<T>&`
4242
// (NOTE: this is bit unsafe but non of the ops in flash_attn mutate
4343
// the optional container)
4444
template<typename T>
45-
struct pytorch_library_compatible_type<c10::optional<T> &> {
46-
using type = const c10::optional<T>&;
47-
static c10::optional<T>& convert_from_type(const c10::optional<T> &arg) {
48-
return const_cast<c10::optional<T>&>(arg);
45+
struct pytorch_library_compatible_type<std::optional<T> &> {
46+
using type = const std::optional<T>&;
47+
static std::optional<T>& convert_from_type(const std::optional<T> &arg) {
48+
return const_cast<std::optional<T>&>(arg);
4949
}
5050
};
5151

52-
// Map `c10::optional<T>` ->
53-
// `c10::optional<pytorch_library_compatible_type_t<T>>`
54-
// (NOTE: tested for `c10::optional<int>` -> `c10::optional<int64_t>`)
52+
// Map `std::optional<T>` ->
53+
// `std::optional<pytorch_library_compatible_type_t<T>>`
54+
// (NOTE: tested for `std::optional<int>` -> `std::optional<int64_t>`)
5555
template<typename T>
56-
struct pytorch_library_compatible_type<c10::optional<T>> {
57-
using type = c10::optional<pytorch_library_compatible_type_t<T>>;
58-
static c10::optional<pytorch_library_compatible_type_t<T>> convert_from_type(c10::optional<T> arg) {
56+
struct pytorch_library_compatible_type<std::optional<T>> {
57+
using type = std::optional<pytorch_library_compatible_type_t<T>>;
58+
static std::optional<pytorch_library_compatible_type_t<T>> convert_from_type(std::optional<T> arg) {
5959
return arg;
6060
}
6161
};
6262

63-
// Map `c10::optional<const at::Tensor>&` -> `const c10::optional<at::Tensor>&`
63+
// Map `std::optional<const at::Tensor>&` -> `const std::optional<at::Tensor>&`
6464
template<>
65-
struct pytorch_library_compatible_type<c10::optional<const at::Tensor> &> {
66-
using type = const c10::optional<at::Tensor>&;
67-
static c10::optional<const at::Tensor>& convert_from_type(
68-
const c10::optional<at::Tensor> &arg) {
69-
return const_cast<c10::optional<const at::Tensor>&>(
70-
reinterpret_cast<const c10::optional<const at::Tensor>&>(arg));
65+
struct pytorch_library_compatible_type<std::optional<const at::Tensor> &> {
66+
using type = const std::optional<at::Tensor>&;
67+
static std::optional<const at::Tensor>& convert_from_type(
68+
const std::optional<at::Tensor> &arg) {
69+
return const_cast<std::optional<const at::Tensor>&>(
70+
reinterpret_cast<const std::optional<const at::Tensor>&>(arg));
7171
}
7272
};
7373

csrc/flash_attn/flash_api_sparse.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -318,11 +318,11 @@ mha_varlen_fwd_sparse(at::Tensor &q, // total_q x num_heads x head_size, total_
318318
const at::Tensor &block_offset,
319319
const at::Tensor &column_count,
320320
const at::Tensor &column_index,
321-
const c10::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
321+
const std::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
322322
const at::Tensor &cu_seqlens_q, // b+1
323323
const at::Tensor &cu_seqlens_k, // b+1
324-
const c10::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
325-
const c10::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
324+
const std::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
325+
const std::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
326326
int64_t max_seqlen_q,
327327
const int64_t max_seqlen_k,
328328
const double p_dropout,
@@ -331,7 +331,7 @@ mha_varlen_fwd_sparse(at::Tensor &q, // total_q x num_heads x head_size, total_
331331
bool is_causal,
332332
const double softcap,
333333
const bool return_softmax,
334-
c10::optional<at::Generator> gen_) {
334+
std::optional<at::Generator> gen_) {
335335

336336
auto [cc_major, cc_minor] = get_compute_capability(get_current_device());
337337
bool is_sm8x = cc_major == 8 && cc_minor >= 0;

csrc/flash_attn/flash_api_torch_lib.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,11 +86,11 @@ mha_varlen_fwd_sparse(at::Tensor &q, // total_q x num_heads x head_size, total_
8686
const at::Tensor &block_offset,
8787
const at::Tensor &column_count,
8888
const at::Tensor &column_index,
89-
const c10::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
89+
const std::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
9090
const at::Tensor &cu_seqlens_q, // b+1
9191
const at::Tensor &cu_seqlens_k, // b+1
92-
const c10::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
93-
const c10::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
92+
const std::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
93+
const std::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
9494
int64_t max_seqlen_q,
9595
const int64_t max_seqlen_k,
9696
const double p_dropout,
@@ -99,7 +99,7 @@ mha_varlen_fwd_sparse(at::Tensor &q, // total_q x num_heads x head_size, total_
9999
bool is_causal,
100100
const double softcap,
101101
const bool return_softmax,
102-
c10::optional<at::Generator> gen_);
102+
std::optional<at::Generator> gen_);
103103

104104
/**
105105
* Torch Library Registration

0 commit comments

Comments
 (0)