Skip to content
This repository has been archived by the owner on Aug 23, 2023. It is now read-only.

Commit

Permalink
Wrap Array1<T> as torch::Tensor. (k2-fsa#173)
Browse files Browse the repository at this point in the history
* Wrap Array1<T> as torch::Tensor.

Fix k2host test cases.

* interpret arc.weight from a float to an int.

* update the comment for torch.h/torch.cu

* fix linker errors for release build.
  • Loading branch information
csukuangfj authored Sep 23, 2020
1 parent c0be6f5 commit 3c21c9d
Show file tree
Hide file tree
Showing 86 changed files with 2,226 additions and 1,631 deletions.
8 changes: 8 additions & 0 deletions .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,11 @@
show-source=true
statistics=true
max-line-length=80
exclude =
.git,
build,
k2/python/host

ignore =
# E127 continuation line over-indented for visual indent
E127,
7 changes: 3 additions & 4 deletions .github/workflows/style_check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.5, 3.6, 3.7, 3.8]
python-version: [3.7, 3.8]

steps:
- uses: actions/checkout@v2
Expand All @@ -32,16 +32,15 @@ jobs:
- name: Install Python dependencies
run: |
python3 -m pip install --upgrade pip
python3 -m pip install --upgrade flake8
python3 -m pip install --upgrade flake8==3.8.3
- name: Run flake8
shell: bash
working-directory: ${{github.workspace}}
run: |
# stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings.
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=79 --statistics
flake8 .
# TODO(fangjun): build a docker for style check
# - name: Install cppcheck
Expand Down
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,10 @@ enable_testing()
list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake)
include(pybind11)
if(USE_PYTORCH)
add_definitions(-DK2_USE_PYTORCH)
include(torch)
endif()
include(cub)
include(googletest)


add_subdirectory(k2)
6 changes: 6 additions & 0 deletions cmake/pybind11.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,15 @@ function(download_pybind11)
set(pybind11_URL "https://github.com/pybind/pybind11/archive/v2.5.0.tar.gz")
set(pybind11_HASH "SHA256=97504db65640570f32d3fdf701c25a340c8643037c3b69aec469c10c93dc8504")

set(double_quotes "\"")
set(dollar "\$")
set(semicolon "\;")
FetchContent_Declare(pybind11
URL ${pybind11_URL}
URL_HASH ${pybind11_HASH}
PATCH_COMMAND
sed -i s/\\${double_quotes}-flto\\\\${dollar}/\\${double_quotes}-Xcompiler=-flto${dollar}/g "tools/pybind11Tools.cmake" &&
sed -i s/${seimcolon}-fno-fat-lto-objects/${seimcolon}-Xcompiler=-fno-fat-lto-objects/g "tools/pybind11Tools.cmake"
)

FetchContent_GetProperties(pybind11)
Expand Down
2 changes: 1 addition & 1 deletion k2/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ else()
endif()

# the target
add_library(context STATIC ${context_srcs})
add_library(context SHARED ${context_srcs})
set_target_properties(context PROPERTIES CUDA_SEPARABLE_COMPILATION ON)

# lib deps
Expand Down
4 changes: 2 additions & 2 deletions k2/csrc/default_context.cu
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class CpuContext : public Context {
int32_t ret = posix_memalign(&p, kAlignment, bytes);
K2_CHECK_EQ(ret, 0);
}
if (deleter_context) *deleter_context = nullptr;
if (deleter_context != nullptr) *deleter_context = nullptr;
return p;
}

Expand Down Expand Up @@ -75,7 +75,7 @@ class CudaContext : public Context {
auto ret = cudaMalloc(&p, bytes);
K2_CHECK_CUDA_ERROR(ret);
}
if (deleter_context) *deleter_context = nullptr;
if (deleter_context != nullptr) *deleter_context = nullptr;
return p;
}

Expand Down
2 changes: 1 addition & 1 deletion k2/csrc/host/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

# the target
# please sort the source files alphabetically
add_library(fsa
add_library(fsa SHARED
arcsort.cc
aux_labels.cc
connect.cc
Expand Down
3 changes: 2 additions & 1 deletion k2/csrc/host/fsa.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ inline std::size_t AlignTo(std::size_t b, std::size_t alignment) {
namespace k2host {

std::ostream &operator<<(std::ostream &os, const Arc &arc) {
os << arc.src_state << " " << arc.dest_state << " " << arc.label;
os << arc.src_state << " " << arc.dest_state << " " << arc.label << " "
<< arc.weight;
return os;
}

Expand Down
23 changes: 23 additions & 0 deletions k2/csrc/pytorch_context.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,27 @@ ContextPtr GetCudaContext(int32_t gpu_id /*= -1*/) {
return std::make_shared<PytorchCudaContext>(gpu_id);
}

RegionPtr NewRegion(torch::Tensor &tensor) {
auto ans = std::make_shared<Region>();
if (tensor.device().type() == torch::kCPU) {
ans->context = GetCpuContext();
} else if (tensor.is_cuda()) {
ans->context = GetCudaContext(tensor.device().index());
} else {
K2_LOG(FATAL) << "Unsupported device: " << tensor.device()
<< "\nOnly CPU and CUDA are supported";
}

// NOTE: the tensor is passed from Python and we have
// to retain it to avoid potential segmentation fault.
//
// It will be freed in `Context::Deallocate`.
auto *managed_tensor = new ManagedTensor(tensor);
ans->data = tensor.data_ptr();
ans->deleter_context = managed_tensor;
ans->num_bytes = tensor.nbytes();
ans->bytes_used = ans->num_bytes;
return ans;
}

} // namespace k2
39 changes: 33 additions & 6 deletions k2/csrc/pytorch_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,21 @@
#include <memory>

#include "c10/cuda/CUDACachingAllocator.h"
#include "c10/cuda/CUDAFunctions.h"
#include "k2/csrc/context.h"
#include "k2/csrc/log.h"
#include "torch/torch.h"

namespace k2 {

class ManagedTensor {
public:
explicit ManagedTensor(torch::Tensor &tensor) : handle_(tensor) {}

private:
torch::Tensor handle_; // retain a copy of the tensor passed from Python
};

class PytorchCpuContext : public Context {
private:
PytorchCpuContext() {
Expand All @@ -46,12 +55,18 @@ class PytorchCpuContext : public Context {

void *Allocate(std::size_t bytes, void **deleter_context) override {
void *p = allocator_->raw_allocate(bytes);
if (deleter_context) *deleter_context = nullptr;
if (deleter_context != nullptr) *deleter_context = nullptr;
return p;
}

void Deallocate(void *data, void * /*deleter_context*/) override {
allocator_->raw_deallocate(data);
void Deallocate(void *data, void *deleter_context) override {
if (deleter_context != nullptr) {
// a non-empty `deleter_context` indicates that
// the memory is passed from a `torch::Tensor`
delete reinterpret_cast<ManagedTensor *>(deleter_context);
} else {
allocator_->raw_deallocate(data);
}
}

bool IsCompatible(const Context &other) const override {
Expand Down Expand Up @@ -94,12 +109,18 @@ class PytorchCudaContext : public Context {

void *Allocate(std::size_t bytes, void **deleter_context) override {
void *p = allocator_->raw_allocate(bytes);
if (deleter_context) *deleter_context = nullptr;
if (deleter_context != nullptr) *deleter_context = nullptr;
return p;
}

void Deallocate(void *data, void * /*deleter_context*/) override {
allocator_->raw_deallocate(data);
void Deallocate(void *data, void *deleter_context) override {
if (deleter_context != nullptr) {
// a non-empty `deleter_context` indicates that
// the memory is passed from a `torch::Tensor`
delete reinterpret_cast<ManagedTensor *>(deleter_context);
} else {
allocator_->raw_deallocate(data);
}
}

bool IsCompatible(const Context &other) const override {
Expand All @@ -116,6 +137,12 @@ class PytorchCudaContext : public Context {
int32_t gpu_id_;
};

// Construct a region from a `torch::Tensor`.
//
// The resulting region shares the underlying memory with
// the given tensor.
RegionPtr NewRegion(torch::Tensor &tensor);

} // namespace k2

#endif // K2_CSRC_PYTORCH_CONTEXT_H_
1 change: 1 addition & 0 deletions k2/python/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
add_subdirectory(csrc)
add_subdirectory(tests)
add_subdirectory(host)
33 changes: 20 additions & 13 deletions k2/python/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,16 +1,23 @@
# please sort the files alphabetically
pybind11_add_module(_k2
array.cc
aux_labels.cc
fsa.cc
fsa_algo.cc
fsa_equivalent.cc
fsa_util.cc
k2.cc
properties.cc
tensor.cc
weights.cc
# please keep the list sorted
set(k2_srcs
k2.cu
torch.cu
)

target_include_directories(_k2 PRIVATE ${CMAKE_SOURCE_DIR})
if(USE_PYTORCH)
add_definitions(-DTORCH_API_INCLUDE_EXTENSION_H)
add_subdirectory(torch)
set(k2_srcs ${k2_srcs} ${torch_srcs})
set(k2_deps
${TORCH_LIBRARIES}
${TORCH_DIR}/lib/libtorch_python.so
)
else()
message(FATAL_ERROR "Please select a framework.")
endif()

pybind11_add_module(_k2 ${k2_srcs})
target_link_libraries(_k2 PRIVATE ${k2_deps})
target_link_libraries(_k2 PRIVATE context)
target_link_libraries(_k2 PRIVATE fsa)
target_include_directories(_k2 PRIVATE ${CMAKE_SOURCE_DIR})
14 changes: 0 additions & 14 deletions k2/python/csrc/aux_labels.h

This file was deleted.

14 changes: 0 additions & 14 deletions k2/python/csrc/fsa_algo.h

This file was deleted.

14 changes: 0 additions & 14 deletions k2/python/csrc/fsa_equivalent.h

This file was deleted.

14 changes: 0 additions & 14 deletions k2/python/csrc/fsa_util.h

This file was deleted.

30 changes: 0 additions & 30 deletions k2/python/csrc/k2.cc

This file was deleted.

18 changes: 18 additions & 0 deletions k2/python/csrc/k2.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
/**
* @brief python wrappers for k2.
*
* @copyright
* Copyright (c) 2020 Mobvoi AI Lab, Beijing, China (authors: Fangjun Kuang)
*
* @copyright
* See LICENSE for clarification regarding multiple authors
*/

#include "k2/python/csrc/k2.h"

#include "k2/python/csrc/torch.h"

PYBIND11_MODULE(_k2, m) {
m.doc() = "pybind11 binding of k2";
PybindTorch(m);
}
15 changes: 9 additions & 6 deletions k2/python/csrc/k2.h
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
// k2/python/csrc/k2.h

// Copyright (c) 2020 Fangjun Kuang (csukuangfj@gmail.com)

// See ../../../LICENSE for clarification regarding multiple authors
/**
* @brief python wrappers for k2.
*
* @copyright
* Copyright (c) 2020 Mobvoi AI Lab, Beijing, China (authors: Fangjun Kuang)
*
* @copyright
* See LICENSE for clarification regarding multiple authors
*/

#ifndef K2_PYTHON_CSRC_K2_H_
#define K2_PYTHON_CSRC_K2_H_

#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
#include "k2/csrc/log.h"

namespace py = pybind11;

Expand Down
Loading

0 comments on commit 3c21c9d

Please sign in to comment.