Skip to content

Commit 46c3a10

Browse files
altanhylc
authored andcommitted
[FIX][CI] hotfix check_grad perf regression (apache#8581)
* hotfix check_grad perf regression: lift compile out of hot loop * hoist interpreter creation out of python closure, fix weird conv2d bug on arm cpu * lint * try one more fix
1 parent 2228b64 commit 46c3a10

File tree

3 files changed

+10
-5
lines changed

3 files changed

+10
-5
lines changed

python/tvm/relay/backend/interpreter.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,8 @@ def _make_executor(self, expr=None):
227227
if expr is None or isinstance(expr, GlobalVar):
228228
assert self.mod is not None
229229

230+
_intrp = _backend.CreateInterpreter(self.optimize(), self.device, self.target)
231+
230232
def _interp_wrapper(*args, **kwargs):
231233
if expr is None:
232234
args = self._convert_args(self.mod["main"], args, kwargs)
@@ -253,7 +255,6 @@ def _interp_wrapper(*args, **kwargs):
253255

254256
mod = self.optimize()
255257
opt_expr = Call(mod["main"], relay_args)
256-
_intrp = _backend.CreateInterpreter(mod, self.device, self.target)
257258
return _intrp(opt_expr)
258259

259260
return _interp_wrapper

python/tvm/relay/testing/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,15 +154,16 @@ def check_grad(
154154
assert len(grads) > 0, "You must test at least one gradient."
155155

156156
# Get numeric gradients for each dimension of each param, using two-sided approximation.
157+
fwd_func_compiled = intrp.evaluate(fwd_func)
157158
approx_grads = []
158159
for x in test_inputs:
159160
approx_grad = np.zeros(x.shape)
160161
for i in np.ndindex(*x.shape):
161162
x_i = x[i]
162163
x[i] = x_i + eps
163-
fwd_plus = intrp.evaluate(fwd_func)(*inputs).numpy().astype("float64")
164+
fwd_plus = fwd_func_compiled(*inputs).numpy().astype("float64")
164165
x[i] = x_i - eps
165-
fwd_minus = intrp.evaluate(fwd_func)(*inputs).numpy().astype("float64")
166+
fwd_minus = fwd_func_compiled(*inputs).numpy().astype("float64")
166167
x[i] = x_i
167168
approx_grad[i] = np.sum((fwd_plus - fwd_minus) / (2 * eps))
168169
approx_grads.append(approx_grad)

python/tvm/topi/arm_cpu/group_conv2d.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,9 @@ def schedule_group_conv2d_nchw(outs):
4242
return schedule_group_conv2d_nchwc(outs)
4343

4444

45-
def _get_default_config(cfg, data, kernel, strides, padding, groups, out_dtype, layout="NCHW"):
45+
def _get_default_config(
46+
cfg, data, kernel, strides, padding, dilation, groups, out_dtype, layout="NCHW"
47+
):
4648
"""
4749
Get default schedule config for the workload
4850
"""
@@ -54,7 +56,7 @@ def _get_default_config(cfg, data, kernel, strides, padding, groups, out_dtype,
5456
static_data_shape.append(dim)
5557
data = te.placeholder(static_data_shape, dtype=data.dtype)
5658

57-
wkl = _get_conv2d_workload(data, kernel, strides, padding, out_dtype, layout)
59+
wkl = _get_conv2d_workload(data, kernel, strides, padding, dilation, out_dtype, layout)
5860
_fallback_schedule(cfg, wkl)
5961

6062

@@ -158,6 +160,7 @@ def group_conv2d_nchw_spatial_pack(
158160
),
159161
strides,
160162
padding,
163+
dilation,
161164
groups,
162165
out_dtype,
163166
)

0 commit comments

Comments
 (0)