Skip to content

Commit bae56dd

Browse files
committed
Add toString(ScalarType)
1 parent bbf6f38 commit bae56dd

File tree

4 files changed

+76
-71
lines changed

4 files changed

+76
-71
lines changed

src/libtorchaudio/forced_align/cpu/compute.cpp

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -177,14 +177,18 @@ std::tuple<Tensor, Tensor> compute(
177177
STD_TORCH_CHECK(
178178
blank >= 0 && blank < logProbs.size(-1),
179179
"blank must be within [0, num classes)");
180-
181-
STD_TORCH_CHECK(
182-
logProbs.size(1) == torchaudio::util::max<int>(inputLengths),
183-
"input length mismatch");
184-
STD_TORCH_CHECK(
185-
targets.size(1) == torchaudio::util::max<int>(targetLengths),
186-
"target length mismatch");
187-
180+
STABLE_DISPATCH_INDEX_TYPES(
181+
inputLengths.scalar_type(), "forced_align_impl", [&] {
182+
STD_TORCH_CHECK(
183+
logProbs.size(1) == torchaudio::util::max<index_t>(inputLengths),
184+
"input length mismatch");
185+
});
186+
STABLE_DISPATCH_INDEX_TYPES(
187+
targetLengths.scalar_type(), "forced_align_impl", [&] {
188+
STD_TORCH_CHECK(
189+
targets.size(1) == torchaudio::util::max<index_t>(targetLengths),
190+
"target length mismatch");
191+
});
188192
const auto B = logProbs.size(0);
189193
const auto T = logProbs.size(1);
190194
Tensor paths = torchaudio::stable::new_zeros(targets, {B, T});
@@ -209,13 +213,13 @@ void boxed_forced_align_cpu(
209213
STD_TORCH_CHECK(num_args == 5, "num_args must be 5");
210214
STD_TORCH_CHECK(num_outputs == 2, "num_outputs must be 2");
211215
std::tuple<Tensor, Tensor> res = compute(
212-
/*logProbs*/ to<Tensor>(stack[0]),
213-
/*targets*/ to<Tensor>(stack[1]),
214-
/*logit_lengths*/ to<Tensor>(stack[2]),
215-
/*target_lengths*/ to<Tensor>(stack[3]),
216-
/*blank*/ float(to<int64_t>(stack[4])));
217-
stack[0] = from(std::get<0>(res));
218-
stack[1] = from(std::get<1>(res));
216+
/*logProbs*/ torch::stable::detail::to<Tensor>(stack[0]),
217+
/*targets*/ torch::stable::detail::to<Tensor>(stack[1]),
218+
/*logit_lengths*/ torch::stable::detail::to<Tensor>(stack[2]),
219+
/*target_lengths*/ torch::stable::detail::to<Tensor>(stack[3]),
220+
/*blank*/ float(torch::stable::detail::to<int64_t>(stack[4])));
221+
stack[0] = torch::stable::detail::from(std::get<0>(res));
222+
stack[1] = torch::stable::detail::from(std::get<1>(res));
219223
}
220224

221225
STABLE_TORCH_LIBRARY_IMPL(torchaudio, CPU, m) {

src/libtorchaudio/forced_align/gpu/compute.cu

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ void forced_align_impl(
209209
if (backPtrBufferLen == kBackPtrBufferSize || t == T - 1) {
210210
cpuDataTranferStream.synchronize();
211211
// GPU -> GPU copy
212-
bufferCopy = torchaudio::stable::clone(backPtrBuffer);
212+
bufferCopy = torch::stable::clone(backPtrBuffer);
213213
STD_TORCH_CHECK(bufferCopy.is_contiguous(), "unexpected fail, need to implement stable::Tensor::contiguous()")
214214
defaultStream.synchronize();
215215
at::cuda::setCurrentCUDAStream(cpuDataTranferStream);
@@ -316,13 +316,13 @@ void boxed_forced_align_gpu(StableIValue* stack, uint64_t num_args, uint64_t num
316316
STD_TORCH_CHECK(num_args == 5, "num_args must be 5");
317317
STD_TORCH_CHECK(num_outputs == 2, "num_outputs must be 2");
318318
std::tuple<Tensor, Tensor> res = compute(
319-
/*logProbs*/to<Tensor>(stack[0]),
320-
/*targets*/to<Tensor>(stack[1]),
321-
/*logit_lengths*/to<Tensor>(stack[2]),
322-
/*target_lengths*/to<Tensor>(stack[3]),
323-
/*blank*/float(to<int64_t>(stack[4])));
324-
stack[0] = from(std::get<0>(res));
325-
stack[1] = from(std::get<1>(res));
319+
/*logProbs*/torch::stable::detail::to<Tensor>(stack[0]),
320+
/*targets*/torch::stable::detail::to<Tensor>(stack[1]),
321+
/*logit_lengths*/torch::stable::detail::to<Tensor>(stack[2]),
322+
/*target_lengths*/torch::stable::detail::to<Tensor>(stack[3]),
323+
/*blank*/float(torch::stable::detail::to<int64_t>(stack[4])));
324+
stack[0] = torch::stable::detail::from(std::get<0>(res));
325+
stack[1] = torch::stable::detail::from(std::get<1>(res));
326326
}
327327

328328
STABLE_TORCH_LIBRARY_IMPL(torchaudio, CUDA, m) {

src/libtorchaudio/stable/dispatch.h

Lines changed: 47 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,19 @@ using torch::headeronly::ScalarType;
2424

2525
namespace impl {
2626

27+
inline const char* toString(ScalarType t) {
28+
#define DEFINE_CASE(_, name) \
29+
case ScalarType::name: \
30+
return #name;
31+
32+
switch (t) {
33+
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CASE)
34+
default:
35+
return "UNKNOWN_SCALAR";
36+
}
37+
#undef DEFINE_CASE
38+
}
39+
2740
template <ScalarType N>
2841
struct ScalarTypeToCPPType;
2942

@@ -51,21 +64,28 @@ using ScalarTypeToCPPTypeT = typename ScalarTypeToCPPType<N>::type;
5164
return __VA_ARGS__(); \
5265
}
5366

54-
#define STABLE_DISPATCH_SWITCH(TYPE, NAME, ...) \
55-
[&] { \
56-
const auto& the_type = TYPE; \
57-
constexpr const char* at_dispatch_name = NAME; \
58-
switch (the_type) { \
59-
__VA_ARGS__ \
60-
default: \
61-
STD_TORCH_CHECK( \
62-
false, \
63-
'"', \
64-
at_dispatch_name, \
65-
"\" not implemented for '", \
66-
toString(the_type), \
67-
"'"); \
68-
} \
67+
#define STABLE_DISPATCH_CASE_INDEX(enum_type, ...) \
68+
case enum_type: { \
69+
using index_t [[maybe_unused]] = \
70+
torchaudio::stable::impl::ScalarTypeToCPPTypeT<enum_type>; \
71+
return __VA_ARGS__(); \
72+
}
73+
74+
#define STABLE_DISPATCH_SWITCH(TYPE, NAME, ...) \
75+
[&] { \
76+
const auto& the_type = TYPE; \
77+
constexpr const char* at_dispatch_name = NAME; \
78+
switch (the_type) { \
79+
__VA_ARGS__ \
80+
default: \
81+
STD_TORCH_CHECK( \
82+
false, \
83+
'"', \
84+
at_dispatch_name, \
85+
"\" not implemented for '", \
86+
torchaudio::stable::impl::toString(the_type), \
87+
"'"); \
88+
} \
6989
}()
7090

7191
#define STABLE_DISPATCH_CASE_FLOATING_TYPES_AND_HALF(...) \
@@ -76,3 +96,15 @@ using ScalarTypeToCPPTypeT = typename ScalarTypeToCPPType<N>::type;
7696
#define STABLE_DISPATCH_FLOATING_TYPES_AND_HALF(TYPE, NAME, ...) \
7797
STABLE_DISPATCH_SWITCH( \
7898
TYPE, NAME, STABLE_DISPATCH_CASE_FLOATING_TYPES_AND_HALF(__VA_ARGS__))
99+
100+
#define STABLE_DISPATCH_FLOATING_TYPES_AND_HALF(TYPE, NAME, ...) \
101+
STABLE_DISPATCH_SWITCH( \
102+
TYPE, NAME, STABLE_DISPATCH_CASE_FLOATING_TYPES_AND_HALF(__VA_ARGS__))
103+
104+
#define STABLE_DISPATCH_CASE_INDEX_TYPES(...) \
105+
STABLE_DISPATCH_CASE_INDEX(ScalarType::Int, __VA_ARGS__) \
106+
STABLE_DISPATCH_CASE_INDEX(ScalarType::Long, __VA_ARGS__)
107+
108+
#define STABLE_DISPATCH_INDEX_TYPES(TYPE, NAME, ...) \
109+
STABLE_DISPATCH_SWITCH( \
110+
TYPE, NAME, STABLE_DISPATCH_CASE_INDEX_TYPES(__VA_ARGS__))

src/libtorchaudio/stable/ops.h

Lines changed: 2 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@ const T* const_data_ptr(const Tensor& t) {
6161

6262
// TODO: When accessor is implemented in torch::stable, eliminate
6363
// accessor template below.
64-
6564
template <typename T, size_t N>
6665
torchaudio::stable::TensorAccessor<T, N> accessor(const Tensor& t) {
6766
static_assert(
@@ -117,10 +116,6 @@ generic_packed_accessor(const Tensor& t) {
117116
sizes_.data(),
118117
strides_.data());
119118
}
120-
// template<typename T, size_t N, template <typename U> class PtrTraits =
121-
// torchaudio::stable::DefaultPtrTraits, typename index_t = int64_t>
122-
// GenericPackedTensorAccessor<T,N> generic_packed_accessor(const Tensor& t) &&
123-
// = delete;
124119

125120
template <
126121
typename T,
@@ -134,9 +129,6 @@ torchaudio::stable::PackedTensorAccessor32<T, N, PtrTraits> packed_accessor32(
134129
"numel needs to be smaller than int32_t max; otherwise, please use packed_accessor64");
135130
return generic_packed_accessor<T, N, PtrTraits, int32_t>(t);
136131
}
137-
// template<typename T, size_t N, template <typename U> class PtrTraits =
138-
// torchaudio::stable::DefaultPtrTraits> PackedTensorAccessor32<T,N,PtrTraits>
139-
// packed_accessor32(const Tensor& t) && = delete;
140132

141133
template <
142134
typename T,
@@ -147,23 +139,6 @@ torchaudio::stable::PackedTensorAccessor64<T, N, PtrTraits> packed_accessor64(
147139
const Tensor& t) {
148140
return generic_packed_accessor<T, N, PtrTraits, int64_t>();
149141
}
150-
// template<typename T, size_t N, template <typename U> class PtrTraits =
151-
// DefaultPtrTraits> PackedTensorAccessor64<T,N,PtrTraits>
152-
// packed_accessor64(const Tensor& t) && = delete;
153-
154-
// TODO: When https://github.com/pytorch/pytorch/pull/161895 lands, eliminate
155-
// copy_ function below.
156-
inline Tensor copy_(
157-
Tensor& self,
158-
const Tensor& src,
159-
std::optional<bool> non_blocking = std::nullopt) {
160-
const auto num_args = 3;
161-
std::array<StableIValue, num_args> stack{
162-
from(self), from(src), from(non_blocking.value_or(false))};
163-
TORCH_ERROR_CODE_CHECK(
164-
aoti_torch_call_dispatcher("aten::copy_", "", stack.data()));
165-
return to<Tensor>(stack[0]);
166-
}
167142

168143
// TODO: When cpu is implemented in torch::stable, eliminate
169144
// cpu function below.
@@ -224,7 +199,8 @@ inline Tensor new_zeros(
224199
std::optional<bool> pin_memory = std::nullopt) {
225200
int32_t target_dtype{};
226201
if (dtype.has_value()) {
227-
target_dtype = to<int32_t>(from(dtype.value()));
202+
target_dtype = torch::stable::detail::to<int32_t>(
203+
torch::stable::detail::from(dtype.value()));
228204
} else {
229205
TORCH_ERROR_CODE_CHECK(aoti_torch_get_dtype(self.get(), &target_dtype));
230206
}
@@ -268,13 +244,6 @@ inline Tensor new_zeros(
268244
return result;
269245
}
270246

271-
// TODO: https://github.com/pytorch/pytorch/pull/161896
272-
inline Tensor clone(const Tensor& self) {
273-
AtenTensorHandle ret = nullptr;
274-
TORCH_ERROR_CODE_CHECK(aoti_torch_clone(self.get(), &ret));
275-
return Tensor(ret);
276-
}
277-
278247
// An analog of item template function defined in
279248
// ATen/templates/TensorBody.h
280249
template <typename T>

0 commit comments

Comments
 (0)