From 3c21c9d019fe4fdb62c286606f323f3549ea9b4d Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Wed, 23 Sep 2020 13:44:27 +0800 Subject: [PATCH] Wrap Array1 as torch::Tensor. (#173) * Wrap Array1 as torch::Tensor. Fix k2host test cases. * interpret arc.weight from a float to an int. * update the comment for torch.h/torch.cu * fix linker errors for release build. --- .flake8 | 8 + .github/workflows/style_check.yml | 7 +- CMakeLists.txt | 2 +- cmake/pybind11.cmake | 6 + k2/csrc/CMakeLists.txt | 2 +- k2/csrc/default_context.cu | 4 +- k2/csrc/host/CMakeLists.txt | 2 +- k2/csrc/host/fsa.cc | 3 +- k2/csrc/pytorch_context.cu | 23 ++ k2/csrc/pytorch_context.h | 39 ++- k2/python/CMakeLists.txt | 1 + k2/python/csrc/CMakeLists.txt | 33 +- k2/python/csrc/aux_labels.h | 14 - k2/python/csrc/fsa_algo.h | 14 - k2/python/csrc/fsa_equivalent.h | 14 - k2/python/csrc/fsa_util.h | 14 - k2/python/csrc/k2.cc | 30 -- k2/python/csrc/k2.cu | 18 + k2/python/csrc/k2.h | 15 +- k2/python/csrc/properties.h | 14 - k2/python/csrc/torch.cu | 23 ++ k2/python/csrc/torch.h | 18 + k2/python/csrc/torch/CMakeLists.txt | 12 + k2/python/csrc/torch/array.cu | 77 ++++ k2/python/csrc/torch/array.h | 18 + k2/python/csrc/torch/torch_util.cu | 27 ++ k2/python/csrc/torch/torch_util.h | 67 ++++ k2/python/csrc/weights.h | 14 - k2/python/host/CMakeLists.txt | 2 + k2/python/host/csrc/CMakeLists.txt | 16 + k2/python/{ => host}/csrc/CPPLINT.cfg | 0 k2/python/{ => host}/csrc/README.md | 0 k2/python/{ => host}/csrc/array.cc | 5 +- k2/python/{ => host}/csrc/array.h | 10 +- k2/python/{ => host}/csrc/aux_labels.cc | 4 +- k2/python/host/csrc/aux_labels.h | 14 + k2/python/{ => host}/csrc/dlpack.h | 0 k2/python/{ => host}/csrc/fsa.cc | 7 +- k2/python/{ => host}/csrc/fsa.h | 10 +- k2/python/{ => host}/csrc/fsa_algo.cc | 6 +- k2/python/host/csrc/fsa_algo.h | 14 + k2/python/{ => host}/csrc/fsa_equivalent.cc | 4 +- k2/python/host/csrc/fsa_equivalent.h | 14 + k2/python/{ => host}/csrc/fsa_util.cc | 4 +- k2/python/host/csrc/fsa_util.h | 14 + k2/python/host/csrc/k2.cc | 30 ++ k2/python/host/csrc/k2.h | 16 + k2/python/{ => host}/csrc/properties.cc | 4 +- k2/python/host/csrc/properties.h | 14 + k2/python/{ => host}/csrc/tensor.cc | 4 +- k2/python/{ => host}/csrc/tensor.h | 12 +- k2/python/{ => host}/csrc/weights.cc | 4 +- k2/python/host/csrc/weights.h | 14 + k2/python/host/k2host/__init__.py | 10 + k2/python/host/k2host/array.py | 107 ++++++ k2/python/{k2 => host/k2host}/aux_labels.py | 8 +- k2/python/{k2 => host/k2host}/fsa.py | 25 +- k2/python/{k2 => host/k2host}/fsa_algo.py | 48 +-- .../{k2 => host/k2host}/fsa_equivalent.py | 32 +- k2/python/{k2 => host/k2host}/fsa_util.py | 15 +- k2/python/{k2 => host/k2host}/properties.py | 18 +- k2/python/{k2 => host/k2host}/weights.py | 10 +- k2/python/host/tests/CMakeLists.txt | 36 ++ k2/python/host/tests/arcsort_test.py | 110 ++++++ k2/python/host/tests/array_test.py | 107 ++++++ k2/python/{ => host}/tests/aux_labels_test.py | 152 ++++---- k2/python/host/tests/connect_test.py | 132 +++++++ k2/python/host/tests/determinize_test.py | 105 ++++++ .../{ => host}/tests/fsa_equivalent_test.py | 144 ++++---- k2/python/{ => host}/tests/fsa_test.py | 38 +- k2/python/host/tests/intersect_test.py | 91 +++++ k2/python/host/tests/properties_test.py | 330 ++++++++++++++++++ k2/python/host/tests/rmepsilon_test.py | 110 ++++++ k2/python/host/tests/topsort_test.py | 110 ++++++ k2/python/{ => host}/tests/weights_test.py | 31 +- k2/python/k2/__init__.py | 13 +- k2/python/k2/array.py | 158 +++------ k2/python/tests/CMakeLists.txt | 13 +- k2/python/tests/arcsort_test.py | 101 ------ k2/python/tests/array_test.py | 186 +++++----- k2/python/tests/connect_test.py | 131 ------- k2/python/tests/determinize_test.py | 106 ------ k2/python/tests/intersect_test.py | 90 ----- k2/python/tests/properties_test.py | 330 ------------------ k2/python/tests/rmepsilon_test.py | 110 ------ k2/python/tests/topsort_test.py | 109 ------ 86 files changed, 2226 insertions(+), 1631 deletions(-) delete mode 100644 k2/python/csrc/aux_labels.h delete mode 100644 k2/python/csrc/fsa_algo.h delete mode 100644 k2/python/csrc/fsa_equivalent.h delete mode 100644 k2/python/csrc/fsa_util.h delete mode 100644 k2/python/csrc/k2.cc create mode 100644 k2/python/csrc/k2.cu delete mode 100644 k2/python/csrc/properties.h create mode 100644 k2/python/csrc/torch.cu create mode 100644 k2/python/csrc/torch.h create mode 100644 k2/python/csrc/torch/CMakeLists.txt create mode 100644 k2/python/csrc/torch/array.cu create mode 100644 k2/python/csrc/torch/array.h create mode 100644 k2/python/csrc/torch/torch_util.cu create mode 100644 k2/python/csrc/torch/torch_util.h delete mode 100644 k2/python/csrc/weights.h create mode 100644 k2/python/host/CMakeLists.txt create mode 100644 k2/python/host/csrc/CMakeLists.txt rename k2/python/{ => host}/csrc/CPPLINT.cfg (100%) rename k2/python/{ => host}/csrc/README.md (100%) rename k2/python/{ => host}/csrc/array.cc (98%) rename k2/python/{ => host}/csrc/array.h (53%) rename k2/python/{ => host}/csrc/aux_labels.cc (95%) create mode 100644 k2/python/host/csrc/aux_labels.h rename k2/python/{ => host}/csrc/dlpack.h (100%) rename k2/python/{ => host}/csrc/fsa.cc (95%) rename k2/python/{ => host}/csrc/fsa.h (53%) rename k2/python/{ => host}/csrc/fsa_algo.cc (97%) create mode 100644 k2/python/host/csrc/fsa_algo.h rename k2/python/{ => host}/csrc/fsa_equivalent.cc (96%) create mode 100644 k2/python/host/csrc/fsa_equivalent.h rename k2/python/{ => host}/csrc/fsa_util.cc (77%) create mode 100644 k2/python/host/csrc/fsa_util.h create mode 100644 k2/python/host/csrc/k2.cc create mode 100644 k2/python/host/csrc/k2.h rename k2/python/{ => host}/csrc/properties.cc (93%) create mode 100644 k2/python/host/csrc/properties.h rename k2/python/{ => host}/csrc/tensor.cc (98%) rename k2/python/{ => host}/csrc/tensor.h (91%) rename k2/python/{ => host}/csrc/weights.cc (95%) create mode 100644 k2/python/host/csrc/weights.h create mode 100644 k2/python/host/k2host/__init__.py create mode 100644 k2/python/host/k2host/array.py rename k2/python/{k2 => host/k2host}/aux_labels.py (91%) rename k2/python/{k2 => host/k2host}/fsa.py (70%) rename k2/python/{k2 => host/k2host}/fsa_algo.py (71%) rename k2/python/{k2 => host/k2host}/fsa_equivalent.py (62%) rename k2/python/{k2 => host/k2host}/fsa_util.py (79%) rename k2/python/{k2 => host/k2host}/properties.py (75%) rename k2/python/{k2 => host/k2host}/weights.py (71%) create mode 100644 k2/python/host/tests/CMakeLists.txt create mode 100644 k2/python/host/tests/arcsort_test.py create mode 100644 k2/python/host/tests/array_test.py rename k2/python/{ => host}/tests/aux_labels_test.py (55%) create mode 100644 k2/python/host/tests/connect_test.py create mode 100644 k2/python/host/tests/determinize_test.py rename k2/python/{ => host}/tests/fsa_equivalent_test.py (53%) rename k2/python/{ => host}/tests/fsa_test.py (55%) create mode 100644 k2/python/host/tests/intersect_test.py create mode 100644 k2/python/host/tests/properties_test.py create mode 100644 k2/python/host/tests/rmepsilon_test.py create mode 100644 k2/python/host/tests/topsort_test.py rename k2/python/{ => host}/tests/weights_test.py (67%) delete mode 100644 k2/python/tests/arcsort_test.py delete mode 100644 k2/python/tests/connect_test.py delete mode 100644 k2/python/tests/determinize_test.py delete mode 100644 k2/python/tests/intersect_test.py delete mode 100644 k2/python/tests/properties_test.py delete mode 100644 k2/python/tests/rmepsilon_test.py delete mode 100644 k2/python/tests/topsort_test.py diff --git a/.flake8 b/.flake8 index cc56566f7..636561126 100644 --- a/.flake8 +++ b/.flake8 @@ -2,3 +2,11 @@ show-source=true statistics=true max-line-length=80 +exclude = + .git, + build, + k2/python/host + +ignore = + # E127 continuation line over-indented for visual indent + E127, diff --git a/.github/workflows/style_check.yml b/.github/workflows/style_check.yml index 34752981a..aab9c07d1 100644 --- a/.github/workflows/style_check.yml +++ b/.github/workflows/style_check.yml @@ -17,7 +17,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [3.5, 3.6, 3.7, 3.8] + python-version: [3.7, 3.8] steps: - uses: actions/checkout@v2 @@ -32,7 +32,7 @@ jobs: - name: Install Python dependencies run: | python3 -m pip install --upgrade pip - python3 -m pip install --upgrade flake8 + python3 -m pip install --upgrade flake8==3.8.3 - name: Run flake8 shell: bash @@ -40,8 +40,7 @@ jobs: run: | # stop the build if there are Python syntax errors or undefined names flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics - # exit-zero treats all errors as warnings. - flake8 . --count --exit-zero --max-complexity=10 --max-line-length=79 --statistics + flake8 . # TODO(fangjun): build a docker for style check # - name: Install cppcheck diff --git a/CMakeLists.txt b/CMakeLists.txt index 00bc05437..178f342dc 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -97,10 +97,10 @@ enable_testing() list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake) include(pybind11) if(USE_PYTORCH) + add_definitions(-DK2_USE_PYTORCH) include(torch) endif() include(cub) include(googletest) - add_subdirectory(k2) diff --git a/cmake/pybind11.cmake b/cmake/pybind11.cmake index 78cb28a1e..36b480a7c 100644 --- a/cmake/pybind11.cmake +++ b/cmake/pybind11.cmake @@ -11,9 +11,15 @@ function(download_pybind11) set(pybind11_URL "https://github.com/pybind/pybind11/archive/v2.5.0.tar.gz") set(pybind11_HASH "SHA256=97504db65640570f32d3fdf701c25a340c8643037c3b69aec469c10c93dc8504") + set(double_quotes "\"") + set(dollar "\$") + set(semicolon "\;") FetchContent_Declare(pybind11 URL ${pybind11_URL} URL_HASH ${pybind11_HASH} + PATCH_COMMAND + sed -i s/\\${double_quotes}-flto\\\\${dollar}/\\${double_quotes}-Xcompiler=-flto${dollar}/g "tools/pybind11Tools.cmake" && + sed -i s/${seimcolon}-fno-fat-lto-objects/${seimcolon}-Xcompiler=-fno-fat-lto-objects/g "tools/pybind11Tools.cmake" ) FetchContent_GetProperties(pybind11) diff --git a/k2/csrc/CMakeLists.txt b/k2/csrc/CMakeLists.txt index f63c783bb..177fb2b31 100644 --- a/k2/csrc/CMakeLists.txt +++ b/k2/csrc/CMakeLists.txt @@ -28,7 +28,7 @@ else() endif() # the target -add_library(context STATIC ${context_srcs}) +add_library(context SHARED ${context_srcs}) set_target_properties(context PROPERTIES CUDA_SEPARABLE_COMPILATION ON) # lib deps diff --git a/k2/csrc/default_context.cu b/k2/csrc/default_context.cu index 8e60bde6a..80d84cc56 100644 --- a/k2/csrc/default_context.cu +++ b/k2/csrc/default_context.cu @@ -39,7 +39,7 @@ class CpuContext : public Context { int32_t ret = posix_memalign(&p, kAlignment, bytes); K2_CHECK_EQ(ret, 0); } - if (deleter_context) *deleter_context = nullptr; + if (deleter_context != nullptr) *deleter_context = nullptr; return p; } @@ -75,7 +75,7 @@ class CudaContext : public Context { auto ret = cudaMalloc(&p, bytes); K2_CHECK_CUDA_ERROR(ret); } - if (deleter_context) *deleter_context = nullptr; + if (deleter_context != nullptr) *deleter_context = nullptr; return p; } diff --git a/k2/csrc/host/CMakeLists.txt b/k2/csrc/host/CMakeLists.txt index 3616d2358..640ab85f5 100644 --- a/k2/csrc/host/CMakeLists.txt +++ b/k2/csrc/host/CMakeLists.txt @@ -4,7 +4,7 @@ # the target # please sort the source files alphabetically -add_library(fsa +add_library(fsa SHARED arcsort.cc aux_labels.cc connect.cc diff --git a/k2/csrc/host/fsa.cc b/k2/csrc/host/fsa.cc index 78f436949..85ca278d2 100644 --- a/k2/csrc/host/fsa.cc +++ b/k2/csrc/host/fsa.cc @@ -30,7 +30,8 @@ inline std::size_t AlignTo(std::size_t b, std::size_t alignment) { namespace k2host { std::ostream &operator<<(std::ostream &os, const Arc &arc) { - os << arc.src_state << " " << arc.dest_state << " " << arc.label; + os << arc.src_state << " " << arc.dest_state << " " << arc.label << " " + << arc.weight; return os; } diff --git a/k2/csrc/pytorch_context.cu b/k2/csrc/pytorch_context.cu index 25bd5dd91..304108203 100644 --- a/k2/csrc/pytorch_context.cu +++ b/k2/csrc/pytorch_context.cu @@ -23,4 +23,27 @@ ContextPtr GetCudaContext(int32_t gpu_id /*= -1*/) { return std::make_shared(gpu_id); } +RegionPtr NewRegion(torch::Tensor &tensor) { + auto ans = std::make_shared(); + if (tensor.device().type() == torch::kCPU) { + ans->context = GetCpuContext(); + } else if (tensor.is_cuda()) { + ans->context = GetCudaContext(tensor.device().index()); + } else { + K2_LOG(FATAL) << "Unsupported device: " << tensor.device() + << "\nOnly CPU and CUDA are supported"; + } + + // NOTE: the tensor is passed from Python and we have + // to retain it to avoid potential segmentation fault. + // + // It will be freed in `Context::Deallocate`. + auto *managed_tensor = new ManagedTensor(tensor); + ans->data = tensor.data_ptr(); + ans->deleter_context = managed_tensor; + ans->num_bytes = tensor.nbytes(); + ans->bytes_used = ans->num_bytes; + return ans; +} + } // namespace k2 diff --git a/k2/csrc/pytorch_context.h b/k2/csrc/pytorch_context.h index 9d4f4de57..3f7682045 100644 --- a/k2/csrc/pytorch_context.h +++ b/k2/csrc/pytorch_context.h @@ -16,12 +16,21 @@ #include #include "c10/cuda/CUDACachingAllocator.h" +#include "c10/cuda/CUDAFunctions.h" #include "k2/csrc/context.h" #include "k2/csrc/log.h" #include "torch/torch.h" namespace k2 { +class ManagedTensor { + public: + explicit ManagedTensor(torch::Tensor &tensor) : handle_(tensor) {} + + private: + torch::Tensor handle_; // retain a copy of the tensor passed from Python +}; + class PytorchCpuContext : public Context { private: PytorchCpuContext() { @@ -46,12 +55,18 @@ class PytorchCpuContext : public Context { void *Allocate(std::size_t bytes, void **deleter_context) override { void *p = allocator_->raw_allocate(bytes); - if (deleter_context) *deleter_context = nullptr; + if (deleter_context != nullptr) *deleter_context = nullptr; return p; } - void Deallocate(void *data, void * /*deleter_context*/) override { - allocator_->raw_deallocate(data); + void Deallocate(void *data, void *deleter_context) override { + if (deleter_context != nullptr) { + // a non-empty `deleter_context` indicates that + // the memory is passed from a `torch::Tensor` + delete reinterpret_cast(deleter_context); + } else { + allocator_->raw_deallocate(data); + } } bool IsCompatible(const Context &other) const override { @@ -94,12 +109,18 @@ class PytorchCudaContext : public Context { void *Allocate(std::size_t bytes, void **deleter_context) override { void *p = allocator_->raw_allocate(bytes); - if (deleter_context) *deleter_context = nullptr; + if (deleter_context != nullptr) *deleter_context = nullptr; return p; } - void Deallocate(void *data, void * /*deleter_context*/) override { - allocator_->raw_deallocate(data); + void Deallocate(void *data, void *deleter_context) override { + if (deleter_context != nullptr) { + // a non-empty `deleter_context` indicates that + // the memory is passed from a `torch::Tensor` + delete reinterpret_cast(deleter_context); + } else { + allocator_->raw_deallocate(data); + } } bool IsCompatible(const Context &other) const override { @@ -116,6 +137,12 @@ class PytorchCudaContext : public Context { int32_t gpu_id_; }; +// Construct a region from a `torch::Tensor`. +// +// The resulting region shares the underlying memory with +// the given tensor. +RegionPtr NewRegion(torch::Tensor &tensor); + } // namespace k2 #endif // K2_CSRC_PYTORCH_CONTEXT_H_ diff --git a/k2/python/CMakeLists.txt b/k2/python/CMakeLists.txt index 60d6382f6..64f7b3d23 100644 --- a/k2/python/CMakeLists.txt +++ b/k2/python/CMakeLists.txt @@ -1,2 +1,3 @@ add_subdirectory(csrc) add_subdirectory(tests) +add_subdirectory(host) diff --git a/k2/python/csrc/CMakeLists.txt b/k2/python/csrc/CMakeLists.txt index e32f0205b..ee0da564b 100644 --- a/k2/python/csrc/CMakeLists.txt +++ b/k2/python/csrc/CMakeLists.txt @@ -1,16 +1,23 @@ -# please sort the files alphabetically -pybind11_add_module(_k2 - array.cc - aux_labels.cc - fsa.cc - fsa_algo.cc - fsa_equivalent.cc - fsa_util.cc - k2.cc - properties.cc - tensor.cc - weights.cc +# please keep the list sorted +set(k2_srcs + k2.cu + torch.cu ) -target_include_directories(_k2 PRIVATE ${CMAKE_SOURCE_DIR}) +if(USE_PYTORCH) + add_definitions(-DTORCH_API_INCLUDE_EXTENSION_H) + add_subdirectory(torch) + set(k2_srcs ${k2_srcs} ${torch_srcs}) + set(k2_deps + ${TORCH_LIBRARIES} + ${TORCH_DIR}/lib/libtorch_python.so + ) +else() + message(FATAL_ERROR "Please select a framework.") +endif() + +pybind11_add_module(_k2 ${k2_srcs}) +target_link_libraries(_k2 PRIVATE ${k2_deps}) +target_link_libraries(_k2 PRIVATE context) target_link_libraries(_k2 PRIVATE fsa) +target_include_directories(_k2 PRIVATE ${CMAKE_SOURCE_DIR}) diff --git a/k2/python/csrc/aux_labels.h b/k2/python/csrc/aux_labels.h deleted file mode 100644 index e2f0883e1..000000000 --- a/k2/python/csrc/aux_labels.h +++ /dev/null @@ -1,14 +0,0 @@ -// k2/python/csrc/aux_labels.h - -// Copyright (c) 2020 Xiaomi Corporation (author: Haowen Qiu) - -// See ../../../LICENSE for clarification regarding multiple authors - -#ifndef K2_PYTHON_CSRC_AUX_LABELS_H_ -#define K2_PYTHON_CSRC_AUX_LABELS_H_ - -#include "k2/python/csrc/k2.h" - -void PybindAuxLabels(py::module &m); - -#endif // K2_PYTHON_CSRC_AUX_LABELS_H_ diff --git a/k2/python/csrc/fsa_algo.h b/k2/python/csrc/fsa_algo.h deleted file mode 100644 index 531f86590..000000000 --- a/k2/python/csrc/fsa_algo.h +++ /dev/null @@ -1,14 +0,0 @@ -// k2/python/csrc/fsa_algo.h - -// Copyright (c) 2020 Xiaomi Corporation (author: Haowen Qiu) - -// See ../../../LICENSE for clarification regarding multiple authors - -#ifndef K2_PYTHON_CSRC_FSA_ALGO_H_ -#define K2_PYTHON_CSRC_FSA_ALGO_H_ - -#include "k2/python/csrc/k2.h" - -void PybindFsaAlgo(py::module &m); - -#endif // K2_PYTHON_CSRC_FSA_ALGO_H_ diff --git a/k2/python/csrc/fsa_equivalent.h b/k2/python/csrc/fsa_equivalent.h deleted file mode 100644 index 11c1c8aa5..000000000 --- a/k2/python/csrc/fsa_equivalent.h +++ /dev/null @@ -1,14 +0,0 @@ -// k2/python/csrc/fsa_equivalent.h - -// Copyright (c) 2020 Xiaomi Corporation (author: Haowen Qiu) - -// See ../../../LICENSE for clarification regarding multiple authors - -#ifndef K2_PYTHON_CSRC_FSA_EQUIVALENT_H_ -#define K2_PYTHON_CSRC_FSA_EQUIVALENT_H_ - -#include "k2/python/csrc/k2.h" - -void PybindFsaEquivalent(py::module &m); - -#endif // K2_PYTHON_CSRC_FSA_EQUIVALENT_H_ diff --git a/k2/python/csrc/fsa_util.h b/k2/python/csrc/fsa_util.h deleted file mode 100644 index c97d8b8a8..000000000 --- a/k2/python/csrc/fsa_util.h +++ /dev/null @@ -1,14 +0,0 @@ -// k2/python/csrc/fsa_util.h - -// Copyright (c) 2020 Fangjun Kuang (csukuangfj@gmail.com) - -// See ../../../LICENSE for clarification regarding multiple authors - -#ifndef K2_PYTHON_CSRC_FSA_UTIL_H_ -#define K2_PYTHON_CSRC_FSA_UTIL_H_ - -#include "k2/python/csrc/k2.h" - -void PybindFsaUtil(py::module &m); - -#endif // K2_PYTHON_CSRC_FSA_UTIL_H_ diff --git a/k2/python/csrc/k2.cc b/k2/python/csrc/k2.cc deleted file mode 100644 index 5968ce227..000000000 --- a/k2/python/csrc/k2.cc +++ /dev/null @@ -1,30 +0,0 @@ -// k2/python/csrc/k2.cc - -// Copyright (c) 2020 Fangjun Kuang (csukuangfj@gmail.com) - -// See ../../../LICENSE for clarification regarding multiple authors - -#include "k2/python/csrc/k2.h" - -#include "k2/python/csrc/array.h" -#include "k2/python/csrc/aux_labels.h" -#include "k2/python/csrc/fsa.h" -#include "k2/python/csrc/fsa_algo.h" -#include "k2/python/csrc/fsa_equivalent.h" -#include "k2/python/csrc/fsa_util.h" -#include "k2/python/csrc/properties.h" -#include "k2/python/csrc/weights.h" - -PYBIND11_MODULE(_k2, m) { - m.doc() = "pybind11 binding of k2"; - PybindArc(m); - PybindArray(m); - PybindArray2Size(m); - PybindFsa(m); - PybindFsaUtil(m); - PybindFsaAlgo(m); - PybindFsaEquivalent(m); - PybindProperties(m); - PybindAuxLabels(m); - PybindWeights(m); -} diff --git a/k2/python/csrc/k2.cu b/k2/python/csrc/k2.cu new file mode 100644 index 000000000..1467a4912 --- /dev/null +++ b/k2/python/csrc/k2.cu @@ -0,0 +1,18 @@ +/** + * @brief python wrappers for k2. + * + * @copyright + * Copyright (c) 2020 Mobvoi AI Lab, Beijing, China (authors: Fangjun Kuang) + * + * @copyright + * See LICENSE for clarification regarding multiple authors + */ + +#include "k2/python/csrc/k2.h" + +#include "k2/python/csrc/torch.h" + +PYBIND11_MODULE(_k2, m) { + m.doc() = "pybind11 binding of k2"; + PybindTorch(m); +} diff --git a/k2/python/csrc/k2.h b/k2/python/csrc/k2.h index 764299d32..2aaac5489 100644 --- a/k2/python/csrc/k2.h +++ b/k2/python/csrc/k2.h @@ -1,15 +1,18 @@ -// k2/python/csrc/k2.h - -// Copyright (c) 2020 Fangjun Kuang (csukuangfj@gmail.com) - -// See ../../../LICENSE for clarification regarding multiple authors +/** + * @brief python wrappers for k2. + * + * @copyright + * Copyright (c) 2020 Mobvoi AI Lab, Beijing, China (authors: Fangjun Kuang) + * + * @copyright + * See LICENSE for clarification regarding multiple authors + */ #ifndef K2_PYTHON_CSRC_K2_H_ #define K2_PYTHON_CSRC_K2_H_ #include "pybind11/pybind11.h" #include "pybind11/stl.h" -#include "k2/csrc/log.h" namespace py = pybind11; diff --git a/k2/python/csrc/properties.h b/k2/python/csrc/properties.h deleted file mode 100644 index 04cd2148e..000000000 --- a/k2/python/csrc/properties.h +++ /dev/null @@ -1,14 +0,0 @@ -// k2/python/csrc/properties.h - -// Copyright (c) 2020 Xiaomi Corporation (author: Haowen Qiu) - -// See ../../../LICENSE for clarification regarding multiple authors - -#ifndef K2_PYTHON_CSRC_PROPERTIES_H_ -#define K2_PYTHON_CSRC_PROPERTIES_H_ - -#include "k2/python/csrc/k2.h" - -void PybindProperties(py::module &m); - -#endif // K2_PYTHON_CSRC_PROPERTIES_H_ diff --git a/k2/python/csrc/torch.cu b/k2/python/csrc/torch.cu new file mode 100644 index 000000000..5ecd1d170 --- /dev/null +++ b/k2/python/csrc/torch.cu @@ -0,0 +1,23 @@ +/** + * @brief Everything related to PyTorch for k2 Python wrappers. + * + * @copyright + * Copyright (c) 2020 Mobvoi AI Lab, Beijing, China (authors: Fangjun Kuang) + * + * @copyright + * See LICENSE for clarification regarding multiple authors + */ + +#include "k2/python/csrc/torch.h" + +#if defined(K2_USE_PYTORCH) + +#include "k2/python/csrc/torch/array.h" + +void PybindTorch(py::module &m) { PybindArray(m); } + +#else + +void PybindTorch(py::module &) {} + +#endif diff --git a/k2/python/csrc/torch.h b/k2/python/csrc/torch.h new file mode 100644 index 000000000..1f85b2d1b --- /dev/null +++ b/k2/python/csrc/torch.h @@ -0,0 +1,18 @@ +/** + * @brief Everything related to PyTorch for k2 Python wrappers. + * + * @copyright + * Copyright (c) 2020 Mobvoi AI Lab, Beijing, China (authors: Fangjun Kuang) + * + * @copyright + * See LICENSE for clarification regarding multiple authors + */ + +#ifndef K2_PYTHON_CSRC_TORCH_H_ +#define K2_PYTHON_CSRC_TORCH_H_ + +#include "k2/python/csrc/k2.h" + +void PybindTorch(py::module &m); + +#endif // K2_PYTHON_CSRC_TORCH_H_ diff --git a/k2/python/csrc/torch/CMakeLists.txt b/k2/python/csrc/torch/CMakeLists.txt new file mode 100644 index 000000000..166e8e0b8 --- /dev/null +++ b/k2/python/csrc/torch/CMakeLists.txt @@ -0,0 +1,12 @@ +# please keep the list sorted +set(torch_srcs + array.cu + torch_util.cu +) + +set(torch_srcs_with_prefix) +foreach(src IN LISTS torch_srcs) + list(APPEND torch_srcs_with_prefix "torch/${src}") +endforeach() + +set(torch_srcs ${torch_srcs_with_prefix} PARENT_SCOPE) diff --git a/k2/python/csrc/torch/array.cu b/k2/python/csrc/torch/array.cu new file mode 100644 index 000000000..5ef86dd97 --- /dev/null +++ b/k2/python/csrc/torch/array.cu @@ -0,0 +1,77 @@ +/** + * @brief python wrappers for Array. + * + * @copyright + * Copyright (c) 2020 Mobvoi AI Lab, Beijing, China (authors: Fangjun Kuang) + * + * @copyright + * See LICENSE for clarification regarding multiple authors + */ + +#include + +#include "c10/core/ScalarType.h" +#include "k2/csrc/array.h" +#include "k2/csrc/pytorch_context.h" +#include "k2/python/csrc/torch/array.h" +#include "k2/python/csrc/torch/torch_util.h" +#include "torch/extension.h" + +namespace k2 { + +template +static void PybindArray1Tpl(py::module &m, const char *name) { + using PyClass = Array1; + py::class_ pyclass(m, name); + pyclass.def(py::init<>()); + pyclass.def("tensor", [](PyClass &self) { return ToTensor(self); }); + + pyclass.def_static( + "from_tensor", + [](torch::Tensor &tensor) { return FromTensor(tensor); }, + py::arg("tensor")); + + // the following functions are for testing only + pyclass.def( + "get", [](const PyClass &self, int32_t i) { return self[i]; }, + py::arg("i")); + pyclass.def("__str__", [](const PyClass &self) { + std::ostringstream os; + os << self; + return os.str(); + }); +} + +static void PybindArrayImpl(py::module &m) { + // users should not use classes with prefix `_` in Python. + PybindArray1Tpl(m, "_FloatArray1"); + PybindArray1Tpl(m, "_Int32Array1"); + + // the following functions are for testing purposes + // and they can be removed later. + m.def("get_cpu_float_array1", []() { + return Array1(GetCpuContext(), {1, 2, 3, 4}); + }); + + m.def("get_cpu_int_array1", []() { + return Array1(GetCpuContext(), {1, 2, 3, 4}); + }); + + m.def( + "get_cuda_float_array1", + [](int32_t gpu_id = -1) { + return Array1(GetCudaContext(gpu_id), {0, 1, 2, 3}); + }, + py::arg("gpu_id") = -1); + + m.def( + "get_cuda_int_array1", + [](int32_t gpu_id = -1) { + return Array1(GetCudaContext(gpu_id), {0, 1, 2, 3}); + }, + py::arg("gpu_id") = -1); +} + +} // namespace k2 + +void PybindArray(py::module &m) { k2::PybindArrayImpl(m); } diff --git a/k2/python/csrc/torch/array.h b/k2/python/csrc/torch/array.h new file mode 100644 index 000000000..eef25e043 --- /dev/null +++ b/k2/python/csrc/torch/array.h @@ -0,0 +1,18 @@ +/** + * @brief python wrappers for Array. + * + * @copyright + * Copyright (c) 2020 Mobvoi AI Lab, Beijing, China (authors: Fangjun Kuang) + * + * @copyright + * See LICENSE for clarification regarding multiple authors + */ + +#ifndef K2_PYTHON_CSRC_TORCH_ARRAY_H_ +#define K2_PYTHON_CSRC_TORCH_ARRAY_H_ + +#include "k2/python/csrc/k2.h" + +void PybindArray(py::module &m); + +#endif // K2_PYTHON_CSRC_TORCH_ARRAY_H_ diff --git a/k2/python/csrc/torch/torch_util.cu b/k2/python/csrc/torch/torch_util.cu new file mode 100644 index 000000000..db70fc86a --- /dev/null +++ b/k2/python/csrc/torch/torch_util.cu @@ -0,0 +1,27 @@ +/** + * @copyright + * Copyright (c) 2020 Mobvoi AI Lab, Beijing, China (authors: Fangjun Kuang) + * + * @copyright + * See LICENSE for clarification regarding multiple authors + */ + +#include "k2/python/csrc/torch/torch_util.h" +#include "torch/extension.h" + +namespace k2 { + +torch::DeviceType ToTorchDeviceType(DeviceType type) { + switch (type) { + case kCuda: + return torch::kCUDA; + case kCpu: + return torch::kCPU; + case kUnk: // fall-through + default: + K2_LOG(FATAL) << "kUnk is not supported!"; + return torch::kCPU; // unreachable code + } +} + +} // namespace k2 diff --git a/k2/python/csrc/torch/torch_util.h b/k2/python/csrc/torch/torch_util.h new file mode 100644 index 000000000..b1acfdb6a --- /dev/null +++ b/k2/python/csrc/torch/torch_util.h @@ -0,0 +1,67 @@ +/** + * @copyright + * Copyright (c) 2020 Mobvoi AI Lab, Beijing, China (authors: Fangjun Kuang) + * + * @copyright + * See LICENSE for clarification regarding multiple authors + */ + +#ifndef K2_PYTHON_CSRC_TORCH_TORCH_UTIL_H_ +#define K2_PYTHON_CSRC_TORCH_TORCH_UTIL_H_ + +#include "k2/csrc/array.h" +#include "k2/csrc/log.h" +#include "k2/csrc/pytorch_context.h" +#include "torch/extension.h" + +namespace k2 { + +torch::DeviceType ToTorchDeviceType(DeviceType type); + +// Some versions of PyTorch do not have `c10::CppTypeToScalarType`, +// so we implement our own here. +template +struct ToScalarType; + +#define TO_SCALAR_TYPE(cpp_type, scalar_type) \ + template <> \ + struct ToScalarType \ + : std::integral_constant {}; + +// TODO(fangjun): add other types if needed +TO_SCALAR_TYPE(float, torch::kFloat); +TO_SCALAR_TYPE(int, torch::kInt); + +#undef TO_SCALAR_TYPE + +template +torch::Tensor ToTensor(Array1 &array) { + auto device_type = ToTorchDeviceType(array.Context()->GetDeviceType()); + int32_t device_id = array.Context()->GetDeviceId(); + auto device = torch::Device(device_type, device_id); + auto scalar_type = ToScalarType::value; + auto options = torch::device(device).dtype(scalar_type); + + // NOTE: we keep a copy of `array` inside the lambda + // so that `torch::Tensor` always accesses valid memory. + return torch::from_blob( + array.Data(), array.Dim(), [array](void *p) {}, options); +} + +template +Array1 FromTensor(torch::Tensor &tensor) { + K2_CHECK_EQ(tensor.dim(), 1) << "Expected dim: 1. Given: " << tensor.dim(); + K2_CHECK_EQ(tensor.scalar_type(), ToScalarType::value) + << "Expected scalar type: " << ToScalarType::value + << ". Given: " << tensor.scalar_type(); + K2_CHECK_EQ(tensor.strides()[0], 1) + << "Expected stride: 1. Given: " << tensor.strides()[0]; + + auto region = NewRegion(tensor); + Array1 ans(tensor.numel(), region, 0); + return ans; +} + +} // namespace k2 + +#endif // K2_PYTHON_CSRC_TORCH_TORCH_UTIL_H_ diff --git a/k2/python/csrc/weights.h b/k2/python/csrc/weights.h deleted file mode 100644 index 7c1bd04cd..000000000 --- a/k2/python/csrc/weights.h +++ /dev/null @@ -1,14 +0,0 @@ -// k2/python/csrc/weights.h - -// Copyright (c) 2020 Xiaomi Corporation (author: Haowen Qiu) - -// See ../../../LICENSE for clarification regarding multiple authors - -#ifndef K2_PYTHON_CSRC_WEIGHTS_H_ -#define K2_PYTHON_CSRC_WEIGHTS_H_ - -#include "k2/python/csrc/k2.h" - -void PybindWeights(py::module &m); - -#endif // K2_PYTHON_CSRC_WEIGHTS_H_ diff --git a/k2/python/host/CMakeLists.txt b/k2/python/host/CMakeLists.txt new file mode 100644 index 000000000..60d6382f6 --- /dev/null +++ b/k2/python/host/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(csrc) +add_subdirectory(tests) diff --git a/k2/python/host/csrc/CMakeLists.txt b/k2/python/host/csrc/CMakeLists.txt new file mode 100644 index 000000000..b6bc530d1 --- /dev/null +++ b/k2/python/host/csrc/CMakeLists.txt @@ -0,0 +1,16 @@ +# please sort the files alphabetically +pybind11_add_module(_k2host + array.cc + aux_labels.cc + fsa.cc + fsa_algo.cc + fsa_equivalent.cc + fsa_util.cc + k2.cc + properties.cc + tensor.cc + weights.cc +) + +target_include_directories(_k2host PRIVATE ${CMAKE_SOURCE_DIR}) +target_link_libraries(_k2host PRIVATE fsa) diff --git a/k2/python/csrc/CPPLINT.cfg b/k2/python/host/csrc/CPPLINT.cfg similarity index 100% rename from k2/python/csrc/CPPLINT.cfg rename to k2/python/host/csrc/CPPLINT.cfg diff --git a/k2/python/csrc/README.md b/k2/python/host/csrc/README.md similarity index 100% rename from k2/python/csrc/README.md rename to k2/python/host/csrc/README.md diff --git a/k2/python/csrc/array.cc b/k2/python/host/csrc/array.cc similarity index 98% rename from k2/python/csrc/array.cc rename to k2/python/host/csrc/array.cc index 3f9ffa418..04aee9a92 100644 --- a/k2/python/csrc/array.cc +++ b/k2/python/host/csrc/array.cc @@ -1,15 +1,16 @@ +// k2/python/host/csrc/array.cc // Copyright (c) 2020 Xiaomi Corporation (author: Haowen Qiu) // See ../../../LICENSE for clarification regarding multiple authors -#include "k2/python/csrc/array.h" +#include "k2/python/host/csrc/array.h" #include #include #include "k2/csrc/host/array.h" #include "k2/csrc/host/determinize_impl.h" -#include "k2/python/csrc/tensor.h" +#include "k2/python/host/csrc/tensor.h" namespace k2host { diff --git a/k2/python/csrc/array.h b/k2/python/host/csrc/array.h similarity index 53% rename from k2/python/csrc/array.h rename to k2/python/host/csrc/array.h index 14d5d267f..76cba74d1 100644 --- a/k2/python/csrc/array.h +++ b/k2/python/host/csrc/array.h @@ -1,15 +1,15 @@ -// k2/python/csrc/array.h +// k2/python/host/csrc/array.h // Copyright (c) 2020 Xiaomi Corporation (author: Haowen Qiu) // See ../../../LICENSE for clarification regarding multiple authors -#ifndef K2_PYTHON_CSRC_ARRAY_H_ -#define K2_PYTHON_CSRC_ARRAY_H_ +#ifndef K2_PYTHON_HOST_CSRC_ARRAY_H_ +#define K2_PYTHON_HOST_CSRC_ARRAY_H_ -#include "k2/python/csrc/k2.h" +#include "k2/python/host/csrc/k2.h" void PybindArray(py::module &m); void PybindArray2Size(py::module &m); -#endif // K2_PYTHON_CSRC_ARRAY_H_ +#endif // K2_PYTHON_HOST_CSRC_ARRAY_H_ diff --git a/k2/python/csrc/aux_labels.cc b/k2/python/host/csrc/aux_labels.cc similarity index 95% rename from k2/python/csrc/aux_labels.cc rename to k2/python/host/csrc/aux_labels.cc index b73b03676..c4db30f20 100644 --- a/k2/python/csrc/aux_labels.cc +++ b/k2/python/host/csrc/aux_labels.cc @@ -1,10 +1,10 @@ -// k2/python/csrc/aux_labels.cc +// k2/python/host/csrc/aux_labels.cc // Copyright (c) 2020 Xiaomi Corporation (author: Haowen Qiu) // See ../../../LICENSE for clarification regarding multiple authors -#include "k2/python/csrc/aux_labels.h" +#include "k2/python/host/csrc/aux_labels.h" #include "k2/csrc/host/aux_labels.h" diff --git a/k2/python/host/csrc/aux_labels.h b/k2/python/host/csrc/aux_labels.h new file mode 100644 index 000000000..b331c5e42 --- /dev/null +++ b/k2/python/host/csrc/aux_labels.h @@ -0,0 +1,14 @@ +// k2/python/host/csrc/aux_labels.h + +// Copyright (c) 2020 Xiaomi Corporation (author: Haowen Qiu) + +// See ../../../LICENSE for clarification regarding multiple authors + +#ifndef K2_PYTHON_HOST_CSRC_AUX_LABELS_H_ +#define K2_PYTHON_HOST_CSRC_AUX_LABELS_H_ + +#include "k2/python/host/csrc/k2.h" + +void PybindAuxLabels(py::module &m); + +#endif // K2_PYTHON_HOST_CSRC_AUX_LABELS_H_ diff --git a/k2/python/csrc/dlpack.h b/k2/python/host/csrc/dlpack.h similarity index 100% rename from k2/python/csrc/dlpack.h rename to k2/python/host/csrc/dlpack.h diff --git a/k2/python/csrc/fsa.cc b/k2/python/host/csrc/fsa.cc similarity index 95% rename from k2/python/csrc/fsa.cc rename to k2/python/host/csrc/fsa.cc index 3f13a5bb2..35b9564ae 100644 --- a/k2/python/csrc/fsa.cc +++ b/k2/python/host/csrc/fsa.cc @@ -1,17 +1,17 @@ -// k2/python/csrc/fsa.cc +// k2/python/host/csrc/fsa.cc // Copyright (c) 2020 Fangjun Kuang (csukuangfj@gmail.com) // Xiaomi Corporation (author: Haowen Qiu) // See ../../../LICENSE for clarification regarding multiple authors -#include "k2/python/csrc/fsa.h" +#include "k2/python/host/csrc/fsa.h" #include #include #include "k2/csrc/host/fsa.h" -#include "k2/python/csrc/tensor.h" +#include "k2/python/host/csrc/tensor.h" namespace k2host { @@ -58,6 +58,7 @@ void PybindArc(py::module &m) { .def_readwrite("src_state", &PyClass::src_state) .def_readwrite("dest_state", &PyClass::dest_state) .def_readwrite("label", &PyClass::label) + .def_readwrite("weight", &PyClass::weight) .def("__str__", [](const PyClass &self) { std::ostringstream os; os << self; diff --git a/k2/python/csrc/fsa.h b/k2/python/host/csrc/fsa.h similarity index 53% rename from k2/python/csrc/fsa.h rename to k2/python/host/csrc/fsa.h index 9090a76c8..10f891649 100644 --- a/k2/python/csrc/fsa.h +++ b/k2/python/host/csrc/fsa.h @@ -1,15 +1,15 @@ -// k2/python/csrc/fsa.h +// k2/python/host/csrc/fsa.h // Copyright (c) 2020 Fangjun Kuang (csukuangfj@gmail.com) // See ../../../LICENSE for clarification regarding multiple authors -#ifndef K2_PYTHON_CSRC_FSA_H_ -#define K2_PYTHON_CSRC_FSA_H_ +#ifndef K2_PYTHON_HOST_CSRC_FSA_H_ +#define K2_PYTHON_HOST_CSRC_FSA_H_ -#include "k2/python/csrc/k2.h" +#include "k2/python/host/csrc/k2.h" void PybindArc(py::module &m); void PybindFsa(py::module &m); -#endif // K2_PYTHON_CSRC_FSA_H_ +#endif // K2_PYTHON_HOST_CSRC_FSA_H_ diff --git a/k2/python/csrc/fsa_algo.cc b/k2/python/host/csrc/fsa_algo.cc similarity index 97% rename from k2/python/csrc/fsa_algo.cc rename to k2/python/host/csrc/fsa_algo.cc index 205d6470f..a3618633a 100644 --- a/k2/python/csrc/fsa_algo.cc +++ b/k2/python/host/csrc/fsa_algo.cc @@ -1,10 +1,10 @@ -// k2/python/csrc/fsa_algo.cc +// k2/python/host/csrc/fsa_algo.cc // Copyright (c) 2020 Xiaomi Corporation (author: Haowen Qiu) // See ../../../LICENSE for clarification regarding multiple authors -#include "k2/python/csrc/fsa_algo.h" +#include "k2/python/host/csrc/fsa_algo.h" #include #include @@ -19,7 +19,7 @@ #include "k2/csrc/host/rmepsilon.h" #include "k2/csrc/host/topsort.h" #include "k2/csrc/host/weights.h" -#include "k2/python/csrc/array.h" +#include "k2/python/host/csrc/array.h" void PyBindArcSort(py::module &m) { using PyClass = k2host::ArcSorter; diff --git a/k2/python/host/csrc/fsa_algo.h b/k2/python/host/csrc/fsa_algo.h new file mode 100644 index 000000000..34a294a79 --- /dev/null +++ b/k2/python/host/csrc/fsa_algo.h @@ -0,0 +1,14 @@ +// k2/python/host/csrc/fsa_algo.h + +// Copyright (c) 2020 Xiaomi Corporation (author: Haowen Qiu) + +// See ../../../LICENSE for clarification regarding multiple authors + +#ifndef K2_PYTHON_HOST_CSRC_FSA_ALGO_H_ +#define K2_PYTHON_HOST_CSRC_FSA_ALGO_H_ + +#include "k2/python/host/csrc/k2.h" + +void PybindFsaAlgo(py::module &m); + +#endif // K2_PYTHON_HOST_CSRC_FSA_ALGO_H_ diff --git a/k2/python/csrc/fsa_equivalent.cc b/k2/python/host/csrc/fsa_equivalent.cc similarity index 96% rename from k2/python/csrc/fsa_equivalent.cc rename to k2/python/host/csrc/fsa_equivalent.cc index 8281a2f7b..eb0917c17 100644 --- a/k2/python/csrc/fsa_equivalent.cc +++ b/k2/python/host/csrc/fsa_equivalent.cc @@ -1,10 +1,10 @@ -// k2/python/csrc/fsa_equivalent.cc +// k2/python/host/csrc/fsa_equivalent.cc // Copyright (c) 2020 Xiaomi Corporation (author: Haowen Qiu) // See ../../../LICENSE for clarification regarding multiple authors -#include "k2/python/csrc/fsa_equivalent.h" +#include "k2/python/host/csrc/fsa_equivalent.h" #include "k2/csrc/host/array.h" #include "k2/csrc/host/fsa_equivalent.h" diff --git a/k2/python/host/csrc/fsa_equivalent.h b/k2/python/host/csrc/fsa_equivalent.h new file mode 100644 index 000000000..c0b9ea4b9 --- /dev/null +++ b/k2/python/host/csrc/fsa_equivalent.h @@ -0,0 +1,14 @@ +// k2/python/host/csrc/fsa_equivalent.h + +// Copyright (c) 2020 Xiaomi Corporation (author: Haowen Qiu) + +// See ../../../LICENSE for clarification regarding multiple authors + +#ifndef K2_PYTHON_HOST_CSRC_FSA_EQUIVALENT_H_ +#define K2_PYTHON_HOST_CSRC_FSA_EQUIVALENT_H_ + +#include "k2/python/host/csrc/k2.h" + +void PybindFsaEquivalent(py::module &m); + +#endif // K2_PYTHON_HOST_CSRC_FSA_EQUIVALENT_H_ diff --git a/k2/python/csrc/fsa_util.cc b/k2/python/host/csrc/fsa_util.cc similarity index 77% rename from k2/python/csrc/fsa_util.cc rename to k2/python/host/csrc/fsa_util.cc index 4e3161d10..b4d6bd6d9 100644 --- a/k2/python/csrc/fsa_util.cc +++ b/k2/python/host/csrc/fsa_util.cc @@ -1,10 +1,10 @@ -// k2/python/csrc/fsa_util.cc +// k2/python/host/csrc/fsa_util.cc // Copyright (c) 2020 Fangjun Kuang (csukuangfj@gmail.com) // See ../../../LICENSE for clarification regarding multiple authors -#include "k2/python/csrc/fsa_util.h" +#include "k2/python/host/csrc/fsa_util.h" #include "k2/csrc/host/fsa_util.h" diff --git a/k2/python/host/csrc/fsa_util.h b/k2/python/host/csrc/fsa_util.h new file mode 100644 index 000000000..f685f3d13 --- /dev/null +++ b/k2/python/host/csrc/fsa_util.h @@ -0,0 +1,14 @@ +// k2/python/host/csrc/fsa_util.h + +// Copyright (c) 2020 Fangjun Kuang (csukuangfj@gmail.com) + +// See ../../../LICENSE for clarification regarding multiple authors + +#ifndef K2_PYTHON_HOST_CSRC_FSA_UTIL_H_ +#define K2_PYTHON_HOST_CSRC_FSA_UTIL_H_ + +#include "k2/python/host/csrc/k2.h" + +void PybindFsaUtil(py::module &m); + +#endif // K2_PYTHON_HOST_CSRC_FSA_UTIL_H_ diff --git a/k2/python/host/csrc/k2.cc b/k2/python/host/csrc/k2.cc new file mode 100644 index 000000000..47fb04b5e --- /dev/null +++ b/k2/python/host/csrc/k2.cc @@ -0,0 +1,30 @@ +// k2/python/host/csrc/k2.cc + +// Copyright (c) 2020 Fangjun Kuang (csukuangfj@gmail.com) + +// See ../../../LICENSE for clarification regarding multiple authors + +#include "k2/python/host/csrc/k2.h" + +#include "k2/python/host/csrc/array.h" +#include "k2/python/host/csrc/aux_labels.h" +#include "k2/python/host/csrc/fsa.h" +#include "k2/python/host/csrc/fsa_algo.h" +#include "k2/python/host/csrc/fsa_equivalent.h" +#include "k2/python/host/csrc/fsa_util.h" +#include "k2/python/host/csrc/properties.h" +#include "k2/python/host/csrc/weights.h" + +PYBIND11_MODULE(_k2host, m) { + m.doc() = "pybind11 binding of k2host"; + PybindArc(m); + PybindArray(m); + PybindArray2Size(m); + PybindFsa(m); + PybindFsaUtil(m); + PybindFsaAlgo(m); + PybindFsaEquivalent(m); + PybindProperties(m); + PybindAuxLabels(m); + PybindWeights(m); +} diff --git a/k2/python/host/csrc/k2.h b/k2/python/host/csrc/k2.h new file mode 100644 index 000000000..40831ffff --- /dev/null +++ b/k2/python/host/csrc/k2.h @@ -0,0 +1,16 @@ +// k2/python/host/csrc/k2.h + +// Copyright (c) 2020 Fangjun Kuang (csukuangfj@gmail.com) + +// See ../../../LICENSE for clarification regarding multiple authors + +#ifndef K2_PYTHON_HOST_CSRC_K2_H_ +#define K2_PYTHON_HOST_CSRC_K2_H_ + +#include "k2/csrc/log.h" +#include "pybind11/pybind11.h" +#include "pybind11/stl.h" + +namespace py = pybind11; + +#endif // K2_PYTHON_HOST_CSRC_K2_H_ diff --git a/k2/python/csrc/properties.cc b/k2/python/host/csrc/properties.cc similarity index 93% rename from k2/python/csrc/properties.cc rename to k2/python/host/csrc/properties.cc index 7df721bb6..e73191233 100644 --- a/k2/python/csrc/properties.cc +++ b/k2/python/host/csrc/properties.cc @@ -4,14 +4,14 @@ // See ../../../LICENSE for clarification regarding multiple authors -#include "k2/python/csrc/properties.h" +#include "k2/python/host/csrc/properties.h" #include #include "k2/csrc/host/array.h" #include "k2/csrc/host/fsa.h" #include "k2/csrc/host/properties.h" -#include "k2/python/csrc/array.h" +#include "k2/python/host/csrc/array.h" // We would never pass `order` parameter to k2host::IsAcyclic in Python code. // We can make it accept `None` with `std::optional` in pybind11, but diff --git a/k2/python/host/csrc/properties.h b/k2/python/host/csrc/properties.h new file mode 100644 index 000000000..0a5a057ea --- /dev/null +++ b/k2/python/host/csrc/properties.h @@ -0,0 +1,14 @@ +// k2/python/host/csrc/properties.h + +// Copyright (c) 2020 Xiaomi Corporation (author: Haowen Qiu) + +// See ../../../LICENSE for clarification regarding multiple authors + +#ifndef K2_PYTHON_HOST_CSRC_PROPERTIES_H_ +#define K2_PYTHON_HOST_CSRC_PROPERTIES_H_ + +#include "k2/python/host/csrc/k2.h" + +void PybindProperties(py::module &m); + +#endif // K2_PYTHON_HOST_CSRC_PROPERTIES_H_ diff --git a/k2/python/csrc/tensor.cc b/k2/python/host/csrc/tensor.cc similarity index 98% rename from k2/python/csrc/tensor.cc rename to k2/python/host/csrc/tensor.cc index f6a072ec8..be43424c3 100644 --- a/k2/python/csrc/tensor.cc +++ b/k2/python/host/csrc/tensor.cc @@ -1,11 +1,11 @@ -// k2/python/csrc/tensor.cc +// k2/python/host/csrc/tensor.cc // Copyright (c) 2020 Fangjun Kuang (csukuangfj@gmail.com) // Xiaomi Corporation (author: Haowen Qiu) // See ../../../LICENSE for clarification regarding multiple authors -#include "k2/python/csrc/tensor.h" +#include "k2/python/host/csrc/tensor.h" namespace k2host { diff --git a/k2/python/csrc/tensor.h b/k2/python/host/csrc/tensor.h similarity index 91% rename from k2/python/csrc/tensor.h rename to k2/python/host/csrc/tensor.h index f72014d18..7a544499b 100644 --- a/k2/python/csrc/tensor.h +++ b/k2/python/host/csrc/tensor.h @@ -1,14 +1,14 @@ -// k2/python/csrc/tensor.h +// k2/python/host/csrc/tensor.h // Copyright (c) 2020 Fangjun Kuang (csukuangfj@gmail.com) // See ../../../LICENSE for clarification regarding multiple authors -#ifndef K2_PYTHON_CSRC_TENSOR_H_ -#define K2_PYTHON_CSRC_TENSOR_H_ +#ifndef K2_PYTHON_HOST_CSRC_TENSOR_H_ +#define K2_PYTHON_HOST_CSRC_TENSOR_H_ -#include "k2/python/csrc/dlpack.h" -#include "k2/python/csrc/k2.h" +#include "k2/python/host/csrc/dlpack.h" +#include "k2/python/host/csrc/k2.h" namespace k2host { @@ -98,4 +98,4 @@ class Tensor { } // namespace k2host -#endif // K2_PYTHON_CSRC_TENSOR_H_ +#endif // K2_PYTHON_HOST_CSRC_TENSOR_H_ diff --git a/k2/python/csrc/weights.cc b/k2/python/host/csrc/weights.cc similarity index 95% rename from k2/python/csrc/weights.cc rename to k2/python/host/csrc/weights.cc index e1081a2c5..55d1e54a7 100644 --- a/k2/python/csrc/weights.cc +++ b/k2/python/host/csrc/weights.cc @@ -1,10 +1,10 @@ -// k2/python/csrc/weights.cc +// k2/python/host/csrc/weights.cc // Copyright (c) 2020 Xiaomi Corporation (author: Haowen Qiu) // See ../../../LICENSE for clarification regarding multiple authors -#include "k2/python/csrc/weights.h" +#include "k2/python/host/csrc/weights.h" #include diff --git a/k2/python/host/csrc/weights.h b/k2/python/host/csrc/weights.h new file mode 100644 index 000000000..fe59920ba --- /dev/null +++ b/k2/python/host/csrc/weights.h @@ -0,0 +1,14 @@ +// k2/python/host/csrc/weights.h + +// Copyright (c) 2020 Xiaomi Corporation (author: Haowen Qiu) + +// See ../../../LICENSE for clarification regarding multiple authors + +#ifndef K2_PYTHON_HOST_CSRC_WEIGHTS_H_ +#define K2_PYTHON_HOST_CSRC_WEIGHTS_H_ + +#include "k2/python/host/csrc/k2.h" + +void PybindWeights(py::module &m); + +#endif // K2_PYTHON_HOST_CSRC_WEIGHTS_H_ diff --git a/k2/python/host/k2host/__init__.py b/k2/python/host/k2host/__init__.py new file mode 100644 index 000000000..32b477121 --- /dev/null +++ b/k2/python/host/k2host/__init__.py @@ -0,0 +1,10 @@ +from _k2host import IntArray2Size +from _k2host import FbWeightType +from .array import * +from .aux_labels import * +from .fsa import * +from .fsa_algo import * +from .fsa_equivalent import * +from .fsa_util import str_to_fsa +from .properties import * +from .weights import * diff --git a/k2/python/host/k2host/array.py b/k2/python/host/k2host/array.py new file mode 100644 index 000000000..821eba64f --- /dev/null +++ b/k2/python/host/k2host/array.py @@ -0,0 +1,107 @@ +# Copyright (c) 2020 Xiaomi Corporation (author: Haowen Qiu) + +# See ../../../LICENSE for clarification regarding multiple authors + +import torch +from torch.utils.dlpack import to_dlpack + +from _k2host import IntArray2Size +from _k2host import DLPackIntArray2 +from _k2host import DLPackIntArray1 +from _k2host import DLPackStridedIntArray1 +from _k2host import DLPackFloatArray1 +from _k2host import DLPackDoubleArray1 +from _k2host import DLPackLogSumArcDerivs + + +class IntArray1(DLPackIntArray1): + + def __init__(self, data: torch.Tensor, check_dtype: bool = True): + if check_dtype: + assert data.dtype == torch.int32 + self.data = data + super().__init__(to_dlpack(self.data)) + + @staticmethod + def from_float_tensor(data: torch.Tensor) -> 'IntArray1': + assert data.dtype == torch.float + return IntArray1(data, False) + + @staticmethod + def create_array_with_size(size: int) -> 'IntArray1': + data = torch.zeros(size, dtype=torch.int32) + return IntArray1(data) + + +class StridedIntArray1(DLPackStridedIntArray1): + + def __init__(self, data: torch.Tensor, check_dtype: bool = True): + if check_dtype: + assert data.dtype == torch.int32 + self.data = data + super().__init__(to_dlpack(self.data)) + + @staticmethod + def from_float_tensor(data: torch.Tensor) -> 'StridedIntArray1': + assert data.dtype == torch.float + return StridedIntArray1(data, False) + + +class FloatArray1(DLPackFloatArray1): + + def __init__(self, data: torch.Tensor): + assert data.dtype == torch.float + self.data = data + super().__init__(to_dlpack(self.data)) + + @staticmethod + def create_array_with_size(size: int) -> 'FloatArray1': + data = torch.zeros(size, dtype=torch.float) + return FloatArray1(data) + + +class DoubleArray1(DLPackDoubleArray1): + + def __init__(self, data: torch.Tensor): + assert data.dtype == torch.double + self.data = data + super().__init__(to_dlpack(self.data)) + + @staticmethod + def create_array_with_size(size: int) -> 'DoubleArray1': + data = torch.zeros(size, dtype=torch.double) + return DoubleArray1(data) + + +class IntArray2(DLPackIntArray2): + + def __init__(self, indexes: torch.Tensor, data: torch.Tensor): + assert indexes.dtype == torch.int32 + assert data.dtype == torch.int32 + self.indexes = indexes + self.data = data + super().__init__(to_dlpack(self.indexes), to_dlpack(self.data)) + + @staticmethod + def create_array_with_size(array_size: IntArray2Size) -> 'IntArray2': + indexes = torch.zeros(array_size.size1 + 1, dtype=torch.int32) + data = torch.zeros(array_size.size2, dtype=torch.int32) + return IntArray2(indexes, data) + + +class LogSumArcDerivs(DLPackLogSumArcDerivs): + + def __init__(self, indexes: torch.Tensor, data: torch.Tensor): + assert indexes.dtype == torch.int32 + assert data.dtype == torch.float32 + assert data.shape[1] == 2 + self.indexes = indexes + self.data = data + super().__init__(to_dlpack(self.indexes), to_dlpack(self.data)) + + @staticmethod + def create_arc_derivs_with_size(array_size: IntArray2Size + ) -> 'LogSumArcDerivs': + indexes = torch.zeros(array_size.size1 + 1, dtype=torch.int32) + data = torch.zeros([array_size.size2, 2], dtype=torch.float32) + return LogSumArcDerivs(indexes, data) diff --git a/k2/python/k2/aux_labels.py b/k2/python/host/k2host/aux_labels.py similarity index 91% rename from k2/python/k2/aux_labels.py rename to k2/python/host/k2host/aux_labels.py index 91a8fe5bb..0b010cc1d 100644 --- a/k2/python/k2/aux_labels.py +++ b/k2/python/host/k2host/aux_labels.py @@ -5,10 +5,10 @@ import torch from torch.utils.dlpack import to_dlpack -from _k2 import IntArray2Size -from _k2 import _AuxLabels1Mapper -from _k2 import _AuxLabels2Mapper -from _k2 import _FstInverter +from _k2host import IntArray2Size +from _k2host import _AuxLabels1Mapper +from _k2host import _AuxLabels2Mapper +from _k2host import _FstInverter from .fsa import Fsa from .array import IntArray1 diff --git a/k2/python/k2/fsa.py b/k2/python/host/k2host/fsa.py similarity index 70% rename from k2/python/k2/fsa.py rename to k2/python/host/k2host/fsa.py index 0ac10a610..a93c0409e 100644 --- a/k2/python/k2/fsa.py +++ b/k2/python/host/k2host/fsa.py @@ -5,24 +5,27 @@ import torch from torch.utils.dlpack import to_dlpack -from _k2 import IntArray2Size -from _k2 import _Arc -from _k2 import DLPackFsa -from _k2 import IntArray2Size +from _k2host import IntArray2Size +from _k2host import _Arc +from _k2host import DLPackFsa +from _k2host import IntArray2Size class Arc(_Arc): - def __init__(self, src_state: int, dest_state: int, label: int): - super().__init__(src_state, dest_state, label) + def __init__(self, src_state: int, dest_state: int, label: int, + weight: float): + super().__init__(src_state, dest_state, label, weight) def to_tensor(self): - return torch.tensor([self.src_state, self.dest_state, self.label], - dtype=torch.int32) + # TODO(fangjun): weight will be truncted to an int. + return torch.tensor( + [self.src_state, self.dest_state, self.label, self.weight], + dtype=torch.int32) @staticmethod def from_tensor(tensor: torch.Tensor) -> 'Arc': - assert tensor.shape == torch.Size([3]) + assert tensor.shape == torch.Size([4]) assert tensor.dtype == torch.int32 return Arc(*tensor.tolist()) @@ -41,7 +44,7 @@ class Fsa(DLPackFsa): def __init__(self, indexes: torch.Tensor, data: torch.Tensor): assert indexes.dtype == torch.int32 assert data.dtype == torch.int32 - assert data.shape[1] == 3 + assert data.shape[1] == 4 self.indexes = indexes self.data = data super().__init__(to_dlpack(self.indexes), to_dlpack(self.data)) @@ -49,5 +52,5 @@ def __init__(self, indexes: torch.Tensor, data: torch.Tensor): @staticmethod def create_fsa_with_size(array_size: IntArray2Size) -> 'Fsa': indexes = torch.zeros(array_size.size1 + 1, dtype=torch.int32) - data = torch.zeros([array_size.size2, 3], dtype=torch.int32) + data = torch.zeros([array_size.size2, 4], dtype=torch.int32) return Fsa(indexes, data) diff --git a/k2/python/k2/fsa_algo.py b/k2/python/host/k2host/fsa_algo.py similarity index 71% rename from k2/python/k2/fsa_algo.py rename to k2/python/host/k2host/fsa_algo.py index f9423f97a..e286adf83 100644 --- a/k2/python/k2/fsa_algo.py +++ b/k2/python/host/k2host/fsa_algo.py @@ -11,16 +11,16 @@ from .array import FloatArray1 from .array import LogSumArcDerivs from .weights import WfsaWithFbWeights -from _k2 import IntArray2Size -from _k2 import _ArcSorter -from _k2 import _arc_sort -from _k2 import _TopSorter -from _k2 import _Connection -from _k2 import _Intersection -from _k2 import _DeterminizerMax -from _k2 import _DeterminizerLogSum -from _k2 import _EpsilonsRemoverMax -from _k2 import _EpsilonsRemoverLogSum +from _k2host import IntArray2Size +from _k2host import _ArcSorter +from _k2host import _arc_sort +from _k2host import _TopSorter +from _k2host import _Connection +from _k2host import _Intersection +from _k2host import _DeterminizerMax +from _k2host import _DeterminizerLogSum +from _k2host import _EpsilonsRemoverMax +from _k2host import _EpsilonsRemoverLogSum class ArcSorter(_ArcSorter): @@ -97,11 +97,8 @@ def get_sizes(self, fsa_size: IntArray2Size, arc_derivs_size: IntArray2Size) -> None: return super().get_sizes(fsa_size, arc_derivs_size) - def get_output(self, fsa_out: Fsa, arc_weights_out: FloatArray1, - arc_derivs: IntArray2) -> float: - return super().get_output(fsa_out.get_base(), - arc_weights_out.get_base(), - arc_derivs.get_base()) + def get_output(self, fsa_out: Fsa, arc_derivs: IntArray2) -> float: + return super().get_output(fsa_out.get_base(), arc_derivs.get_base()) class DeterminizerLogSum(_DeterminizerLogSum): @@ -113,11 +110,8 @@ def get_sizes(self, fsa_size: IntArray2Size, arc_derivs_size: IntArray2Size) -> None: return super().get_sizes(fsa_size, arc_derivs_size) - def get_output(self, fsa_out: Fsa, arc_weights_out: FloatArray1, - arc_derivs: LogSumArcDerivs) -> float: - return super().get_output(fsa_out.get_base(), - arc_weights_out.get_base(), - arc_derivs.get_base()) + def get_output(self, fsa_out: Fsa, arc_derivs: LogSumArcDerivs) -> float: + return super().get_output(fsa_out.get_base(), arc_derivs.get_base()) class EpsilonsRemoverMax(_EpsilonsRemoverMax): @@ -129,11 +123,8 @@ def get_sizes(self, fsa_size: IntArray2Size, arc_derivs_size: IntArray2Size) -> None: return super().get_sizes(fsa_size, arc_derivs_size) - def get_output(self, fsa_out: Fsa, arc_weights_out: FloatArray1, - arc_derivs: IntArray2) -> None: - return super().get_output(fsa_out.get_base(), - arc_weights_out.get_base(), - arc_derivs.get_base()) + def get_output(self, fsa_out: Fsa, arc_derivs: IntArray2) -> None: + return super().get_output(fsa_out.get_base(), arc_derivs.get_base()) class EpsilonsRemoverLogSum(_EpsilonsRemoverLogSum): @@ -145,8 +136,5 @@ def get_sizes(self, fsa_size: IntArray2Size, arc_derivs_size: IntArray2Size) -> None: return super().get_sizes(fsa_size, arc_derivs_size) - def get_output(self, fsa_out: Fsa, arc_weights_out: FloatArray1, - arc_derivs: LogSumArcDerivs) -> None: - return super().get_output(fsa_out.get_base(), - arc_weights_out.get_base(), - arc_derivs.get_base()) + def get_output(self, fsa_out: Fsa, arc_derivs: LogSumArcDerivs) -> None: + return super().get_output(fsa_out.get_base(), arc_derivs.get_base()) diff --git a/k2/python/k2/fsa_equivalent.py b/k2/python/host/k2host/fsa_equivalent.py similarity index 62% rename from k2/python/k2/fsa_equivalent.py rename to k2/python/host/k2host/fsa_equivalent.py index 2a83c6056..b8608a5e2 100644 --- a/k2/python/k2/fsa_equivalent.py +++ b/k2/python/host/k2host/fsa_equivalent.py @@ -8,12 +8,12 @@ from .fsa import Fsa from .array import IntArray1 from .array import FloatArray1 -from _k2 import IntArray2Size -from _k2 import _RandPath -from _k2 import _is_rand_equivalent -from _k2 import _is_rand_equivalent_max_weight -from _k2 import _is_rand_equivalent_logsum_weight -from _k2 import _is_rand_equivalent_after_rmeps_pruned_logsum +from _k2host import IntArray2Size +from _k2host import _RandPath +from _k2host import _is_rand_equivalent +from _k2host import _is_rand_equivalent_max_weight +from _k2host import _is_rand_equivalent_logsum_weight +from _k2host import _is_rand_equivalent_after_rmeps_pruned_logsum class RandPath(_RandPath): @@ -35,42 +35,30 @@ def is_rand_equivalent(fsa_a: Fsa, fsa_b: Fsa, npath: int = 100) -> bool: def is_rand_equivalent_max_weight(fsa_a: Fsa, - a_weights: FloatArray1, fsa_b: Fsa, - b_weights: FloatArray1, beam: float = float('inf'), delta: float = 1e-6, top_sorted: bool = True, npath: int = 100) -> bool: - return _is_rand_equivalent_max_weight(fsa_a.get_base(), - a_weights.get_base(), - fsa_b.get_base(), - b_weights.get_base(), beam, delta, - top_sorted, npath) + return _is_rand_equivalent_max_weight(fsa_a.get_base(), fsa_b.get_base(), + beam, delta, top_sorted, npath) def is_rand_equivalent_logsum_weight(fsa_a: Fsa, - a_weights: FloatArray1, fsa_b: Fsa, - b_weights: FloatArray1, beam: float = float('inf'), delta: float = 1e-6, top_sorted: bool = True, npath: int = 100) -> bool: return _is_rand_equivalent_logsum_weight(fsa_a.get_base(), - a_weights.get_base(), - fsa_b.get_base(), - b_weights.get_base(), beam, delta, + fsa_b.get_base(), beam, delta, top_sorted, npath) def is_rand_equivalent_after_rmeps_pruned_logsum(fsa_a: Fsa, - a_weights: FloatArray1, fsa_b: Fsa, - b_weights: FloatArray1, beam: float, top_sorted: bool = True, npath: int = 100) -> bool: return _is_rand_equivalent_after_rmeps_pruned_logsum( - fsa_a.get_base(), a_weights.get_base(), fsa_b.get_base(), - b_weights.get_base(), beam, top_sorted, npath) + fsa_a.get_base(), fsa_b.get_base(), beam, top_sorted, npath) diff --git a/k2/python/k2/fsa_util.py b/k2/python/host/k2host/fsa_util.py similarity index 79% rename from k2/python/k2/fsa_util.py rename to k2/python/host/k2host/fsa_util.py index d2cfbd595..d096dcf7b 100644 --- a/k2/python/k2/fsa_util.py +++ b/k2/python/host/k2host/fsa_util.py @@ -3,6 +3,7 @@ # See ../../../LICENSE for clarification regarding multiple authors import re +import struct from collections import defaultdict import torch @@ -10,13 +11,18 @@ from .fsa import Fsa +def float_to_int(f): + f = struct.pack('f', f) + return int.from_bytes(f, 'little') + + def str_to_fsa(s: str) -> Fsa: '''Create an FSA from a string. The input string `s` consists of several lines; every line except the last line has the following format: -