Skip to content

Commit 427fc97

Browse files
committed
[runtime] modified OpKernel to support dynamic shape
1 parent c4eeef0 commit 427fc97

File tree

13 files changed

+427
-49
lines changed

13 files changed

+427
-49
lines changed

runtime/include/brt/backends/cuda/device/utils/op_kernel_impl_helpers.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ using CurandOpKernelIfaceTraits =
143143
* struct ConcreateOpImpl {
144144
* ConcreateOpImpl(const OpAccessor&);
145145
* void Execute(args..., cudaStream_t);
146+
* optional<void ProluguePerExecute(const OpAccessor&)>;
146147
* };
147148
* using ConcreteOp = CudaOpKernel<ConcreateOpImpl, Arguments...>;
148149
*/
@@ -153,6 +154,7 @@ BRT_DEF_OP_KERNEL_WRPPER(CudaOpKernel,
153154
* struct ConcreateOpImpl {
154155
* ConcreateOpImpl(const OpAccessor&);
155156
* void Execute(args..., cublasHandle_t, cudaStream_t);
157+
* optional<void ProluguePerExecute(const OpAccessor&)>;
156158
* };
157159
* using ConcreteOp = CublasOpKernel<ConcreateOpImpl, Arguments...>;
158160
*/
@@ -163,6 +165,7 @@ BRT_DEF_OP_KERNEL_WRPPER(CublasOpKernel,
163165
* struct ConcreateOpImpl {
164166
* ConcreateOpImpl(const OpAccessor&);
165167
* void Execute(args..., cudnnHandle_t, cudaStream_t);
168+
* optional<void ProluguePerExecute(const OpAccessor&)>;
166169
* };
167170
* using ConcreteOp = CudnnOpKernel<ConcreateOpImpl, Arguments...>;
168171
*/
@@ -173,6 +176,7 @@ BRT_DEF_OP_KERNEL_WRPPER(CudnnOpKernel,
173176
* struct ConcreateOpImpl {
174177
* ConcreateOpImpl(const OpAccessor&);
175178
* void Execute(args..., void* workspace, cudaStream_t);
179+
* optional<void ProluguePerExecute(const OpAccessor&)>;
176180
* size_t GetWorkspaceSize(const ExecutionContext &);
177181
* };
178182
* using ConcreteOp = CudaOpKernelWithWorkspace<ConcreateOpImpl, Arguments...>;

runtime/include/brt/core/framework/op_kernel_impl_base.h

Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "brt/core/context/work_queue.h"
2323
#include "brt/core/framework/op_accessor.h"
2424
#include "brt/core/framework/op_kernel.h"
25+
#include "mlir/IR/BuiltinTypeInterfaces.h"
2526

2627
namespace brt {
2728

@@ -169,6 +170,10 @@ template <typename... Arguments> struct OpKernelIfaceTraitsBase {
169170

170171
template <typename Impl>
171172
common::Status static inline Run(Impl *impl, const ExecutionContext &ctx) {
173+
auto status = impl->ProloguePerExecute(ctx);
174+
if (!status.IsOK()) {
175+
return status;
176+
}
172177
return impl->Execute(Arguments::Get(impl, ctx)...);
173178
}
174179

@@ -187,17 +192,63 @@ template <typename... Arguments> struct OpKernelIfaceTraitsBase {
187192

188193
template <typename... Arguments>
189194
struct NaiveOpKernelIfaceTraits : public OpKernelIfaceTraitsBase<Arguments...> {
195+
196+
template <typename T> struct TrueHelper : std::true_type {};
197+
198+
template <typename ClassType, typename... ArgType>
199+
struct HasProloguePerExecuteTraits {
200+
template <typename Impl, typename... Arg>
201+
static auto CheckPrologurePerExecute(int)
202+
-> TrueHelper<decltype(std::declval<Impl>().ProloguePerExecute(
203+
std::declval<Arg>()...))>;
204+
205+
template <typename Impl, typename... Arg>
206+
static auto CheckPrologurePerExecute(...) -> std::false_type;
207+
208+
public:
209+
enum {
210+
value =
211+
decltype(CheckPrologurePerExecute<ClassType, ArgType...>(0))::value
212+
};
213+
};
214+
190215
template <typename ImplBase> struct ImplMixin : public ImplBase {
191216
public:
192-
explicit ImplMixin(const OpKernelInfo &info)
193-
: ImplBase(info), info_(info) {}
217+
explicit ImplMixin(const OpKernelInfo &info) : ImplBase(info), info_(info) {
218+
// initialize `io_contain_dynamic_shape`
219+
io_contain_dynamic_shape = false;
220+
OpAccessor accessor(info);
221+
size_t num_args = accessor.GetNumArgs();
222+
for (size_t i = 0; i < accessor.GetNumArgs(); ++i) {
223+
auto shape = accessor.GetArgShape(i);
224+
if (mlir::ShapedType::isDynamicShape(shape)) {
225+
io_contain_dynamic_shape = true;
226+
}
227+
}
228+
for (size_t i = 0; i < accessor.GetNumResults(); ++i) {
229+
auto shape = accessor.GetArgShape(i + num_args);
230+
if (mlir::ShapedType::isDynamicShape(shape)) {
231+
io_contain_dynamic_shape = true;
232+
}
233+
}
234+
}
235+
236+
common::Status ProloguePerExecute(const ExecutionContext &ctx) {
237+
if constexpr (HasProloguePerExecuteTraits<ImplBase, OpAccessor>::value) {
238+
if (io_contain_dynamic_shape) {
239+
ImplBase::ProloguePerExecute(GetOpAccessor(ctx));
240+
}
241+
}
242+
return Status::OK();
243+
}
194244

195245
OpAccessor GetOpAccessor(const ExecutionContext &ctx) const {
196246
return OpAccessor(info_, ctx.exec_frame);
197247
}
198248

199249
private:
200250
const OpKernelInfo &info_;
251+
bool io_contain_dynamic_shape;
201252
};
202253
};
203254

runtime/include/brt/core/framework/op_kernel_info.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -143,13 +143,16 @@ class OpKernelInfo {
143143

144144
// Utilities
145145

146-
// Get Tensor as uniuqe Index, from the ith argument of OpKernelInfo
146+
// Get Tensor as unique Index, from the ith argument of OpKernelInfo
147147
size_t GetTensorIndexFromOpArgIndex(const OpKernelInfo &, unsigned int i);
148148

149-
// Get Tensor as uniuqe Index, from MLIR Value
149+
// Get Tensor as unique Index, from MLIR Value
150150
size_t GetTensorIndexFromMLIRValue(const OpKernelInfo &, mlir::Value val);
151151

152-
// Get Scalar as uniuqe Index, from MLIR Value
152+
// Get Scalar as unique Index, from the ith argument of OpKernelInfo
153+
size_t GetScalarIndexFromOpArgIndex(const OpKernelInfo &, unsigned int i);
154+
155+
// Get Scalar as unique Index, from MLIR Value
153156
size_t GetScalarIndexFromMLIRValue(const OpKernelInfo &, mlir::Value val);
154157

155158
// Get Rank of MLIR Value, of ith argument of OpKernelInfo

runtime/lib/backends/cuda/providers/default/codegen/ptx.cc

Lines changed: 63 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ using namespace mlir;
4545
#define BLOCK_SIZE_Z_ATTR "BlockSize.z"
4646
#define ARG_RANKS_ATTR "arg_ranks"
4747
#define CALL_CONVENTION_ATTR "call_convention"
48+
#define DYNAMIC_CONFIG "__byteir_dynamic_config__"
49+
#define KERNEL_LAUNCH_CONFIG_NUM 6
4850

4951
namespace brt {
5052
namespace cuda {
@@ -123,42 +125,50 @@ PTXOpKernel::PTXOpKernel(const OpKernelInfo &info)
123125
impl_->call_convention = "all";
124126
// static assignment for config
125127
// TODO extend to support dynamic
126-
if (!info.GetOperation()->hasAttrOfType<IntegerAttr>(GRID_SIZE_X_ATTR)) {
127-
BRT_THROW_EX(std::runtime_error, "no GridSize.x attr");
128+
bool dynamic_config_flag = false;
129+
if (info.GetOperation()->hasAttr(DYNAMIC_CONFIG)) {
130+
dynamic_config_flag = true;
128131
}
132+
int gx, gy, gz, bx, by, bz;
133+
gx = gy = gz = bx = by = bz = 1;
134+
if (!dynamic_config_flag) {
135+
if (!info.GetOperation()->hasAttrOfType<IntegerAttr>(GRID_SIZE_X_ATTR)) {
136+
BRT_THROW_EX(std::runtime_error, "no GridSize.x attr");
137+
}
129138

130-
if (!info.GetOperation()->hasAttrOfType<IntegerAttr>(BLOCK_SIZE_X_ATTR)) {
131-
BRT_THROW_EX(std::runtime_error, "no BlockSize.x attr");
132-
}
139+
if (!info.GetOperation()->hasAttrOfType<IntegerAttr>(BLOCK_SIZE_X_ATTR)) {
140+
BRT_THROW_EX(std::runtime_error, "no BlockSize.x attr");
141+
}
133142

134-
int gx = static_cast<int>(info.GetOperation()
135-
->getAttrOfType<IntegerAttr>(GRID_SIZE_X_ATTR)
136-
.getInt()),
137-
gy = 1, gz = 1;
138-
if (info.GetOperation()->hasAttrOfType<IntegerAttr>(GRID_SIZE_Y_ATTR)) {
139-
gy = static_cast<int>(info.GetOperation()
140-
->getAttrOfType<IntegerAttr>(GRID_SIZE_Y_ATTR)
141-
.getInt());
142-
}
143-
if (info.GetOperation()->hasAttrOfType<IntegerAttr>(GRID_SIZE_Z_ATTR)) {
144-
gz = static_cast<int>(info.GetOperation()
145-
->getAttrOfType<IntegerAttr>(GRID_SIZE_Z_ATTR)
146-
.getInt());
147-
}
143+
gx = static_cast<int>(info.GetOperation()
144+
->getAttrOfType<IntegerAttr>(GRID_SIZE_X_ATTR)
145+
.getInt()),
146+
gy = 1, gz = 1;
147+
if (info.GetOperation()->hasAttrOfType<IntegerAttr>(GRID_SIZE_Y_ATTR)) {
148+
gy = static_cast<int>(info.GetOperation()
149+
->getAttrOfType<IntegerAttr>(GRID_SIZE_Y_ATTR)
150+
.getInt());
151+
}
152+
if (info.GetOperation()->hasAttrOfType<IntegerAttr>(GRID_SIZE_Z_ATTR)) {
153+
gz = static_cast<int>(info.GetOperation()
154+
->getAttrOfType<IntegerAttr>(GRID_SIZE_Z_ATTR)
155+
.getInt());
156+
}
148157

149-
int bx = static_cast<int>(info.GetOperation()
150-
->getAttrOfType<IntegerAttr>(BLOCK_SIZE_X_ATTR)
151-
.getInt()),
152-
by = 1, bz = 1;
153-
if (info.GetOperation()->hasAttrOfType<IntegerAttr>(BLOCK_SIZE_Y_ATTR)) {
154-
by = static_cast<int>(info.GetOperation()
155-
->getAttrOfType<IntegerAttr>(BLOCK_SIZE_Y_ATTR)
156-
.getInt());
157-
}
158-
if (info.GetOperation()->hasAttrOfType<IntegerAttr>(BLOCK_SIZE_Z_ATTR)) {
159-
bz = static_cast<int>(info.GetOperation()
160-
->getAttrOfType<IntegerAttr>(BLOCK_SIZE_Z_ATTR)
161-
.getInt());
158+
bx = static_cast<int>(info.GetOperation()
159+
->getAttrOfType<IntegerAttr>(BLOCK_SIZE_X_ATTR)
160+
.getInt()),
161+
by = 1, bz = 1;
162+
if (info.GetOperation()->hasAttrOfType<IntegerAttr>(BLOCK_SIZE_Y_ATTR)) {
163+
by = static_cast<int>(info.GetOperation()
164+
->getAttrOfType<IntegerAttr>(BLOCK_SIZE_Y_ATTR)
165+
.getInt());
166+
}
167+
if (info.GetOperation()->hasAttrOfType<IntegerAttr>(BLOCK_SIZE_Z_ATTR)) {
168+
bz = static_cast<int>(info.GetOperation()
169+
->getAttrOfType<IntegerAttr>(BLOCK_SIZE_Z_ATTR)
170+
.getInt());
171+
}
162172
}
163173

164174
std::vector<int> ranks;
@@ -172,6 +182,10 @@ PTXOpKernel::PTXOpKernel(const OpKernelInfo &info)
172182
}
173183

174184
auto num_arg = GetOpArgNum(info_);
185+
// filter launch config in inputs
186+
// TODO: make `shared_size` be a input operand in compiler.
187+
if (dynamic_config_flag)
188+
num_arg -= KERNEL_LAUNCH_CONFIG_NUM;
175189
impl_->grid = dim3(gx, gy, gz);
176190
impl_->block = dim3(bx, by, bz);
177191
impl_->shared_size = 0;
@@ -198,20 +212,34 @@ common::Status PTXOpKernel::RunImpl(const ExecutionContext &ctx) {
198212
std::vector<void *> args;
199213
std::vector<MLIREngineMemRefDescriptor> descs;
200214
args.reserve(impl_->arg_reserve_size);
215+
bool dynamic_config_flag = false;
216+
if (info_.GetOperation()->hasAttr(DYNAMIC_CONFIG)) {
217+
dynamic_config_flag = true;
218+
auto num_arg = GetOpArgNum(info_);
219+
std::vector<int64_t> launch_config;
220+
launch_config.reserve(KERNEL_LAUNCH_CONFIG_NUM);
221+
for (size_t i = num_arg - KERNEL_LAUNCH_CONFIG_NUM; i < num_arg; ++i) {
222+
size_t idx = GetScalarIndexFromOpArgIndex(info_, i);
223+
launch_config.emplace_back(ctx.exec_frame->GetScalar<int64_t>(idx));
224+
}
225+
impl_->grid = dim3(launch_config[0], launch_config[1], launch_config[2]);
226+
impl_->block = dim3(launch_config[3], launch_config[4], launch_config[5]);
227+
}
228+
201229
args.push_back(&(impl_->grid));
202230
args.push_back(&(impl_->block));
203231
args.push_back(&(impl_->shared_size));
204232

205233
descs.reserve(impl_->tensor_ids.size());
206234
for (size_t i = 0; i < impl_->tensor_ids.size(); ++i) {
207235
descs.emplace_back(ctx.exec_frame->GetAsyncValueRef(impl_->tensor_ids[i]),
208-
impl_->tensor_ranks[i]);
236+
ctx.exec_frame->GetShapeRef(impl_->tensor_ids[i]));
209237
if (impl_->call_convention == "bare_ptr")
210238
args.push_back(&descs.back().data);
211-
else
239+
else {
212240
InsertMemDescToArgs(descs.back(), args);
241+
}
213242
}
214-
215243
auto work_queue = static_cast<CUDAWorkQueue *>(ctx.work_queue);
216244
auto cuda_env = work_queue->GetCudaEnv();
217245
BRT_ENFORCE(cuda_env.IsPrimaryContext(),

runtime/lib/backends/cuda/providers/default/math/matmul.cc

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,24 @@ template <typename T> MatmulImpl<T>::MatmulImpl(const OpAccessor &accessor) {
6868
}
6969
}
7070

71+
template <typename T>
72+
void MatmulImpl<T>::ProloguePerExecute(const OpAccessor &accessor) {
73+
auto shape_a = accessor.GetArgShape(0);
74+
auto shape_b = accessor.GetArgShape(1);
75+
if (!lhs_transpose) {
76+
m = shape_a[0];
77+
k = shape_a[1];
78+
} else {
79+
m = shape_a[1];
80+
k = shape_a[0];
81+
}
82+
if (!rhs_transpose) {
83+
n = shape_b[1];
84+
} else {
85+
n = shape_b[0];
86+
}
87+
}
88+
7189
template <>
7290
void MatmulImpl<float>::Execute(const float *a_val, const float *b_val,
7391
float *c_val, cublasHandle_t handle,

runtime/lib/backends/cuda/providers/default/math/matmul.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ template <typename T> class MatmulImpl {
3030
public:
3131
explicit MatmulImpl(const OpAccessor &accessor);
3232

33+
void ProloguePerExecute(const OpAccessor &);
34+
3335
void Execute(const T *a_val, const T *b_val, T *c_val, cublasHandle_t handle,
3436
cudaStream_t stream);
3537

runtime/lib/core/context/execution_frame.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,11 @@ void BRTInferenceExecutionFrame::BindArg(size_t idx, const void *ptr) {
186186
}
187187

188188
void *BRTInferenceExecutionFrame::GetArg(size_t idx) {
189+
// this only for debug : get weight ptr
190+
if (idx >= info_.graph_info.io_count) {
191+
return ctx_.weights_and_ios[idx - info_.graph_info.io_count];
192+
}
193+
189194
BRT_ENFORCE(idx < info_.graph_info.io_count);
190195
int i = info_.weights.size() + idx;
191196

runtime/lib/core/framework/execution_plan.cc

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -337,16 +337,20 @@ common::Status StaticBRTExecutionPlan::ProloguePerSession(
337337
return WalkResult::interrupt();
338338
}
339339

340-
auto maybeSpace = brt::ir::GetSpace(op_arg);
341-
if (!maybeSpace.has_value()) {
342-
status_internal = Status(BRT, FAIL, "non-memref Arg of Op " + key);
343-
return WalkResult::interrupt();
344-
}
345-
346-
auto space = maybeSpace.value();
347-
IAllocator *cur_allocator = GetAllocator(allocators, space);
348-
last_alloc = cur_allocator;
340+
std::string space;
341+
IAllocator *cur_allocator;
342+
if (op_arg.getType().dyn_cast<MemRefType>()) {
343+
auto maybeSpace = brt::ir::GetSpace(op_arg);
344+
if (!maybeSpace.has_value()) {
345+
status_internal =
346+
Status(BRT, FAIL, "non-memref Arg of Op " + key);
347+
return WalkResult::interrupt();
348+
}
349349

350+
space = maybeSpace.value();
351+
cur_allocator = GetAllocator(allocators, space);
352+
last_alloc = cur_allocator;
353+
}
350354
// skip if visited
351355
if (visited_ptrs.count(arg_ptr) != 0) {
352356
continue;
@@ -366,6 +370,10 @@ common::Status StaticBRTExecutionPlan::ProloguePerSession(
366370
graph_info_.tensor_to_id.emplace(arg_ptr,
367371
graph_info_.tensors.size());
368372
graph_info_.tensors.push_back(arg_ptr);
373+
} else if (op_arg.getType().isa<IndexType>()) {
374+
int64_t scalar_index = graph_info_.scalars.size();
375+
graph_info_.scalar_to_id.emplace(arg_ptr, scalar_index);
376+
graph_info_.scalars.push_back(arg_ptr);
369377
} else {
370378
status_internal =
371379
Status(BRT, FAIL, " non-supported Arg Type of Op " + key);
@@ -473,6 +481,11 @@ common::Status StaticBRTExecutionPlan::ProloguePerSession(
473481
return WalkResult::interrupt();
474482
}
475483

484+
// PTXOp launch config?
485+
if (op_arg.getType().isa<IndexType>()) {
486+
continue;
487+
}
488+
476489
auto found_arg = graph_info_.tensor_to_id.find(arg_ptr);
477490
if (found_arg == graph_info_.tensor_to_id.end()) {
478491
status_internal = Status(BRT, FAIL, "cannot find arg");

0 commit comments

Comments
 (0)