This repository has been archived by the owner on Aug 23, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement Rnnt decoding (k2-fsa#926)
* first working draft of rnnt decoding * FormatOutput works... * Different num frames for FormatOutput works * Update docs * Fix comments, break advance into several stages, add more docs * Add python wrapper * Add more docs * Minor fixes * Fix comments
- Loading branch information
Showing
28 changed files
with
2,550 additions
and
160 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
/** | ||
* Copyright 2022 Xiaomi Corporation (authors: Wei Kang) | ||
* | ||
* See LICENSE for clarification regarding multiple authors | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#include "k2/csrc/array_of_ragged.h" | ||
|
||
namespace k2 { | ||
|
||
Array1OfRaggedShape::Array1OfRaggedShape(RaggedShape *src, int32_t num_srcs) | ||
: num_srcs_(num_srcs) { | ||
K2_CHECK_GE(num_srcs, 1); | ||
K2_CHECK(src); | ||
num_axes_ = src[0].NumAxes(); | ||
c_ = src[0].Context(); | ||
|
||
row_splits_ = | ||
Array2<const int32_t *>(GetCpuContext(), num_axes_ - 1, num_srcs_); | ||
row_ids_ = Array2<const int32_t *>(GetCpuContext(), num_axes_ - 1, num_srcs_); | ||
tot_sizes_ = Array1<int32_t>(GetCpuContext(), num_axes_, 0); | ||
|
||
auto row_splits_acc = row_splits_.Accessor(), | ||
row_ids_acc = row_ids_.Accessor(); | ||
int32_t *tot_sizes_data = tot_sizes_.Data(); | ||
|
||
for (int32_t i = 0; i < num_srcs_; ++i) { | ||
K2_CHECK_EQ(src[i].NumAxes(), num_axes_); | ||
K2_CHECK(c_->IsCompatible(*(src[i].Context()))); | ||
for (int32_t j = 1; j < num_axes_; ++j) { | ||
row_splits_acc(j - 1, i) = src[i].RowSplits(j).Data(); | ||
row_ids_acc(j - 1, i) = src[i].RowIds(j).Data(); | ||
tot_sizes_data[j] += src[i].TotSize(j); | ||
} | ||
tot_sizes_data[0] += src[i].TotSize(0); | ||
} | ||
|
||
row_splits_ = row_splits_.To(c_); | ||
row_ids_ = row_ids_.To(c_); | ||
} | ||
|
||
} // namespace k2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,200 @@ | ||
/** | ||
* Copyright 2022 Xiaomi Corporation (authors: Daniel Povey, Wei Kang) | ||
* | ||
* See LICENSE for clarification regarding multiple authors | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#ifndef K2_CSRC_ARRAY_OF_RAGGED_H_ | ||
#define K2_CSRC_ARRAY_OF_RAGGED_H_ | ||
|
||
#include <string> | ||
#include <utility> | ||
#include <vector> | ||
|
||
#include "k2/csrc/array.h" | ||
#include "k2/csrc/context.h" | ||
#include "k2/csrc/log.h" | ||
#include "k2/csrc/ragged_ops.h" | ||
|
||
namespace k2 { | ||
/* | ||
Array1OfRaggedShape is a convenience function that gives you easy access | ||
to pointers-of-pointers for an array of ragged shapes. | ||
*/ | ||
class Array1OfRaggedShape { | ||
public: | ||
/* | ||
Constructor. | ||
Args: | ||
srcs: pointers to the source shapes, a CPU pointer | ||
num_srcs: the number of source shapes. All shapes must have the | ||
same NumAxes() and must be on the same device. | ||
TODO: we'll likely, later, add optional args which dictate which of | ||
the MetaRowSplits() and MetaRowIds() are to be pre-populated; this should | ||
enable us to save kernels by combining certain operations across the | ||
axes. | ||
*/ | ||
Array1OfRaggedShape(RaggedShape *srcs, int32_t num_srcs); | ||
Array1OfRaggedShape() = default; | ||
|
||
int32_t NumSrcs() const { return num_srcs_; } | ||
int32_t NumAxes() const { return num_axes_; } | ||
|
||
ContextPtr &Context() { return c_; } | ||
|
||
// Returns device-accessible array of row-splits for the individual shapes, | ||
// indexed [axis-1][src], with 0 <= src < num_srcs. The shape of this | ||
// Array2 is [NumAxes() - 1][NumSrcs()]. | ||
const Array2<const int32_t *> *RowSplits() const { return &row_splits_; } | ||
|
||
// Returns device-accessible vector of row-splits for a particular | ||
// axis, indexed by 0 <= src < num_srcs. | ||
const int32_t **RowSplits(int32_t axis) { | ||
return row_splits_.Row(axis - 1).Data(); | ||
} | ||
|
||
// Returns device-accessible array of row-ids for the individual shapes | ||
// indexed [axis-1][src], with 0 <= src < num_srcs. The shape of this | ||
// Array2 is [NumAxes() - 1][NumSrcs()]. | ||
const Array2<const int32_t *> *RowIds() const { return &row_ids_; } | ||
|
||
// Returns device-accessible vector of row-splits for a particular | ||
// axis, indexed by 0 <= src < num_srcs. | ||
const int32_t **RowIds(int32_t axis) { return row_ids_.Row(axis - 1).Data(); } | ||
|
||
/* Return the total size on this axis, which is the sum of the TotSize() of | ||
the individual shapes. Requires 0 <= axis < NumAxes() and | ||
for axis=0 the returned value is the same as Dim0(). | ||
*/ | ||
int32_t TotSize(int32_t axis) const { return tot_sizes_[axis]; } | ||
|
||
// equivalent to TotSize(0). | ||
int32_t Dim0() const { return TotSize(0); } | ||
|
||
/* Return the device-accessible meta-row-splits, which is the cumulative sum, | ||
along the src axis, of the tot-sizes of the individual arrays. | ||
This Array2 is of shape [NumAxes()][NumSrcs() + 1], indexed [axis][src]; | ||
caution, the indexing is different from RowSplits(), there is no offset. | ||
Also, the meta_row_splits0 is a thing, unlike with regular row-splits | ||
which start from 1. | ||
Caution: the lengths of the arrays pointed to by the elements of this | ||
Array2 (which contains pointers!) are of course all different, and | ||
these lengths are currently only available | ||
Implementation note: we can probably just populate this on CPU and transfer | ||
to GPU, this will be faster than invoking an extra kernel in normal cases | ||
when the NumSrcs() is small. [Also: see GetRowInfoMulti()]. | ||
*/ | ||
// TODO: implement it... | ||
Array2<int32_t> MetaRowSplits(); | ||
|
||
// could POSSIBLY add this so this code could be used in functions like | ||
// Stack(). would be like MetaRowSplits but with an extra 1st row containing | ||
// 0,1,2,... We could perhaps create it with 1 extra initial row so this is | ||
// always convenient to output. | ||
// TODO: implement it... | ||
Array2<int32_t> Offsets(); | ||
|
||
/* | ||
Returns the meta-row-splits for a particular axis, with 0 <= axis < | ||
NumAxes(); this is the cumulative sum of the TotSize(axis) for all of the | ||
sources, with MetaRowSplits(axis).Dim() == NumSrcs() + 1. | ||
Note: in ragged_ops.cu we refer to this as composed_row_splits | ||
*/ | ||
// TODO: implement it... | ||
Array1<int32_t> MetaRowSplits(int32_t axis); | ||
|
||
/* Return the device-accessible meta-row-ids, which are the row-ids | ||
corresponding to MetaRowSplits(); this tells us, for indexes into the | ||
appended/concatenated array, which source array they belong to, i.e. | ||
elements are in [0,NumSrcs()-1]. | ||
This cannot be an Array2 because unlike the MetaRowSplits(), all the | ||
row-ids arrays are of different lengths. | ||
Note: in ragged_ops.cu we refer to this as composed_row_ids. | ||
*/ | ||
// TODO: implement it... | ||
Array1<int32_t *> MetaRowIds(); | ||
|
||
/* | ||
Returns the meta-row-ids for a particular axis, with 0 <= axis < NumAxes(); | ||
this is the row-ids corresponding to MetaRowSplits(axis), and its elements | ||
gives, for indexes into the concatentated shape (concatenated on axis 0),m | ||
which source they come from. E.g. element 100 of MetaRowIds(2) | ||
would tell us which source an idx012 with value 100 into axis 2 of | ||
concatenated array would come from. | ||
*/ | ||
// TODO: implement it... | ||
Array1<int32_t> MetaRowIds(int32_t axis); | ||
|
||
private: | ||
ContextPtr c_; | ||
int32_t num_srcs_; | ||
int32_t num_axes_; | ||
Array2<const int32_t *> row_splits_; // shape [num_axes_ - 1][num_srcs_] | ||
Array2<const int32_t *> row_ids_; // shape [num_axes_ - 1][num_srcs_] | ||
Array1<int32_t> tot_sizes_; // dim num_axes_, this is on CPU | ||
}; | ||
|
||
/* | ||
Array1OfRagged<T> is a 1-dimensional array of Ragged<T>. | ||
It is intended for situations where you want to do some operations on | ||
arrays of ragged arrays, without explicitly concatenating them (e.g. to | ||
save time). This is a fairly low-level interface, intended to | ||
be used mostly by CUDA/C++ implementation code. It is a convenience | ||
wrapper that saves you the trouble of creating arrays of pointers. | ||
*/ | ||
template <typename T> | ||
struct Array1OfRagged { | ||
Array1OfRaggedShape shape; | ||
|
||
// Array of the individual values pointers of the source arrays, indexed by | ||
// shape | ||
Array1<T *> values; | ||
|
||
int32_t NumSrcs() const { return values.Dim(); } | ||
ContextPtr &Context() { return shape.Context(); } | ||
|
||
Array1OfRagged() = default; | ||
|
||
/* | ||
Constructor. | ||
Args: | ||
srcs: pointers to the source ragged tensors, a CPU pointer | ||
num_srcs: the number of source ragged tensors. All ragged tensors must | ||
have the same NumAxes() and must be on the same device. | ||
*/ | ||
Array1OfRagged(Ragged<T> *srcs, int32_t num_srcs) { | ||
K2_CHECK_GE(num_srcs, 1); | ||
K2_CHECK(srcs); | ||
values = Array1<T *>(GetCpuContext(), num_srcs); | ||
T **values_data = values.Data(); | ||
std::vector<RaggedShape> shapes(num_srcs); | ||
for (int32_t i = 0; i < num_srcs; ++i) { | ||
shapes[i] = srcs[i].shape; | ||
values_data[i] = srcs[i].values.Data(); | ||
} | ||
shape = Array1OfRaggedShape(shapes.data(), num_srcs); | ||
values = values.To(shape.Context()); | ||
} | ||
}; | ||
|
||
} // namespace k2 | ||
|
||
#endif // K2_CSRC_ARRAY_OF_RAGGED_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
/** | ||
* Copyright 2022 Xiaomi Corporation (authors: Wei Kang) | ||
* | ||
* See LICENSE for clarification regarding multiple authors | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#include "gtest/gtest.h" | ||
#include "k2/csrc/array_of_ragged.h" | ||
#include "k2/csrc/ragged.h" | ||
#include "k2/csrc/ragged_ops.h" | ||
#include "k2/csrc/ragged_utils.h" | ||
#include "k2/csrc/test_utils.h" | ||
|
||
namespace k2 { | ||
|
||
template <typename T> | ||
void TestArray1OfRaggedConstruct() { | ||
int32_t num_srcs = 5; | ||
int32_t num_axes = 4; | ||
|
||
for (auto &c : {GetCpuContext(), GetCudaContext()}) { | ||
std::vector<Ragged<T>> raggeds; | ||
for (int32_t i = 0; i < num_srcs; ++i) { | ||
raggeds.emplace_back( | ||
RandomRagged<T>(0 /*min_value*/, 100 /*max_value*/, | ||
num_axes /*min_num_axes*/, num_axes /*max_num_axes*/, | ||
0 /*min_num_elements*/, 100 /*max_num_elements*/) | ||
.To(c, true /*copy_all*/)); | ||
} | ||
auto array_of_ragged = Array1OfRagged<T>(raggeds.data(), num_srcs); | ||
for (int32_t j = 1; j < num_axes; ++j) { | ||
const int32_t **row_splits = array_of_ragged.shape.RowSplits(j); | ||
const int32_t **row_ids = array_of_ragged.shape.RowIds(j); | ||
Array1<int32_t *> excepted_row_splits(GetCpuContext(), num_srcs); | ||
Array1<int32_t *> excepted_row_ids(GetCpuContext(), num_srcs); | ||
int32_t **excepted_row_splits_data = excepted_row_splits.Data(); | ||
int32_t **excepted_row_ids_data = excepted_row_ids.Data(); | ||
for (int32_t i = 0; i < num_srcs; ++i) { | ||
excepted_row_splits_data[i] = raggeds[i].RowSplits(j).Data(); | ||
excepted_row_ids_data[i] = raggeds[i].RowIds(j).Data(); | ||
} | ||
excepted_row_splits = excepted_row_splits.To(c); | ||
excepted_row_ids = excepted_row_ids.To(c); | ||
excepted_row_splits_data = excepted_row_splits.Data(); | ||
excepted_row_ids_data = excepted_row_ids.Data(); | ||
Array1<int32_t> flags(c, 2, 1); | ||
int32_t *flags_data = flags.Data(); | ||
K2_EVAL( | ||
c, num_srcs, lambda_check_pointer, (int32_t i) { | ||
if (row_splits[i] != excepted_row_splits_data[i]) flags_data[0] = 0; | ||
if (row_ids[i] != excepted_row_ids_data[i]) flags_data[1] = 0; | ||
}); | ||
K2_CHECK(Equal(flags, Array1<int32_t>(c, std::vector<int32_t>{1, 1}))); | ||
} | ||
for (int32_t i = 0; i < num_srcs; ++i) { | ||
K2_CHECK_EQ(array_of_ragged.values[i], raggeds[i].values.Data()); | ||
} | ||
} | ||
} | ||
|
||
TEST(Array1OfRagged, Construct) { | ||
TestArray1OfRaggedConstruct<int32_t>(); | ||
TestArray1OfRaggedConstruct<float>(); | ||
} | ||
|
||
} // namespace k2 |
Oops, something went wrong.