11#include < libtorchaudio/rnnt/cpu/cpu_transducer.h>
2- #include < torch/script.h>
2+ #include < torch/csrc/stable/library.h>
3+ #include < torch/csrc/stable/ops.h>
4+ #include < torch/csrc/stable/tensor.h>
35
46namespace torchaudio {
57namespace rnnt {
68namespace cpu {
79
10+ using torch::headeronly::ScalarType;
11+ using torch::stable::Tensor;
12+
813// Entry point into RNNT Loss
9- std::tuple<torch:: Tensor, std::optional<torch:: Tensor> > compute (
10- torch:: Tensor& logits,
11- const torch:: Tensor& targets,
12- const torch:: Tensor& logit_lengths,
13- const torch:: Tensor& target_lengths,
14+ std::tuple<Tensor, Tensor> compute (
15+ const Tensor& logits,
16+ const Tensor& targets,
17+ const Tensor& logit_lengths,
18+ const Tensor& target_lengths,
1419 int64_t blank,
1520 double clamp,
1621 bool fused_log_softmax = true ) {
17- TORCH_CHECK (
18- logits.device ().type () == targets.device ().type (),
19- " logits and targets must be on the same device" );
20- TORCH_CHECK (
21- logits.device ().type () == logit_lengths.device ().type (),
22+ STD_TORCH_CHECK (logits.is_cpu (), " logits must be on CPU" );
23+
24+ STD_TORCH_CHECK (
25+ targets.is_cpu (), " logits and targets must be on the same device" );
26+ STD_TORCH_CHECK (
27+ logit_lengths.is_cpu (),
2228 " logits and logit_lengths must be on the same device" );
23- TORCH_CHECK (
24- logits. device (). type () == target_lengths.device (). type (),
29+ STD_TORCH_CHECK (
30+ target_lengths.is_cpu (),
2531 " logits and target_lengths must be on the same device" );
2632
27- TORCH_CHECK (
28- logits.dtype () == torch::kFloat32 || logits.dtype () == torch::kFloat16 ,
33+ STD_TORCH_CHECK (
34+ logits.scalar_type () == ScalarType::Float ||
35+ logits.scalar_type () == ScalarType::Half,
2936 " logits must be float32 or float16 (half) type" );
30- TORCH_CHECK (targets.dtype () == torch::kInt32 , " targets must be int32 type" );
31- TORCH_CHECK (
32- logit_lengths.dtype () == torch::kInt32 ,
37+
38+ STD_TORCH_CHECK (
39+ targets.scalar_type () == ScalarType::Int, " targets must be int32 type" );
40+
41+ STD_TORCH_CHECK (
42+ logit_lengths.scalar_type () == ScalarType::Int,
3343 " logit_lengths must be int32 type" );
34- TORCH_CHECK (
35- target_lengths.dtype () == torch:: kInt32 ,
44+ STD_TORCH_CHECK (
45+ target_lengths.scalar_type () == ScalarType::Int ,
3646 " target_lengths must be int32 type" );
3747
38- TORCH_CHECK (logits.is_contiguous (), " logits must be contiguous" );
39- TORCH_CHECK (targets.is_contiguous (), " targets must be contiguous" );
40- TORCH_CHECK (
48+ STD_TORCH_CHECK (logits.is_contiguous (), " logits must be contiguous" );
49+ STD_TORCH_CHECK (targets.is_contiguous (), " targets must be contiguous" );
50+ STD_TORCH_CHECK (
4151 logit_lengths.is_contiguous (), " logit_lengths must be contiguous" );
42- TORCH_CHECK (
52+ STD_TORCH_CHECK (
4353 target_lengths.is_contiguous (), " target_lengths must be contiguous" );
4454
45- TORCH_CHECK (
55+ STD_TORCH_CHECK (
4656 logits.dim () == 4 , " logits must be 4-D (batch, time, target, class)" );
47- TORCH_CHECK (
57+ STD_TORCH_CHECK (
4858 targets.dim () == 2 , " targets must be 2-D (batch, max target length)" );
49- TORCH_CHECK (logit_lengths.dim () == 1 , " logit_lengths must be 1-D" );
50- TORCH_CHECK (target_lengths.dim () == 1 , " target_lengths must be 1-D" );
59+ STD_TORCH_CHECK (logit_lengths.dim () == 1 , " logit_lengths must be 1-D" );
60+ STD_TORCH_CHECK (target_lengths.dim () == 1 , " target_lengths must be 1-D" );
5161
52- TORCH_CHECK (
62+ STD_TORCH_CHECK (
5363 logit_lengths.size (0 ) == logits.size (0 ),
5464 " batch dimension mismatch between logits and logit_lengths" );
55- TORCH_CHECK (
65+ STD_TORCH_CHECK (
5666 target_lengths.size (0 ) == logits.size (0 ),
5767 " batch dimension mismatch between logits and target_lengths" );
58- TORCH_CHECK (
68+ STD_TORCH_CHECK (
5969 targets.size (0 ) == logits.size (0 ),
6070 " batch dimension mismatch between logits and targets" );
6171
62- TORCH_CHECK (
72+ STD_TORCH_CHECK (
6373 blank >= 0 && blank < logits.size (-1 ),
6474 " blank must be within [0, logits.shape[-1])" );
6575
66- TORCH_CHECK (
67- logits.size (1 ) == at::max (logit_lengths).item ().toInt (),
68- " input length mismatch" );
69- TORCH_CHECK (
70- logits.size (2 ) == at::max (target_lengths).item ().toInt () + 1 ,
71- " output length mismatch" );
72- TORCH_CHECK (
73- targets.size (1 ) == at::max (target_lengths).item ().toInt (),
74- " target length mismatch" );
76+ auto max_ivalue = [](const Tensor& t) {
77+ return reinterpret_cast <int32_t *>(torch::stable::amax (t, {}).data_ptr ())[0 ];
78+ };
7579
80+ STD_TORCH_CHECK (
81+ logits.size (1 ) == max_ivalue (logit_lengths), " input length mismatch" );
82+ STD_TORCH_CHECK (
83+ logits.size (2 ) == max_ivalue (target_lengths) + 1 ,
84+ " output length mismatch" );
85+ STD_TORCH_CHECK (
86+ targets.size (1 ) + 1 == logits.size (2 ), " target length mismatch" );
87+ // TODO: Use static_cast and check bounds when down-casting from
88+ // double to float (clamp_) and from int64 to int (blank_)
7689 Options options;
7790 options.batchSize_ = logit_lengths.size (0 );
7891 options.nHypos_ = target_lengths.size (0 ) / logit_lengths.size (0 );
@@ -82,67 +95,78 @@ std::tuple<torch::Tensor, std::optional<torch::Tensor>> compute(
8295 options.blank_ = blank;
8396 options.clamp_ = clamp;
8497 options.fusedLogSmax_ = fused_log_softmax;
85-
86- TORCH_CHECK_EQ (logits.device ().type (), torch::DeviceType::CPU);
8798 options.device_ = CPU;
8899
89- torch::Tensor costs = torch::empty (
90- options.batchSize_ * options.nHypos_ ,
91- torch::TensorOptions ().device (logits.device ()).dtype (logits.dtype ()));
92- std::optional<torch::Tensor> gradients = torch::zeros_like (logits);
93-
94- torch::Tensor int_workspace = torch::empty (
95- IntWorkspace::ComputeSizeFromOptions (options),
96- torch::TensorOptions ()
97- .device (logits.device ())
98- .dtype (torch::ScalarType::Int));
100+ Tensor costs =
101+ torch::stable::new_empty (logits, {options.batchSize_ * options.nHypos_ });
102+ Tensor gradients = torch::stable::empty_like (logits);
103+ torch::stable::fill_ (gradients, 0.0 );
99104
100- torch::Tensor float_workspace = torch::empty (
101- DtypeWorkspace<float >::ComputeSizeFromOptions (options),
102- torch::TensorOptions ()
103- .device (logits.device ())
104- .dtype (torch::ScalarType::Float));
105+ Tensor int_workspace = torch::stable::new_empty (
106+ logits, {IntWorkspace::ComputeSizeFromOptions (options)}, ScalarType::Int);
107+ Tensor float_workspace = torch::stable::new_empty (
108+ logits,
109+ {DtypeWorkspace<float >::ComputeSizeFromOptions (options)},
110+ ScalarType::Float);
105111
106112 Workspace<float > workspace (
107113 /* options=*/ options,
108- /* dtype_data=*/ float_workspace. data_ptr <float >( ),
114+ /* dtype_data=*/ reinterpret_cast <float *>(float_workspace. data_ptr () ),
109115 /* dtype_size=*/ float_workspace.numel (),
110- /* int_data=*/ int_workspace. data_ptr <int >( ),
116+ /* int_data=*/ reinterpret_cast <int *>(int_workspace. data_ptr () ),
111117 /* int_size=*/ int_workspace.numel ());
112118
113119 switch (logits.scalar_type ()) {
114- case torch:: ScalarType::Float: {
120+ case ScalarType::Float: {
115121 Compute</* DTYPE=*/ float , /* CAST_DTYPE=*/ float >(
116122 /* workspace=*/ workspace,
117- /* logits=*/ logits. data_ptr <float >( ),
118- /* targets=*/ targets. data_ptr <int >( ),
119- /* srcLengths=*/ logit_lengths. data_ptr <int >( ),
120- /* tgtLengths=*/ target_lengths. data_ptr <int >( ),
121- /* costs=*/ costs. data_ptr <float >( ),
122- /* gradients=*/ gradients-> data_ptr <float >( ));
123+ /* logits=*/ reinterpret_cast <float *>(logits. data_ptr () ),
124+ /* targets=*/ reinterpret_cast <int *>(targets. data_ptr () ),
125+ /* srcLengths=*/ reinterpret_cast <int *>(logit_lengths. data_ptr () ),
126+ /* tgtLengths=*/ reinterpret_cast <int *>(target_lengths. data_ptr () ),
127+ /* costs=*/ reinterpret_cast <float *>(costs. data_ptr () ),
128+ /* gradients=*/ reinterpret_cast <float *>(gradients. data_ptr () ));
123129 break ;
124130 }
125- case torch:: ScalarType::Half: {
131+ case ScalarType::Half: {
126132 Compute</* DTYPE=*/ c10::Half, /* CAST_DTYPE=*/ float >(
127133 /* workspace=*/ workspace,
128- /* logits=*/ logits. data_ptr <c10::Half>( ),
129- /* targets=*/ targets. data_ptr <int >( ),
130- /* srcLengths=*/ logit_lengths. data_ptr <int >( ),
131- /* tgtLengths=*/ target_lengths. data_ptr <int >( ),
132- /* costs=*/ costs. data_ptr <c10::Half>( ),
133- /* gradients=*/ gradients-> data_ptr <c10::Half>( ));
134+ /* logits=*/ reinterpret_cast <c10::Half*>(logits. data_ptr () ),
135+ /* targets=*/ reinterpret_cast <int *>(targets. data_ptr () ),
136+ /* srcLengths=*/ reinterpret_cast <int *>(logit_lengths. data_ptr () ),
137+ /* tgtLengths=*/ reinterpret_cast <int *>(target_lengths. data_ptr () ),
138+ /* costs=*/ reinterpret_cast <c10::Half*>(costs. data_ptr () ),
139+ /* gradients=*/ reinterpret_cast <c10::Half*>(gradients. data_ptr () ));
134140 break ;
135141 }
136142 default : {
137- break ;
143+ STD_TORCH_CHECK ( false , " unreachable " ) ;
138144 }
139145 };
140146
141147 return std::make_tuple (costs, gradients);
142148}
143149
144- TORCH_LIBRARY_IMPL (torchaudio, CPU, m) {
145- m.impl (" rnnt_loss" , &compute);
150+ void boxed_rnnt_loss (
151+ StableIValue* stack,
152+ uint64_t num_args,
153+ uint64_t num_outputs) {
154+ STD_TORCH_CHECK (num_args == 7 , " num_args must be 7" );
155+ STD_TORCH_CHECK (num_outputs == 2 , " num_outputs must be 2" );
156+ std::tuple<Tensor, Tensor> res = compute (
157+ /* logits*/ torch::stable::detail::to<Tensor>(stack[0 ]),
158+ /* targets*/ torch::stable::detail::to<Tensor>(stack[1 ]),
159+ /* logit_lengths*/ torch::stable::detail::to<Tensor>(stack[2 ]),
160+ /* target_lengths*/ torch::stable::detail::to<Tensor>(stack[3 ]),
161+ /* blank*/ float (torch::stable::detail::to<int64_t >(stack[4 ])),
162+ /* clamp*/ torch::stable::detail::to<double >(stack[5 ]),
163+ /* fused_log_softmax*/ torch::stable::detail::to<bool >(stack[6 ]));
164+ stack[0 ] = torch::stable::detail::from (std::get<0 >(res));
165+ stack[1 ] = torch::stable::detail::from (std::get<1 >(res));
166+ }
167+
168+ STABLE_TORCH_LIBRARY_IMPL (torchaudio, CPU, m) {
169+ m.impl (" rnnt_loss_forward" , &boxed_rnnt_loss);
146170}
147171
148172} // namespace cpu
0 commit comments