Skip to content
This repository has been archived by the owner on Aug 23, 2023. It is now read-only.

Commit

Permalink
Implement Rnnt decoding (k2-fsa#926)
Browse files Browse the repository at this point in the history
* first working draft of rnnt decoding

* FormatOutput works...

* Different num frames for FormatOutput works

* Update docs

* Fix comments, break advance into several stages, add more docs

* Add python wrapper

* Add more docs

* Minor fixes

* Fix comments
  • Loading branch information
pkufool authored Mar 16, 2022
1 parent 36e2b8d commit f4b4247
Show file tree
Hide file tree
Showing 28 changed files with 2,550 additions and 160 deletions.
1 change: 1 addition & 0 deletions .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ per-file-ignores =
# line break before operator W503
k2/python/k2/rnnt_loss.py: E501, W503
k2/python/tests/rnnt_loss_test.py: W503
k2/python/tests/rnnt_decode_test.py: W503
exclude =
.git,
setup.py,
Expand Down
4 changes: 4 additions & 0 deletions k2/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ add_subdirectory(host)
# please keep it sorted
set(context_srcs
algorithms.cu
array_of_ragged.cu
array_ops.cu
connect.cu
context.cu
Expand All @@ -65,6 +66,7 @@ set(context_srcs
ragged_utils.cu
rand.cu
rm_epsilon.cu
rnnt_decode.cu
tensor.cu
tensor_ops.cu
thread_pool.cu
Expand Down Expand Up @@ -142,6 +144,7 @@ target_link_libraries(test_utils PUBLIC context gtest)
# please sort the source files alphabetically
set(cuda_test_srcs
algorithms_test.cu
array_of_ragged_test.cu
array_ops_test.cu
array_test.cu
connect_test.cu
Expand All @@ -163,6 +166,7 @@ set(cuda_test_srcs
ragged_utils_test.cu
rand_test.cu
rm_epsilon_test.cu
rnnt_decode_test.cu
tensor_ops_test.cu
tensor_test.cu
thread_pool_test.cu
Expand Down
9 changes: 3 additions & 6 deletions k2/csrc/algorithms.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,12 +119,9 @@ class Renumbering {
return new2old_;
}

/* Return a mapping from new index to old index, with one extra element
containing the total number of kept elements if extra_element == true.
If Keep() can be interpreted as a tails vector, i.e. with 1 at the end
of sub-lists of elements, then New2Old(true) would corresponds to a
row-splits array and Old2New(false) would correspond to a row-ids
array.
/*
Return a mapping from new index to old index, with one extra element
containing the total number of kept elements if extra_element == true.
*/
Array1<int32_t> New2Old(bool extra_element) {
Array1<int32_t> &new2old_part = New2Old();
Expand Down
54 changes: 54 additions & 0 deletions k2/csrc/array_of_ragged.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/**
* Copyright 2022 Xiaomi Corporation (authors: Wei Kang)
*
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "k2/csrc/array_of_ragged.h"

namespace k2 {

Array1OfRaggedShape::Array1OfRaggedShape(RaggedShape *src, int32_t num_srcs)
: num_srcs_(num_srcs) {
K2_CHECK_GE(num_srcs, 1);
K2_CHECK(src);
num_axes_ = src[0].NumAxes();
c_ = src[0].Context();

row_splits_ =
Array2<const int32_t *>(GetCpuContext(), num_axes_ - 1, num_srcs_);
row_ids_ = Array2<const int32_t *>(GetCpuContext(), num_axes_ - 1, num_srcs_);
tot_sizes_ = Array1<int32_t>(GetCpuContext(), num_axes_, 0);

auto row_splits_acc = row_splits_.Accessor(),
row_ids_acc = row_ids_.Accessor();
int32_t *tot_sizes_data = tot_sizes_.Data();

for (int32_t i = 0; i < num_srcs_; ++i) {
K2_CHECK_EQ(src[i].NumAxes(), num_axes_);
K2_CHECK(c_->IsCompatible(*(src[i].Context())));
for (int32_t j = 1; j < num_axes_; ++j) {
row_splits_acc(j - 1, i) = src[i].RowSplits(j).Data();
row_ids_acc(j - 1, i) = src[i].RowIds(j).Data();
tot_sizes_data[j] += src[i].TotSize(j);
}
tot_sizes_data[0] += src[i].TotSize(0);
}

row_splits_ = row_splits_.To(c_);
row_ids_ = row_ids_.To(c_);
}

} // namespace k2
200 changes: 200 additions & 0 deletions k2/csrc/array_of_ragged.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
/**
* Copyright 2022 Xiaomi Corporation (authors: Daniel Povey, Wei Kang)
*
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef K2_CSRC_ARRAY_OF_RAGGED_H_
#define K2_CSRC_ARRAY_OF_RAGGED_H_

#include <string>
#include <utility>
#include <vector>

#include "k2/csrc/array.h"
#include "k2/csrc/context.h"
#include "k2/csrc/log.h"
#include "k2/csrc/ragged_ops.h"

namespace k2 {
/*
Array1OfRaggedShape is a convenience function that gives you easy access
to pointers-of-pointers for an array of ragged shapes.
*/
class Array1OfRaggedShape {
public:
/*
Constructor.
Args:
srcs: pointers to the source shapes, a CPU pointer
num_srcs: the number of source shapes. All shapes must have the
same NumAxes() and must be on the same device.
TODO: we'll likely, later, add optional args which dictate which of
the MetaRowSplits() and MetaRowIds() are to be pre-populated; this should
enable us to save kernels by combining certain operations across the
axes.
*/
Array1OfRaggedShape(RaggedShape *srcs, int32_t num_srcs);
Array1OfRaggedShape() = default;

int32_t NumSrcs() const { return num_srcs_; }
int32_t NumAxes() const { return num_axes_; }

ContextPtr &Context() { return c_; }

// Returns device-accessible array of row-splits for the individual shapes,
// indexed [axis-1][src], with 0 <= src < num_srcs. The shape of this
// Array2 is [NumAxes() - 1][NumSrcs()].
const Array2<const int32_t *> *RowSplits() const { return &row_splits_; }

// Returns device-accessible vector of row-splits for a particular
// axis, indexed by 0 <= src < num_srcs.
const int32_t **RowSplits(int32_t axis) {
return row_splits_.Row(axis - 1).Data();
}

// Returns device-accessible array of row-ids for the individual shapes
// indexed [axis-1][src], with 0 <= src < num_srcs. The shape of this
// Array2 is [NumAxes() - 1][NumSrcs()].
const Array2<const int32_t *> *RowIds() const { return &row_ids_; }

// Returns device-accessible vector of row-splits for a particular
// axis, indexed by 0 <= src < num_srcs.
const int32_t **RowIds(int32_t axis) { return row_ids_.Row(axis - 1).Data(); }

/* Return the total size on this axis, which is the sum of the TotSize() of
the individual shapes. Requires 0 <= axis < NumAxes() and
for axis=0 the returned value is the same as Dim0().
*/
int32_t TotSize(int32_t axis) const { return tot_sizes_[axis]; }

// equivalent to TotSize(0).
int32_t Dim0() const { return TotSize(0); }

/* Return the device-accessible meta-row-splits, which is the cumulative sum,
along the src axis, of the tot-sizes of the individual arrays.
This Array2 is of shape [NumAxes()][NumSrcs() + 1], indexed [axis][src];
caution, the indexing is different from RowSplits(), there is no offset.
Also, the meta_row_splits0 is a thing, unlike with regular row-splits
which start from 1.
Caution: the lengths of the arrays pointed to by the elements of this
Array2 (which contains pointers!) are of course all different, and
these lengths are currently only available
Implementation note: we can probably just populate this on CPU and transfer
to GPU, this will be faster than invoking an extra kernel in normal cases
when the NumSrcs() is small. [Also: see GetRowInfoMulti()].
*/
// TODO: implement it...
Array2<int32_t> MetaRowSplits();

// could POSSIBLY add this so this code could be used in functions like
// Stack(). would be like MetaRowSplits but with an extra 1st row containing
// 0,1,2,... We could perhaps create it with 1 extra initial row so this is
// always convenient to output.
// TODO: implement it...
Array2<int32_t> Offsets();

/*
Returns the meta-row-splits for a particular axis, with 0 <= axis <
NumAxes(); this is the cumulative sum of the TotSize(axis) for all of the
sources, with MetaRowSplits(axis).Dim() == NumSrcs() + 1.
Note: in ragged_ops.cu we refer to this as composed_row_splits
*/
// TODO: implement it...
Array1<int32_t> MetaRowSplits(int32_t axis);

/* Return the device-accessible meta-row-ids, which are the row-ids
corresponding to MetaRowSplits(); this tells us, for indexes into the
appended/concatenated array, which source array they belong to, i.e.
elements are in [0,NumSrcs()-1].
This cannot be an Array2 because unlike the MetaRowSplits(), all the
row-ids arrays are of different lengths.
Note: in ragged_ops.cu we refer to this as composed_row_ids.
*/
// TODO: implement it...
Array1<int32_t *> MetaRowIds();

/*
Returns the meta-row-ids for a particular axis, with 0 <= axis < NumAxes();
this is the row-ids corresponding to MetaRowSplits(axis), and its elements
gives, for indexes into the concatentated shape (concatenated on axis 0),m
which source they come from. E.g. element 100 of MetaRowIds(2)
would tell us which source an idx012 with value 100 into axis 2 of
concatenated array would come from.
*/
// TODO: implement it...
Array1<int32_t> MetaRowIds(int32_t axis);

private:
ContextPtr c_;
int32_t num_srcs_;
int32_t num_axes_;
Array2<const int32_t *> row_splits_; // shape [num_axes_ - 1][num_srcs_]
Array2<const int32_t *> row_ids_; // shape [num_axes_ - 1][num_srcs_]
Array1<int32_t> tot_sizes_; // dim num_axes_, this is on CPU
};

/*
Array1OfRagged<T> is a 1-dimensional array of Ragged<T>.
It is intended for situations where you want to do some operations on
arrays of ragged arrays, without explicitly concatenating them (e.g. to
save time). This is a fairly low-level interface, intended to
be used mostly by CUDA/C++ implementation code. It is a convenience
wrapper that saves you the trouble of creating arrays of pointers.
*/
template <typename T>
struct Array1OfRagged {
Array1OfRaggedShape shape;

// Array of the individual values pointers of the source arrays, indexed by
// shape
Array1<T *> values;

int32_t NumSrcs() const { return values.Dim(); }
ContextPtr &Context() { return shape.Context(); }

Array1OfRagged() = default;

/*
Constructor.
Args:
srcs: pointers to the source ragged tensors, a CPU pointer
num_srcs: the number of source ragged tensors. All ragged tensors must
have the same NumAxes() and must be on the same device.
*/
Array1OfRagged(Ragged<T> *srcs, int32_t num_srcs) {
K2_CHECK_GE(num_srcs, 1);
K2_CHECK(srcs);
values = Array1<T *>(GetCpuContext(), num_srcs);
T **values_data = values.Data();
std::vector<RaggedShape> shapes(num_srcs);
for (int32_t i = 0; i < num_srcs; ++i) {
shapes[i] = srcs[i].shape;
values_data[i] = srcs[i].values.Data();
}
shape = Array1OfRaggedShape(shapes.data(), num_srcs);
values = values.To(shape.Context());
}
};

} // namespace k2

#endif // K2_CSRC_ARRAY_OF_RAGGED_H_
78 changes: 78 additions & 0 deletions k2/csrc/array_of_ragged_test.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
/**
* Copyright 2022 Xiaomi Corporation (authors: Wei Kang)
*
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "gtest/gtest.h"
#include "k2/csrc/array_of_ragged.h"
#include "k2/csrc/ragged.h"
#include "k2/csrc/ragged_ops.h"
#include "k2/csrc/ragged_utils.h"
#include "k2/csrc/test_utils.h"

namespace k2 {

template <typename T>
void TestArray1OfRaggedConstruct() {
int32_t num_srcs = 5;
int32_t num_axes = 4;

for (auto &c : {GetCpuContext(), GetCudaContext()}) {
std::vector<Ragged<T>> raggeds;
for (int32_t i = 0; i < num_srcs; ++i) {
raggeds.emplace_back(
RandomRagged<T>(0 /*min_value*/, 100 /*max_value*/,
num_axes /*min_num_axes*/, num_axes /*max_num_axes*/,
0 /*min_num_elements*/, 100 /*max_num_elements*/)
.To(c, true /*copy_all*/));
}
auto array_of_ragged = Array1OfRagged<T>(raggeds.data(), num_srcs);
for (int32_t j = 1; j < num_axes; ++j) {
const int32_t **row_splits = array_of_ragged.shape.RowSplits(j);
const int32_t **row_ids = array_of_ragged.shape.RowIds(j);
Array1<int32_t *> excepted_row_splits(GetCpuContext(), num_srcs);
Array1<int32_t *> excepted_row_ids(GetCpuContext(), num_srcs);
int32_t **excepted_row_splits_data = excepted_row_splits.Data();
int32_t **excepted_row_ids_data = excepted_row_ids.Data();
for (int32_t i = 0; i < num_srcs; ++i) {
excepted_row_splits_data[i] = raggeds[i].RowSplits(j).Data();
excepted_row_ids_data[i] = raggeds[i].RowIds(j).Data();
}
excepted_row_splits = excepted_row_splits.To(c);
excepted_row_ids = excepted_row_ids.To(c);
excepted_row_splits_data = excepted_row_splits.Data();
excepted_row_ids_data = excepted_row_ids.Data();
Array1<int32_t> flags(c, 2, 1);
int32_t *flags_data = flags.Data();
K2_EVAL(
c, num_srcs, lambda_check_pointer, (int32_t i) {
if (row_splits[i] != excepted_row_splits_data[i]) flags_data[0] = 0;
if (row_ids[i] != excepted_row_ids_data[i]) flags_data[1] = 0;
});
K2_CHECK(Equal(flags, Array1<int32_t>(c, std::vector<int32_t>{1, 1})));
}
for (int32_t i = 0; i < num_srcs; ++i) {
K2_CHECK_EQ(array_of_ragged.values[i], raggeds[i].values.Data());
}
}
}

TEST(Array1OfRagged, Construct) {
TestArray1OfRaggedConstruct<int32_t>();
TestArray1OfRaggedConstruct<float>();
}

} // namespace k2
Loading

0 comments on commit f4b4247

Please sign in to comment.