Skip to content

Commit 3b0e7a6

Browse files
authored
[STABLE ABI] Port rnnt. (#4073)
1 parent 69bbe73 commit 3b0e7a6

File tree

11 files changed

+198
-513
lines changed

11 files changed

+198
-513
lines changed

src/libtorchaudio/CMakeLists.txt

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,19 +22,13 @@ if(BUILD_RNNT)
2222
list(
2323
APPEND
2424
sources
25-
rnnt/cpu/compute_alphas.cpp
26-
rnnt/cpu/compute_betas.cpp
2725
rnnt/cpu/compute.cpp
28-
rnnt/compute_alphas.cpp
29-
rnnt/compute_betas.cpp
3026
rnnt/compute.cpp
3127
)
3228
if (USE_CUDA)
3329
list(
3430
APPEND
3531
sources
36-
rnnt/gpu/compute_alphas.cu
37-
rnnt/gpu/compute_betas.cu
3832
rnnt/gpu/compute.cu
3933
)
4034
endif()

src/libtorchaudio/rnnt/compute.cpp

Lines changed: 4 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,12 @@
1-
#include <libtorchaudio/rnnt/compute.h>
1+
#include <torch/csrc/stable/library.h>
22

3-
std::tuple<torch::Tensor, std::optional<torch::Tensor>> rnnt_loss(
4-
torch::Tensor& logits,
5-
const torch::Tensor& targets,
6-
const torch::Tensor& logit_lengths,
7-
const torch::Tensor& target_lengths,
8-
int64_t blank,
9-
double clamp,
10-
bool fused_log_softmax = true) {
11-
static auto op = torch::Dispatcher::singleton()
12-
.findSchemaOrThrow("torchaudio::rnnt_loss", "")
13-
.typed<decltype(rnnt_loss)>();
14-
return op.call(
15-
logits,
16-
targets,
17-
logit_lengths,
18-
target_lengths,
19-
blank,
20-
clamp,
21-
fused_log_softmax);
22-
}
23-
24-
TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
3+
STABLE_TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
254
m.def(
26-
"rnnt_loss(Tensor logits,"
5+
"rnnt_loss_forward(Tensor logits,"
276
"Tensor targets,"
287
"Tensor logit_lengths,"
298
"Tensor target_lengths,"
309
"int blank,"
3110
"float clamp,"
32-
"bool fused_log_softmax) -> (Tensor, Tensor?)");
33-
m.def("torchaudio::rnnt_loss_forward", &rnnt_loss);
11+
"bool fused_log_softmax) -> (Tensor, Tensor)");
3412
}

src/libtorchaudio/rnnt/compute.h

Lines changed: 0 additions & 12 deletions
This file was deleted.

src/libtorchaudio/rnnt/compute_alphas.cpp

Lines changed: 0 additions & 11 deletions
This file was deleted.

src/libtorchaudio/rnnt/compute_betas.cpp

Lines changed: 0 additions & 11 deletions
This file was deleted.

src/libtorchaudio/rnnt/cpu/compute.cpp

Lines changed: 101 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -1,78 +1,91 @@
11
#include <libtorchaudio/rnnt/cpu/cpu_transducer.h>
2-
#include <torch/script.h>
2+
#include <torch/csrc/stable/library.h>
3+
#include <torch/csrc/stable/ops.h>
4+
#include <torch/csrc/stable/tensor.h>
35

46
namespace torchaudio {
57
namespace rnnt {
68
namespace cpu {
79

10+
using torch::headeronly::ScalarType;
11+
using torch::stable::Tensor;
12+
813
// Entry point into RNNT Loss
9-
std::tuple<torch::Tensor, std::optional<torch::Tensor>> compute(
10-
torch::Tensor& logits,
11-
const torch::Tensor& targets,
12-
const torch::Tensor& logit_lengths,
13-
const torch::Tensor& target_lengths,
14+
std::tuple<Tensor, Tensor> compute(
15+
const Tensor& logits,
16+
const Tensor& targets,
17+
const Tensor& logit_lengths,
18+
const Tensor& target_lengths,
1419
int64_t blank,
1520
double clamp,
1621
bool fused_log_softmax = true) {
17-
TORCH_CHECK(
18-
logits.device().type() == targets.device().type(),
19-
"logits and targets must be on the same device");
20-
TORCH_CHECK(
21-
logits.device().type() == logit_lengths.device().type(),
22+
STD_TORCH_CHECK(logits.is_cpu(), "logits must be on CPU");
23+
24+
STD_TORCH_CHECK(
25+
targets.is_cpu(), "logits and targets must be on the same device");
26+
STD_TORCH_CHECK(
27+
logit_lengths.is_cpu(),
2228
"logits and logit_lengths must be on the same device");
23-
TORCH_CHECK(
24-
logits.device().type() == target_lengths.device().type(),
29+
STD_TORCH_CHECK(
30+
target_lengths.is_cpu(),
2531
"logits and target_lengths must be on the same device");
2632

27-
TORCH_CHECK(
28-
logits.dtype() == torch::kFloat32 || logits.dtype() == torch::kFloat16,
33+
STD_TORCH_CHECK(
34+
logits.scalar_type() == ScalarType::Float ||
35+
logits.scalar_type() == ScalarType::Half,
2936
"logits must be float32 or float16 (half) type");
30-
TORCH_CHECK(targets.dtype() == torch::kInt32, "targets must be int32 type");
31-
TORCH_CHECK(
32-
logit_lengths.dtype() == torch::kInt32,
37+
38+
STD_TORCH_CHECK(
39+
targets.scalar_type() == ScalarType::Int, "targets must be int32 type");
40+
41+
STD_TORCH_CHECK(
42+
logit_lengths.scalar_type() == ScalarType::Int,
3343
"logit_lengths must be int32 type");
34-
TORCH_CHECK(
35-
target_lengths.dtype() == torch::kInt32,
44+
STD_TORCH_CHECK(
45+
target_lengths.scalar_type() == ScalarType::Int,
3646
"target_lengths must be int32 type");
3747

38-
TORCH_CHECK(logits.is_contiguous(), "logits must be contiguous");
39-
TORCH_CHECK(targets.is_contiguous(), "targets must be contiguous");
40-
TORCH_CHECK(
48+
STD_TORCH_CHECK(logits.is_contiguous(), "logits must be contiguous");
49+
STD_TORCH_CHECK(targets.is_contiguous(), "targets must be contiguous");
50+
STD_TORCH_CHECK(
4151
logit_lengths.is_contiguous(), "logit_lengths must be contiguous");
42-
TORCH_CHECK(
52+
STD_TORCH_CHECK(
4353
target_lengths.is_contiguous(), "target_lengths must be contiguous");
4454

45-
TORCH_CHECK(
55+
STD_TORCH_CHECK(
4656
logits.dim() == 4, "logits must be 4-D (batch, time, target, class)");
47-
TORCH_CHECK(
57+
STD_TORCH_CHECK(
4858
targets.dim() == 2, "targets must be 2-D (batch, max target length)");
49-
TORCH_CHECK(logit_lengths.dim() == 1, "logit_lengths must be 1-D");
50-
TORCH_CHECK(target_lengths.dim() == 1, "target_lengths must be 1-D");
59+
STD_TORCH_CHECK(logit_lengths.dim() == 1, "logit_lengths must be 1-D");
60+
STD_TORCH_CHECK(target_lengths.dim() == 1, "target_lengths must be 1-D");
5161

52-
TORCH_CHECK(
62+
STD_TORCH_CHECK(
5363
logit_lengths.size(0) == logits.size(0),
5464
"batch dimension mismatch between logits and logit_lengths");
55-
TORCH_CHECK(
65+
STD_TORCH_CHECK(
5666
target_lengths.size(0) == logits.size(0),
5767
"batch dimension mismatch between logits and target_lengths");
58-
TORCH_CHECK(
68+
STD_TORCH_CHECK(
5969
targets.size(0) == logits.size(0),
6070
"batch dimension mismatch between logits and targets");
6171

62-
TORCH_CHECK(
72+
STD_TORCH_CHECK(
6373
blank >= 0 && blank < logits.size(-1),
6474
"blank must be within [0, logits.shape[-1])");
6575

66-
TORCH_CHECK(
67-
logits.size(1) == at::max(logit_lengths).item().toInt(),
68-
"input length mismatch");
69-
TORCH_CHECK(
70-
logits.size(2) == at::max(target_lengths).item().toInt() + 1,
71-
"output length mismatch");
72-
TORCH_CHECK(
73-
targets.size(1) == at::max(target_lengths).item().toInt(),
74-
"target length mismatch");
76+
auto max_ivalue = [](const Tensor& t) {
77+
return reinterpret_cast<int32_t*>(torch::stable::amax(t, {}).data_ptr())[0];
78+
};
7579

80+
STD_TORCH_CHECK(
81+
logits.size(1) == max_ivalue(logit_lengths), "input length mismatch");
82+
STD_TORCH_CHECK(
83+
logits.size(2) == max_ivalue(target_lengths) + 1,
84+
"output length mismatch");
85+
STD_TORCH_CHECK(
86+
targets.size(1) + 1 == logits.size(2), "target length mismatch");
87+
// TODO: Use static_cast and check bounds when down-casting from
88+
// double to float (clamp_) and from int64 to int (blank_)
7689
Options options;
7790
options.batchSize_ = logit_lengths.size(0);
7891
options.nHypos_ = target_lengths.size(0) / logit_lengths.size(0);
@@ -82,67 +95,78 @@ std::tuple<torch::Tensor, std::optional<torch::Tensor>> compute(
8295
options.blank_ = blank;
8396
options.clamp_ = clamp;
8497
options.fusedLogSmax_ = fused_log_softmax;
85-
86-
TORCH_CHECK_EQ(logits.device().type(), torch::DeviceType::CPU);
8798
options.device_ = CPU;
8899

89-
torch::Tensor costs = torch::empty(
90-
options.batchSize_ * options.nHypos_,
91-
torch::TensorOptions().device(logits.device()).dtype(logits.dtype()));
92-
std::optional<torch::Tensor> gradients = torch::zeros_like(logits);
93-
94-
torch::Tensor int_workspace = torch::empty(
95-
IntWorkspace::ComputeSizeFromOptions(options),
96-
torch::TensorOptions()
97-
.device(logits.device())
98-
.dtype(torch::ScalarType::Int));
100+
Tensor costs =
101+
torch::stable::new_empty(logits, {options.batchSize_ * options.nHypos_});
102+
Tensor gradients = torch::stable::empty_like(logits);
103+
torch::stable::fill_(gradients, 0.0);
99104

100-
torch::Tensor float_workspace = torch::empty(
101-
DtypeWorkspace<float>::ComputeSizeFromOptions(options),
102-
torch::TensorOptions()
103-
.device(logits.device())
104-
.dtype(torch::ScalarType::Float));
105+
Tensor int_workspace = torch::stable::new_empty(
106+
logits, {IntWorkspace::ComputeSizeFromOptions(options)}, ScalarType::Int);
107+
Tensor float_workspace = torch::stable::new_empty(
108+
logits,
109+
{DtypeWorkspace<float>::ComputeSizeFromOptions(options)},
110+
ScalarType::Float);
105111

106112
Workspace<float> workspace(
107113
/*options=*/options,
108-
/*dtype_data=*/float_workspace.data_ptr<float>(),
114+
/*dtype_data=*/reinterpret_cast<float*>(float_workspace.data_ptr()),
109115
/*dtype_size=*/float_workspace.numel(),
110-
/*int_data=*/int_workspace.data_ptr<int>(),
116+
/*int_data=*/reinterpret_cast<int*>(int_workspace.data_ptr()),
111117
/*int_size=*/int_workspace.numel());
112118

113119
switch (logits.scalar_type()) {
114-
case torch::ScalarType::Float: {
120+
case ScalarType::Float: {
115121
Compute</*DTYPE=*/float, /*CAST_DTYPE=*/float>(
116122
/*workspace=*/workspace,
117-
/*logits=*/logits.data_ptr<float>(),
118-
/*targets=*/targets.data_ptr<int>(),
119-
/*srcLengths=*/logit_lengths.data_ptr<int>(),
120-
/*tgtLengths=*/target_lengths.data_ptr<int>(),
121-
/*costs=*/costs.data_ptr<float>(),
122-
/*gradients=*/gradients->data_ptr<float>());
123+
/*logits=*/reinterpret_cast<float*>(logits.data_ptr()),
124+
/*targets=*/reinterpret_cast<int*>(targets.data_ptr()),
125+
/*srcLengths=*/reinterpret_cast<int*>(logit_lengths.data_ptr()),
126+
/*tgtLengths=*/reinterpret_cast<int*>(target_lengths.data_ptr()),
127+
/*costs=*/reinterpret_cast<float*>(costs.data_ptr()),
128+
/*gradients=*/reinterpret_cast<float*>(gradients.data_ptr()));
123129
break;
124130
}
125-
case torch::ScalarType::Half: {
131+
case ScalarType::Half: {
126132
Compute</*DTYPE=*/c10::Half, /*CAST_DTYPE=*/float>(
127133
/*workspace=*/workspace,
128-
/*logits=*/logits.data_ptr<c10::Half>(),
129-
/*targets=*/targets.data_ptr<int>(),
130-
/*srcLengths=*/logit_lengths.data_ptr<int>(),
131-
/*tgtLengths=*/target_lengths.data_ptr<int>(),
132-
/*costs=*/costs.data_ptr<c10::Half>(),
133-
/*gradients=*/gradients->data_ptr<c10::Half>());
134+
/*logits=*/reinterpret_cast<c10::Half*>(logits.data_ptr()),
135+
/*targets=*/reinterpret_cast<int*>(targets.data_ptr()),
136+
/*srcLengths=*/reinterpret_cast<int*>(logit_lengths.data_ptr()),
137+
/*tgtLengths=*/reinterpret_cast<int*>(target_lengths.data_ptr()),
138+
/*costs=*/reinterpret_cast<c10::Half*>(costs.data_ptr()),
139+
/*gradients=*/reinterpret_cast<c10::Half*>(gradients.data_ptr()));
134140
break;
135141
}
136142
default: {
137-
break;
143+
STD_TORCH_CHECK(false, "unreachable");
138144
}
139145
};
140146

141147
return std::make_tuple(costs, gradients);
142148
}
143149

144-
TORCH_LIBRARY_IMPL(torchaudio, CPU, m) {
145-
m.impl("rnnt_loss", &compute);
150+
void boxed_rnnt_loss(
151+
StableIValue* stack,
152+
uint64_t num_args,
153+
uint64_t num_outputs) {
154+
STD_TORCH_CHECK(num_args == 7, "num_args must be 7");
155+
STD_TORCH_CHECK(num_outputs == 2, "num_outputs must be 2");
156+
std::tuple<Tensor, Tensor> res = compute(
157+
/*logits*/ torch::stable::detail::to<Tensor>(stack[0]),
158+
/*targets*/ torch::stable::detail::to<Tensor>(stack[1]),
159+
/*logit_lengths*/ torch::stable::detail::to<Tensor>(stack[2]),
160+
/*target_lengths*/ torch::stable::detail::to<Tensor>(stack[3]),
161+
/*blank*/ float(torch::stable::detail::to<int64_t>(stack[4])),
162+
/*clamp*/ torch::stable::detail::to<double>(stack[5]),
163+
/*fused_log_softmax*/ torch::stable::detail::to<bool>(stack[6]));
164+
stack[0] = torch::stable::detail::from(std::get<0>(res));
165+
stack[1] = torch::stable::detail::from(std::get<1>(res));
166+
}
167+
168+
STABLE_TORCH_LIBRARY_IMPL(torchaudio, CPU, m) {
169+
m.impl("rnnt_loss_forward", &boxed_rnnt_loss);
146170
}
147171

148172
} // namespace cpu

0 commit comments

Comments
 (0)