Skip to content
This repository has been archived by the owner on Aug 23, 2023. It is now read-only.

Commit

Permalink
Add rnnt loss (k2-fsa#891)
Browse files Browse the repository at this point in the history
* Add cpp code of mutual information

* mutual information working

* Add rnnt loss

* Add pruned rnnt loss

* Minor Fixes

* Minor fixes & fix code style

* Fix cpp style

* Fix code style

* Fix s_begin values in padding positions

* Fix bugs related to boundary; Fix s_begin padding value; Add more tests

* Minor fixes

* Fix comments

* Add boundary to pruned loss tests
  • Loading branch information
pkufool authored Jan 17, 2022
1 parent e799928 commit d6323d5
Show file tree
Hide file tree
Showing 13 changed files with 3,258 additions and 1 deletion.
5 changes: 5 additions & 0 deletions .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@
show-source=true
statistics=true
max-line-length=80
per-file-ignores =
# line too long E501
# line break before operator W503
k2/python/k2/rnnt_loss.py: E501, W503
k2/python/tests/rnnt_loss_test.py: W503
exclude =
.git,
setup.py,
Expand Down
2 changes: 2 additions & 0 deletions k2/python/csrc/torch.cu
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include "k2/python/csrc/torch/fsa_algo.h"
#include "k2/python/csrc/torch/index_add.h"
#include "k2/python/csrc/torch/index_select.h"
#include "k2/python/csrc/torch/mutual_information.h"
#include "k2/python/csrc/torch/nbest.h"
#include "k2/python/csrc/torch/ragged.h"
#include "k2/python/csrc/torch/ragged_ops.h"
Expand All @@ -44,6 +45,7 @@ void PybindTorch(py::module &m) {
PybindFsaAlgo(m);
PybindIndexAdd(m);
PybindIndexSelect(m);
PybindMutualInformation(m);
PybindNbest(m);
PybindRagged(m);
PybindRaggedOps(m);
Expand Down
6 changes: 6 additions & 0 deletions k2/python/csrc/torch/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ set(torch_srcs
fsa_algo.cu
index_add.cu
index_select.cu
mutual_information.cu
mutual_information_cpu.cu
nbest.cu
ragged.cu
ragged_ops.cu
Expand All @@ -19,6 +21,10 @@ set(torch_srcs
v2/ragged_shape.cu
)

if (K2_WITH_CUDA)
list(APPEND torch_srcs mutual_information_cuda.cu)
endif()

set(torch_srcs_with_prefix)
foreach(src IN LISTS torch_srcs)
list(APPEND torch_srcs_with_prefix "torch/${src}")
Expand Down
68 changes: 68 additions & 0 deletions k2/python/csrc/torch/mutual_information.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/**
* @copyright
* Copyright 2021 Xiaomi Corporation (authors: Wei Kang)
*
* @copyright
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "k2/csrc/device_guard.h"
#include "k2/python/csrc/torch/mutual_information.h"
#include "k2/python/csrc/torch/torch_util.h"

void PybindMutualInformation(py::module &m) {
m.def(
"mutual_information_forward",
[](torch::Tensor px, torch::Tensor py,
torch::optional<torch::Tensor> boundary,
torch::Tensor p) -> torch::Tensor {
k2::DeviceGuard guard(k2::GetContext(px));
if (px.device().is_cpu()) {
return k2::MutualInformationCpu(px, py, boundary, p);
} else {
#ifdef K2_WITH_CUDA
return k2::MutualInformationCuda(px, py, boundary, p);
#else
K2_LOG(FATAL) << "Failed to find native CUDA module, make sure "
<< "that you compiled the code with K2_WITH_CUDA.";
return torch::Tensor();
#endif
}
},
py::arg("px"), py::arg("py"), py::arg("boundary"), py::arg("p"));

m.def(
"mutual_information_backward",
[](torch::Tensor px, torch::Tensor py,
torch::optional<torch::Tensor> boundary, torch::Tensor p,
torch::Tensor ans_grad) -> std::vector<torch::Tensor> {
k2::DeviceGuard guard(k2::GetContext(px));
if (px.device().is_cpu()) {
return k2::MutualInformationBackwardCpu(px, py, boundary, p,
ans_grad);
} else {
#ifdef K2_WITH_CUDA
return k2::MutualInformationBackwardCuda(px, py, boundary, p,
ans_grad, true);
#else
K2_LOG(FATAL) << "Failed to find native CUDA module, make sure "
<< "that you compiled the code with K2_WITH_CUDA.";
return std::vector<torch::Tensor>();
#endif
}
},
py::arg("px"), py::arg("py"), py::arg("boundary"), py::arg("p"),
py::arg("ans_grad"));
}
107 changes: 107 additions & 0 deletions k2/python/csrc/torch/mutual_information.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
/**
* @copyright
* Copyright 2021 Xiaomi Corporation (authors: Daniel Povey)
*
* @copyright
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef K2_PYTHON_CSRC_TORCH_MUTUAL_INFORMATION_H_
#define K2_PYTHON_CSRC_TORCH_MUTUAL_INFORMATION_H_

#include <torch/extension.h>

#include <vector>

#include "k2/python/csrc/torch.h"

namespace k2 {
/*
Forward of mutual_information. See also comment of `mutual_information`
in mutual_information.py. This is the core recursion
in the sequence-to-sequence mutual information computation.
@param px Tensor of shape [B][S][T + 1]; contains the log-odds ratio of
generating the next x in the sequence, i.e.
xy[b][s][t] is the log of
p(x_s | x_0..x_{s-1}, y_0..y_{s-1}) / p(x_s),
i.e. the log-prob of generating x_s given subsequences of
lengths (s, t), divided by the prior probability of generating
x_s. (See mutual_information.py for more info).
@param py The log-odds ratio of generating the next y in the sequence.
Shape [B][S + 1][T]
@param p This function writes to p[b][s][t] the mutual information between
sub-sequences of x and y of length s and t respectively, from the
b'th sequences in the batch. Its shape is [B][S + 1][T + 1].
Concretely, this function implements the following recursion,
in the case where s_begin == t_begin == 0:
p[b,0,0] = 0.0
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
p[b,s,t-1] + py[b,s,t-1])
if s > 0 or t > 0,
treating values with any -1 index as -infinity.
.. if `boundary` is set, we start fom p[b,s_begin,t_begin]=0.0.
@param boundary If set, a tensor of shape [B][4] of type int64_t, which
contains, where for each batch element b, boundary[b]
equals [s_begin, t_begin, s_end, t_end]
which are the beginning and end (i.e. one-past-the-last)
of the x and y sequences that we should process.
Alternatively, may be a tensor of shape [0][0] and type
int64_t; the elements will default to (0, 0, S, T).
@return A tensor `ans` of shape [B], where this function will set
ans[b] = p[b][s_end][t_end],
with s_end and t_end being (S, T) if `boundary` was specified,
and (boundary[b][2], boundary[b][3]) otherwise.
`ans` represents the mutual information between each pair of
sequences (i.e. x[b] and y[b], although the sequences are not
supplied directy to this function).
The block-dim and grid-dim must both be 1-dimensional, and the block-dim must
be at least 128.
*/
torch::Tensor MutualInformationCpu(
torch::Tensor px, // [B][S][T+1]
torch::Tensor py, // [B][S+1][T]
torch::optional<torch::Tensor> boundary, // [B][4], int64_t.
torch::Tensor p); // [B][S+1][T+1]; an output

torch::Tensor MutualInformationCuda(
torch::Tensor px, // [B][S][T+1]
torch::Tensor py, // [B][S+1][T]
torch::optional<torch::Tensor> boundary, // [B][4], int64_t.
torch::Tensor p); // [B][S+1][T+1]; an output

/*
backward of mutual_information; returns (grad_px, grad_py)
if overwrite_ans_grad == true, this function will overwrite ans_grad with a
value that, if the computation worked correctly, should be identical to or
very close to the value of ans_grad at entry. This can be used
to validate the correctness of this code.
*/
std::vector<torch::Tensor> MutualInformationBackwardCpu(
torch::Tensor px, torch::Tensor py, torch::optional<torch::Tensor> boundary,
torch::Tensor p, torch::Tensor ans_grad);

std::vector<torch::Tensor> MutualInformationBackwardCuda(
torch::Tensor px, torch::Tensor py, torch::optional<torch::Tensor> boundary,
torch::Tensor p, torch::Tensor ans_grad, bool overwrite_ans_grad);

} // namespace k2

void PybindMutualInformation(py::module &m);

#endif // K2_PYTHON_CSRC_TORCH_MUTUAL_INFORMATION_H_
Loading

0 comments on commit d6323d5

Please sign in to comment.