Skip to content

Commit

Permalink
Fix Div
Browse files Browse the repository at this point in the history
Change-Id: I66ca38746e19299c544b27ad4feab004adfae5d3
  • Loading branch information
WenjieZhou9 authored and Korbin-chen committed Feb 15, 2023
1 parent 5315aa6 commit 123549b
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 4 deletions.
9 changes: 7 additions & 2 deletions lib/Dialect/Top/Interfaces/Div.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,14 @@ int64_t top::DivOp::getFLOPs() { return module::getNumElements(getOutput()); }

LogicalResult top::DivOp::init(InferenceParameter &p) {
auto binary = new Binary();
auto lhs_shape = module::getShape(getInputs()[0]);
auto rhs_shape = module::getShape(getInputs()[1]);
auto max_ndim = std::max(lhs_shape.size(), rhs_shape.size());
auto input0_shape = shape_expand_dim(lhs_shape, max_ndim);
auto input1_shape = shape_expand_dim(rhs_shape, max_ndim);
(*binary)
.lhs(p.inputs[0], module::getShape(getInputs()[0]))
.rhs(p.inputs[1], module::getShape(getInputs()[1]))
.lhs(p.inputs[0], input0_shape)
.rhs(p.inputs[1], input1_shape)
.dst(p.outputs[0], module::getShape(getOutput()))
.do_relu(getDoRelu())
.relu_limit(getReluLimit().convertToDouble())
Expand Down
9 changes: 7 additions & 2 deletions lib/Dialect/Tpu/Interfaces/Common/Div.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,14 @@

LogicalResult tpu::DivOp::init(InferenceParameter &p) {
auto binary = new Binary();
auto lhs_shape = module::getShape(getInputs()[0]);
auto rhs_shape = module::getShape(getInputs()[1]);
auto max_ndim = std::max(lhs_shape.size(), rhs_shape.size());
auto input0_shape = shape_expand_dim(lhs_shape, max_ndim);
auto input1_shape = shape_expand_dim(rhs_shape, max_ndim);
(*binary)
.lhs(p.inputs[0], module::getShape(getInputs()[0]))
.rhs(p.inputs[1], module::getShape(getInputs()[1]))
.lhs(p.inputs[0], input0_shape)
.rhs(p.inputs[1], input1_shape)
.dst(p.outputs[0], module::getShape(getOutput()))
.do_relu(getDoRelu())
.relu_limit(getReluLimit().convertToDouble())
Expand Down

0 comments on commit 123549b

Please sign in to comment.