Skip to content

Commit e49dd1f

Browse files
houseroadIsotr0py
authored andcommitted
[MISC] Replace c10::optional with std::optional (vllm-project#11730)
Signed-off-by: Lu Fang <lufang@fb.com> Signed-off-by: Isotr0py <2037008807@qq.com>
1 parent a9f5e7d commit e49dd1f

24 files changed

+136
-136
lines changed

csrc/attention/paged_attention_v1.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ void paged_attention_v1_launcher(
5353
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
5454
torch::Tensor& value_cache, int num_kv_heads, float scale,
5555
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
56-
const c10::optional<torch::Tensor>& alibi_slopes, float k_scale,
56+
const std::optional<torch::Tensor>& alibi_slopes, float k_scale,
5757
float v_scale, const int tp_rank, const int blocksparse_local_blocks,
5858
const int blocksparse_vert_stride, const int blocksparse_block_size,
5959
const int blocksparse_head_sliding_step) {
@@ -176,7 +176,7 @@ void paged_attention_v1(
176176
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
177177
torch::Tensor& seq_lens, // [num_seqs]
178178
int64_t block_size, int64_t max_seq_len,
179-
const c10::optional<torch::Tensor>& alibi_slopes,
179+
const std::optional<torch::Tensor>& alibi_slopes,
180180
const std::string& kv_cache_dtype, double k_scale, double v_scale,
181181
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
182182
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,

csrc/attention/paged_attention_v2.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ void paged_attention_v2_launcher(
5454
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
5555
torch::Tensor& value_cache, int num_kv_heads, float scale,
5656
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
57-
const c10::optional<torch::Tensor>& alibi_slopes, float k_scale,
57+
const std::optional<torch::Tensor>& alibi_slopes, float k_scale,
5858
float v_scale, const int tp_rank, const int blocksparse_local_blocks,
5959
const int blocksparse_vert_stride, const int blocksparse_block_size,
6060
const int blocksparse_head_sliding_step) {
@@ -187,7 +187,7 @@ void paged_attention_v2(
187187
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
188188
torch::Tensor& seq_lens, // [num_seqs]
189189
int64_t block_size, int64_t max_seq_len,
190-
const c10::optional<torch::Tensor>& alibi_slopes,
190+
const std::optional<torch::Tensor>& alibi_slopes,
191191
const std::string& kv_cache_dtype, double k_scale, double v_scale,
192192
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
193193
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,

csrc/cpu/attention.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -386,7 +386,7 @@ void paged_attention_v1_impl_launcher(
386386
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
387387
torch::Tensor& value_cache, int num_kv_heads, float scale,
388388
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
389-
const c10::optional<torch::Tensor>& alibi_slopes) {
389+
const std::optional<torch::Tensor>& alibi_slopes) {
390390
int num_seqs = query.size(0);
391391
int num_heads = query.size(1);
392392
int head_size = query.size(2);
@@ -459,7 +459,7 @@ void paged_attention_v1(
459459
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
460460
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
461461
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
462-
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
462+
int64_t max_seq_len, const std::optional<torch::Tensor>& alibi_slopes,
463463
const std::string& kv_cache_dtype, double k_scale, double v_scale,
464464
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
465465
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
@@ -702,7 +702,7 @@ void paged_attention_v2_impl_launcher(
702702
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
703703
torch::Tensor& value_cache, int num_kv_heads, float scale,
704704
torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size,
705-
int max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes) {
705+
int max_seq_len, const std::optional<torch::Tensor>& alibi_slopes) {
706706
int num_seqs = query.size(0);
707707
int num_heads = query.size(1);
708708
int head_size = query.size(2);
@@ -781,7 +781,7 @@ void paged_attention_v2(
781781
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
782782
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
783783
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
784-
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
784+
int64_t max_seq_len, const std::optional<torch::Tensor>& alibi_slopes,
785785
const std::string& kv_cache_dtype, double k_scale, double v_scale,
786786
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
787787
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,

csrc/cpu/quant.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -359,7 +359,7 @@ void int8_scaled_mm(torch::Tensor& c, // [M, OC], row-major
359359
const torch::Tensor& b, // [IC, OC], column-major
360360
const torch::Tensor& a_scales, // [1] or [M]
361361
const torch::Tensor& b_scales, // [1] or [OC]
362-
const c10::optional<torch::Tensor>& bias // [OC]
362+
const std::optional<torch::Tensor>& bias // [OC]
363363
) {
364364
CPU_KERNEL_GUARD_IN(cutlass_scaled_mm)
365365
// Checks for conformality
@@ -442,8 +442,8 @@ void int8_scaled_mm_azp(torch::Tensor& c, // [M, OC], row-major
442442
const torch::Tensor& a_scales, // [1] or [M]
443443
const torch::Tensor& b_scales, // [1] or [OC]
444444
const torch::Tensor& azp_adj, // [OC]
445-
const c10::optional<torch::Tensor>& azp, // [1] or [M]
446-
const c10::optional<torch::Tensor>& bias // [OC]
445+
const std::optional<torch::Tensor>& azp, // [1] or [M]
446+
const std::optional<torch::Tensor>& bias // [OC]
447447
) {
448448
CPU_KERNEL_GUARD_IN(cutlass_scaled_mm_azp)
449449
// Checks for conformality
@@ -561,7 +561,7 @@ void int8_scaled_mm_azp(torch::Tensor& c, // [M, OC], row-major
561561
void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
562562
const torch::Tensor& input, // [..., hidden_size]
563563
const torch::Tensor& scale,
564-
c10::optional<torch::Tensor> const& azp) {
564+
std::optional<torch::Tensor> const& azp) {
565565
CPU_KERNEL_GUARD_IN(static_scaled_int8_quant)
566566
TORCH_CHECK(input.is_contiguous());
567567
TORCH_CHECK(out.is_contiguous());
@@ -590,7 +590,7 @@ void dynamic_scaled_int8_quant(
590590
torch::Tensor& out, // [..., hidden_size]
591591
const torch::Tensor& input, // [..., hidden_size]
592592
torch::Tensor& scale, // [..., 1]
593-
c10::optional<torch::Tensor> const& azp) {
593+
std::optional<torch::Tensor> const& azp) {
594594
CPU_KERNEL_GUARD_IN(dynamic_scaled_int8_quant)
595595
TORCH_CHECK(input.is_contiguous());
596596
TORCH_CHECK(out.is_contiguous());

csrc/cpu/torch_bindings.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,14 @@ std::string init_cpu_threads_env(const std::string& cpu_ids);
99
void int8_scaled_mm(torch::Tensor& c, const torch::Tensor& a,
1010
const torch::Tensor& b, const torch::Tensor& a_scales,
1111
const torch::Tensor& b_scales,
12-
const c10::optional<torch::Tensor>& bias);
12+
const std::optional<torch::Tensor>& bias);
1313

1414
void int8_scaled_mm_azp(torch::Tensor& c, const torch::Tensor& a,
1515
const torch::Tensor& b, const torch::Tensor& a_scales,
1616
const torch::Tensor& b_scales,
1717
const torch::Tensor& azp_adj,
18-
const c10::optional<torch::Tensor>& azp,
19-
const c10::optional<torch::Tensor>& bias);
18+
const std::optional<torch::Tensor>& azp,
19+
const std::optional<torch::Tensor>& bias);
2020

2121
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
2222
// vLLM custom ops

csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ struct ScaledEpilogueBase {
6868
// This overload handles the case where there might not be a tensor, in which
6969
// case a nullptr is passed and a constant (0) is used.
7070
template <typename Descriptor, typename T>
71-
static auto args_from_tensor(c10::optional<torch::Tensor> const& tensor) {
71+
static auto args_from_tensor(std::optional<torch::Tensor> const& tensor) {
7272
static_assert(std::is_same_v<Descriptor, RowOrZeroLoad<T>>);
7373
using Arguments = typename Descriptor::Arguments;
7474
auto* data_ptr = tensor ? static_cast<T*>(tensor->data_ptr()) : nullptr;
@@ -223,7 +223,7 @@ struct ScaledEpilogueBiasAzp
223223
static ArgumentType prepare_args(torch::Tensor const& a_scales,
224224
torch::Tensor const& b_scales,
225225
torch::Tensor const& azp_adj,
226-
c10::optional<torch::Tensor> const& bias) {
226+
std::optional<torch::Tensor> const& bias) {
227227
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
228228
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
229229
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
@@ -301,7 +301,7 @@ struct ScaledEpilogueBiasAzpToken
301301
torch::Tensor const& b_scales,
302302
torch::Tensor const& azp_adj,
303303
torch::Tensor const& azp,
304-
c10::optional<torch::Tensor> const& bias) {
304+
std::optional<torch::Tensor> const& bias) {
305305
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
306306
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
307307
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);

csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ struct ScaledEpilogueBase {
6767
// This overload handles the case where there might not be a tensor, in which
6868
// case a nullptr is passed and a constant (0) is used.
6969
template <typename Descriptor, typename T>
70-
static auto args_from_tensor(c10::optional<torch::Tensor> const& tensor) {
70+
static auto args_from_tensor(std::optional<torch::Tensor> const& tensor) {
7171
using Arguments = typename Descriptor::Arguments;
7272
auto* data_ptr = tensor ? static_cast<T*>(tensor->data_ptr()) : nullptr;
7373
static_assert(std::is_same_v<Descriptor, ColLoad<T, true>> ||
@@ -223,7 +223,7 @@ struct ScaledEpilogueBiasAzp
223223
static ArgumentType prepare_args(torch::Tensor const& a_scales,
224224
torch::Tensor const& b_scales,
225225
torch::Tensor const& azp_adj,
226-
c10::optional<torch::Tensor> const& bias) {
226+
std::optional<torch::Tensor> const& bias) {
227227
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
228228
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
229229
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
@@ -299,7 +299,7 @@ struct ScaledEpilogueBiasAzpToken
299299
torch::Tensor const& b_scales,
300300
torch::Tensor const& azp_adj,
301301
torch::Tensor const& azp,
302-
c10::optional<torch::Tensor> const& bias) {
302+
std::optional<torch::Tensor> const& bias) {
303303
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
304304
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
305305
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);

csrc/cutlass_extensions/torch_utils.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ static inline auto make_cute_layout(torch::Tensor const& tensor,
9797

9898
template <typename Stride>
9999
static inline auto maybe_make_cute_layout(
100-
c10::optional<torch::Tensor> const& tensor,
100+
std::optional<torch::Tensor> const& tensor,
101101
std::string_view name = "tensor") {
102102
using Layout = decltype(make_cute_layout<Stride>(*tensor));
103103

csrc/mamba/causal_conv1d/causal_conv1d.cu

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -53,12 +53,12 @@ void set_conv_params_fwd(ConvParamsBase &params,
5353
const at::Tensor x,
5454
const at::Tensor weight,
5555
const at::Tensor out,
56-
const c10::optional<at::Tensor>& bias,
56+
const std::optional<at::Tensor>& bias,
5757
bool silu_activation,
5858
int64_t pad_slot_id,
59-
const c10::optional<at::Tensor>& query_start_loc = std::nullopt,
60-
const c10::optional<at::Tensor>& cache_indices = std::nullopt,
61-
const c10::optional<at::Tensor>& has_initial_state = std::nullopt) {
59+
const std::optional<at::Tensor>& query_start_loc = std::nullopt,
60+
const std::optional<at::Tensor>& cache_indices = std::nullopt,
61+
const std::optional<at::Tensor>& has_initial_state = std::nullopt) {
6262

6363
// Reset the parameters
6464
memset(&params, 0, sizeof(params));
@@ -93,11 +93,11 @@ void set_conv_params_fwd(ConvParamsBase &params,
9393

9494

9595
void causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight,
96-
const c10::optional<at::Tensor> &bias_,
97-
const c10::optional<at::Tensor> &conv_states,
98-
const c10::optional<at::Tensor> &query_start_loc,
99-
const c10::optional<at::Tensor> &cache_indices,
100-
const c10::optional<at::Tensor> &has_initial_state,
96+
const std::optional<at::Tensor> &bias_,
97+
const std::optional<at::Tensor> &conv_states,
98+
const std::optional<at::Tensor> &query_start_loc,
99+
const std::optional<at::Tensor> &cache_indices,
100+
const std::optional<at::Tensor> &has_initial_state,
101101
bool silu_activation,
102102
// used to identify padding entries if cache_indices provided
103103
// in case of padding, the kernel will return early
@@ -194,10 +194,10 @@ void causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight,
194194
void causal_conv1d_update(const at::Tensor &x,
195195
const at::Tensor &conv_state,
196196
const at::Tensor &weight,
197-
const c10::optional<at::Tensor> &bias_,
197+
const std::optional<at::Tensor> &bias_,
198198
bool silu_activation,
199-
const c10::optional<at::Tensor> &cache_seqlens_,
200-
const c10::optional<at::Tensor> &conv_state_indices_,
199+
const std::optional<at::Tensor> &cache_seqlens_,
200+
const std::optional<at::Tensor> &conv_state_indices_,
201201
// used to identify padding entries if cache_indices provided
202202
// in case of padding, the kernel will return early
203203
int64_t pad_slot_id) {

csrc/mamba/mamba_ssm/selective_scan_fwd.cu

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -402,14 +402,14 @@ void set_ssm_params_fwd(SSMParamsBase &params,
402402
const torch::Tensor out,
403403
const torch::Tensor z,
404404
const torch::Tensor out_z,
405-
const c10::optional<at::Tensor>& D,
406-
const c10::optional<at::Tensor>& delta_bias,
405+
const std::optional<at::Tensor>& D,
406+
const std::optional<at::Tensor>& delta_bias,
407407
const torch::Tensor ssm_states,
408408
bool has_z,
409409
bool delta_softplus,
410-
const c10::optional<at::Tensor>& query_start_loc,
411-
const c10::optional<at::Tensor>& cache_indices,
412-
const c10::optional<at::Tensor>& has_initial_state,
410+
const std::optional<at::Tensor>& query_start_loc,
411+
const std::optional<at::Tensor>& cache_indices,
412+
const std::optional<at::Tensor>& has_initial_state,
413413
bool varlen,
414414
int64_t pad_slot_id) {
415415

@@ -504,13 +504,13 @@ void set_ssm_params_fwd(SSMParamsBase &params,
504504

505505
void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
506506
const torch::Tensor &A, const torch::Tensor &B, const torch::Tensor &C,
507-
const c10::optional<torch::Tensor> &D_,
508-
const c10::optional<torch::Tensor> &z_,
509-
const c10::optional<torch::Tensor> &delta_bias_,
507+
const std::optional<torch::Tensor> &D_,
508+
const std::optional<torch::Tensor> &z_,
509+
const std::optional<torch::Tensor> &delta_bias_,
510510
bool delta_softplus,
511-
const c10::optional<torch::Tensor> &query_start_loc,
512-
const c10::optional<torch::Tensor> &cache_indices,
513-
const c10::optional<torch::Tensor> &has_initial_state,
511+
const std::optional<torch::Tensor> &query_start_loc,
512+
const std::optional<torch::Tensor> &cache_indices,
513+
const std::optional<torch::Tensor> &has_initial_state,
514514
const torch::Tensor &ssm_states,
515515
// used to identify padding entries if cache_indices provided
516516
// in case of padding, the kernel will return early

0 commit comments

Comments
 (0)