Skip to content

Commit

Permalink
Updated pybind11 bindings
Browse files Browse the repository at this point in the history
  • Loading branch information
COM8 committed Jun 13, 2022
1 parent 5017676 commit 3425a6f
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 65 deletions.
120 changes: 59 additions & 61 deletions python/src/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,22 @@ opAlgoDispatchPyInit(std::shared_ptr<kp::Algorithm>& algorithm,
push_consts.size(),
std::string(py::str(push_consts.dtype())));

if (push_consts.dtype() == py::dtype::of<std::float_t>()) {
if (push_consts.dtype().is(py::dtype::of<std::float_t>())) {
std::vector<float> dataVec((float*)info.ptr,
((float*)info.ptr) + info.size);
return std::unique_ptr<kp::OpAlgoDispatch>{ new kp::OpAlgoDispatch(
algorithm, dataVec) };
} else if (push_consts.dtype() == py::dtype::of<std::uint32_t>()) {
} else if (push_consts.dtype().is(py::dtype::of<std::uint32_t>())) {
std::vector<uint32_t> dataVec((uint32_t*)info.ptr,
((uint32_t*)info.ptr) + info.size);
return std::unique_ptr<kp::OpAlgoDispatch>{ new kp::OpAlgoDispatch(
algorithm, dataVec) };
} else if (push_consts.dtype() == py::dtype::of<std::int32_t>()) {
} else if (push_consts.dtype().is(py::dtype::of<std::int32_t>())) {
std::vector<int32_t> dataVec((int32_t*)info.ptr,
((int32_t*)info.ptr) + info.size);
return std::unique_ptr<kp::OpAlgoDispatch>{ new kp::OpAlgoDispatch(
algorithm, dataVec) };
} else if (push_consts.dtype() == py::dtype::of<std::double_t>()) {
} else if (push_consts.dtype().is(py::dtype::of<std::double_t>())) {
std::vector<double> dataVec((double*)info.ptr,
((double*)info.ptr) + info.size);
return std::unique_ptr<kp::OpAlgoDispatch>{ new kp::OpAlgoDispatch(
Expand Down Expand Up @@ -76,29 +76,29 @@ PYBIND11_MODULE(kp, m)
py::class_<kp::OpBase, std::shared_ptr<kp::OpBase>>(
m, "OpBase", DOC(kp, OpBase));

py::class_<kp::OpTensorSyncDevice, std::shared_ptr<kp::OpTensorSyncDevice>>(
m,
"OpTensorSyncDevice",
py::base<kp::OpBase>(),
DOC(kp, OpTensorSyncDevice))
py::class_<kp::OpTensorSyncDevice,
kp::OpBase,
std::shared_ptr<kp::OpTensorSyncDevice>>(
m, "OpTensorSyncDevice", DOC(kp, OpTensorSyncDevice))
.def(py::init<const std::vector<std::shared_ptr<kp::Tensor>>&>(),
DOC(kp, OpTensorSyncDevice, OpTensorSyncDevice));

py::class_<kp::OpTensorSyncLocal, std::shared_ptr<kp::OpTensorSyncLocal>>(
m,
"OpTensorSyncLocal",
py::base<kp::OpBase>(),
DOC(kp, OpTensorSyncLocal))
py::class_<kp::OpTensorSyncLocal,
kp::OpBase,
std::shared_ptr<kp::OpTensorSyncLocal>>(
m, "OpTensorSyncLocal", DOC(kp, OpTensorSyncLocal))
.def(py::init<const std::vector<std::shared_ptr<kp::Tensor>>&>(),
DOC(kp, OpTensorSyncLocal, OpTensorSyncLocal));

py::class_<kp::OpTensorCopy, std::shared_ptr<kp::OpTensorCopy>>(
m, "OpTensorCopy", py::base<kp::OpBase>(), DOC(kp, OpTensorCopy))
py::class_<kp::OpTensorCopy, kp::OpBase, std::shared_ptr<kp::OpTensorCopy>>(
m, "OpTensorCopy", DOC(kp, OpTensorCopy))
.def(py::init<const std::vector<std::shared_ptr<kp::Tensor>>&>(),
DOC(kp, OpTensorCopy, OpTensorCopy));

py::class_<kp::OpAlgoDispatch, std::shared_ptr<kp::OpAlgoDispatch>>(
m, "OpAlgoDispatch", py::base<kp::OpBase>(), DOC(kp, OpAlgoDispatch))
py::class_<kp::OpAlgoDispatch,
kp::OpBase,
std::shared_ptr<kp::OpAlgoDispatch>>(
m, "OpAlgoDispatch", DOC(kp, OpAlgoDispatch))
.def(py::init<const std::shared_ptr<kp::Algorithm>&,
const std::vector<float>&>(),
DOC(kp, OpAlgoDispatch, OpAlgoDispatch),
Expand All @@ -109,8 +109,8 @@ PYBIND11_MODULE(kp, m)
py::arg("algorithm"),
py::arg("push_consts"));

py::class_<kp::OpMult, std::shared_ptr<kp::OpMult>>(
m, "OpMult", py::base<kp::OpBase>(), DOC(kp, OpMult))
py::class_<kp::OpMult, kp::OpBase, std::shared_ptr<kp::OpMult>>(
m, "OpMult", DOC(kp, OpMult))
.def(py::init<const std::vector<std::shared_ptr<kp::Tensor>>&,
const std::shared_ptr<kp::Algorithm>&>(),
DOC(kp, OpMult, OpMult));
Expand Down Expand Up @@ -253,31 +253,31 @@ PYBIND11_MODULE(kp, m)
"size {} dtype {}",
flatdata.size(),
std::string(py::str(flatdata.dtype())));
if (flatdata.dtype() == py::dtype::of<std::float_t>()) {
if (flatdata.dtype().is(py::dtype::of<std::float_t>())) {
return self.tensor(info.ptr,
flatdata.size(),
sizeof(float),
kp::Tensor::TensorDataTypes::eFloat,
tensor_type);
} else if (flatdata.dtype() == py::dtype::of<std::uint32_t>()) {
} else if (flatdata.dtype().is(py::dtype::of<std::uint32_t>())) {
return self.tensor(info.ptr,
flatdata.size(),
sizeof(uint32_t),
kp::Tensor::TensorDataTypes::eUnsignedInt,
tensor_type);
} else if (flatdata.dtype() == py::dtype::of<std::int32_t>()) {
} else if (flatdata.dtype().is(py::dtype::of<std::int32_t>())) {
return self.tensor(info.ptr,
flatdata.size(),
sizeof(int32_t),
kp::Tensor::TensorDataTypes::eInt,
tensor_type);
} else if (flatdata.dtype() == py::dtype::of<std::double_t>()) {
} else if (flatdata.dtype().is(py::dtype::of<std::double_t>())) {
return self.tensor(info.ptr,
flatdata.size(),
sizeof(double),
kp::Tensor::TensorDataTypes::eDouble,
tensor_type);
} else if (flatdata.dtype() == py::dtype::of<bool>()) {
} else if (flatdata.dtype().is(py::dtype::of<bool>())) {
return self.tensor(info.ptr,
flatdata.size(),
sizeof(bool),
Expand Down Expand Up @@ -340,10 +340,10 @@ PYBIND11_MODULE(kp, m)

// We have to iterate across a combination of parameters due to the
// lack of support for templating
if (spec_consts.dtype() == py::dtype::of<std::float_t>()) {
if (spec_consts.dtype().is(py::dtype::of<std::float_t>())) {
std::vector<float> specConstsVec(
(float*)specInfo.ptr, ((float*)specInfo.ptr) + specInfo.size);
if (spec_consts.dtype() == py::dtype::of<std::float_t>()) {
if (spec_consts.dtype().is(py::dtype::of<std::float_t>())) {
std::vector<float> pushConstsVec((float*)pushInfo.ptr,
((float*)pushInfo.ptr) +
pushInfo.size);
Expand All @@ -352,8 +352,8 @@ PYBIND11_MODULE(kp, m)
workgroup,
specConstsVec,
pushConstsVec);
} else if (spec_consts.dtype() ==
py::dtype::of<std::int32_t>()) {
} else if (spec_consts.dtype().is(
py::dtype::of<std::int32_t>())) {
std::vector<int32_t> pushConstsVec(
(int32_t*)pushInfo.ptr,
((int32_t*)pushInfo.ptr) + pushInfo.size);
Expand All @@ -362,8 +362,8 @@ PYBIND11_MODULE(kp, m)
workgroup,
specConstsVec,
pushConstsVec);
} else if (spec_consts.dtype() ==
py::dtype::of<std::uint32_t>()) {
} else if (spec_consts.dtype().is(
py::dtype::of<std::uint32_t>())) {
std::vector<uint32_t> pushConstsVec(
(uint32_t*)pushInfo.ptr,
((uint32_t*)pushInfo.ptr) + pushInfo.size);
Expand All @@ -372,8 +372,8 @@ PYBIND11_MODULE(kp, m)
workgroup,
specConstsVec,
pushConstsVec);
} else if (spec_consts.dtype() ==
py::dtype::of<std::double_t>()) {
} else if (spec_consts.dtype().is(
py::dtype::of<std::double_t>())) {
std::vector<double> pushConstsVec((double*)pushInfo.ptr,
((double*)pushInfo.ptr) +
pushInfo.size);
Expand All @@ -383,11 +383,11 @@ PYBIND11_MODULE(kp, m)
specConstsVec,
pushConstsVec);
}
} else if (spec_consts.dtype() == py::dtype::of<std::int32_t>()) {
} else if (spec_consts.dtype().is(py::dtype::of<std::int32_t>())) {
std::vector<int32_t> specconstsvec((int32_t*)specInfo.ptr,
((int32_t*)specInfo.ptr) +
specInfo.size);
if (spec_consts.dtype() == py::dtype::of<std::float_t>()) {
if (spec_consts.dtype().is(py::dtype::of<std::float_t>())) {
std::vector<float> pushconstsvec((float*)pushInfo.ptr,
((float*)pushInfo.ptr) +
pushInfo.size);
Expand All @@ -396,8 +396,8 @@ PYBIND11_MODULE(kp, m)
workgroup,
specconstsvec,
pushconstsvec);
} else if (spec_consts.dtype() ==
py::dtype::of<std::int32_t>()) {
} else if (spec_consts.dtype().is(
py::dtype::of<std::int32_t>())) {
std::vector<int32_t> pushconstsvec(
(int32_t*)pushInfo.ptr,
((int32_t*)pushInfo.ptr) + pushInfo.size);
Expand All @@ -406,8 +406,8 @@ PYBIND11_MODULE(kp, m)
workgroup,
specconstsvec,
pushconstsvec);
} else if (spec_consts.dtype() ==
py::dtype::of<std::uint32_t>()) {
} else if (spec_consts.dtype().is(
py::dtype::of<std::uint32_t>())) {
std::vector<uint32_t> pushconstsvec(
(uint32_t*)pushInfo.ptr,
((uint32_t*)pushInfo.ptr) + pushInfo.size);
Expand All @@ -416,8 +416,8 @@ PYBIND11_MODULE(kp, m)
workgroup,
specconstsvec,
pushconstsvec);
} else if (spec_consts.dtype() ==
py::dtype::of<std::double_t>()) {
} else if (spec_consts.dtype().is(
py::dtype::of<std::double_t>())) {
std::vector<double> pushconstsvec((double*)pushInfo.ptr,
((double*)pushInfo.ptr) +
pushInfo.size);
Expand All @@ -427,11 +427,11 @@ PYBIND11_MODULE(kp, m)
specconstsvec,
pushconstsvec);
}
} else if (spec_consts.dtype() == py::dtype::of<std::uint32_t>()) {
} else if (spec_consts.dtype().is(py::dtype::of<std::uint32_t>())) {
std::vector<uint32_t> specconstsvec((uint32_t*)specInfo.ptr,
((uint32_t*)specInfo.ptr) +
specInfo.size);
if (spec_consts.dtype() == py::dtype::of<std::float_t>()) {
if (spec_consts.dtype().is(py::dtype::of<std::float_t>())) {
std::vector<float> pushconstsvec((float*)pushInfo.ptr,
((float*)pushInfo.ptr) +
pushInfo.size);
Expand All @@ -440,8 +440,8 @@ PYBIND11_MODULE(kp, m)
workgroup,
specconstsvec,
pushconstsvec);
} else if (spec_consts.dtype() ==
py::dtype::of<std::int32_t>()) {
} else if (spec_consts.dtype().is(
py::dtype::of<std::int32_t>())) {
std::vector<int32_t> pushconstsvec(
(int32_t*)pushInfo.ptr,
((int32_t*)pushInfo.ptr) + pushInfo.size);
Expand All @@ -450,8 +450,8 @@ PYBIND11_MODULE(kp, m)
workgroup,
specconstsvec,
pushconstsvec);
} else if (spec_consts.dtype() ==
py::dtype::of<std::uint32_t>()) {
} else if (spec_consts.dtype().is(
py::dtype::of<std::uint32_t>())) {
std::vector<uint32_t> pushconstsvec(
(uint32_t*)pushInfo.ptr,
((uint32_t*)pushInfo.ptr) + pushInfo.size);
Expand All @@ -460,8 +460,8 @@ PYBIND11_MODULE(kp, m)
workgroup,
specconstsvec,
pushconstsvec);
} else if (spec_consts.dtype() ==
py::dtype::of<std::double_t>()) {
} else if (spec_consts.dtype().is(
py::dtype::of<std::double_t>())) {
std::vector<double> pushconstsvec((double*)pushInfo.ptr,
((double*)pushInfo.ptr) +
pushInfo.size);
Expand All @@ -471,11 +471,11 @@ PYBIND11_MODULE(kp, m)
specconstsvec,
pushconstsvec);
}
} else if (spec_consts.dtype() == py::dtype::of<std::double_t>()) {
} else if (spec_consts.dtype().is(py::dtype::of<std::double_t>())) {
std::vector<double> specconstsvec((double*)specInfo.ptr,
((double*)specInfo.ptr) +
specInfo.size);
if (spec_consts.dtype() == py::dtype::of<std::float_t>()) {
if (spec_consts.dtype().is(py::dtype::of<std::float_t>())) {
std::vector<float> pushconstsvec((float*)pushInfo.ptr,
((float*)pushInfo.ptr) +
pushInfo.size);
Expand All @@ -484,8 +484,8 @@ PYBIND11_MODULE(kp, m)
workgroup,
specconstsvec,
pushconstsvec);
} else if (spec_consts.dtype() ==
py::dtype::of<std::int32_t>()) {
} else if (spec_consts.dtype().is(
py::dtype::of<std::int32_t>())) {
std::vector<float> pushconstsvec((int32_t*)pushInfo.ptr,
((int32_t*)pushInfo.ptr) +
pushInfo.size);
Expand All @@ -494,8 +494,8 @@ PYBIND11_MODULE(kp, m)
workgroup,
specconstsvec,
pushconstsvec);
} else if (spec_consts.dtype() ==
py::dtype::of<std::uint32_t>()) {
} else if (spec_consts.dtype().is(
py::dtype::of<std::uint32_t>())) {
std::vector<float> pushconstsvec((uint32_t*)pushInfo.ptr,
((uint32_t*)pushInfo.ptr) +
pushInfo.size);
Expand All @@ -504,8 +504,8 @@ PYBIND11_MODULE(kp, m)
workgroup,
specconstsvec,
pushconstsvec);
} else if (spec_consts.dtype() ==
py::dtype::of<std::double_t>()) {
} else if (spec_consts.dtype().is(
py::dtype::of<std::double_t>())) {
std::vector<float> pushconstsvec((double*)pushInfo.ptr,
((double*)pushInfo.ptr) +
pushInfo.size);
Expand All @@ -515,11 +515,9 @@ PYBIND11_MODULE(kp, m)
specconstsvec,
pushconstsvec);
}
} else {
// If reach then no valid dtype supported
throw std::runtime_error(
"Kompute Python no valid dtype supported");
}
// If reach then no valid dtype supported
throw std::runtime_error("Kompute Python no valid dtype supported");
},
DOC(kp, Manager, algorithm),
py::arg("tensors"),
Expand Down
14 changes: 10 additions & 4 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,18 @@ endif()
if(KOMPUTE_OPT_ANDROID_BUILD)
target_link_libraries(kompute PUBLIC kompute_vk_ndk_wrapper
android
PRIVATE kp_logger
fmt::fmt)
kp_logger
PRIVATE fmt::fmt)
else()
target_link_libraries(kompute PUBLIC Vulkan::Vulkan
PRIVATE fmt::fmt
kp_logger)
kp_logger
PRIVATE fmt::fmt)
endif()

if(KOMPUTE_OPT_BUILD_PYTHON)
include_directories(${PYTHON_INCLUDE_DIRS})

target_link_libraries(kompute PRIVATE pybind11::headers ${PYTHON_LIBRARIES})
endif()

if(KOMPUTE_OPT_USE_BUILD_IN_VULKAN_HEADER)
Expand Down

0 comments on commit 3425a6f

Please sign in to comment.