Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TOPI] Minor perf improvement for GPU scatter #7233

Merged
merged 16 commits into from
Jan 19, 2021
Prev Previous commit
Next Next commit
skip random_fill when a tuning workload is from scatter
This reverts commit 1fed883.
  • Loading branch information
masahi committed Jan 18, 2021
commit a876443e47387e7e3bf4b3168b6d26c419a398f7
6 changes: 3 additions & 3 deletions python/tvm/autotvm/measure/measure_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,9 +561,9 @@ def run_through_rpc(
"Please make sure USE_RANDOM is ON in the config.cmake " "on the remote devices"
)
args = [nd.empty(x[0], dtype=x[1], ctx=ctx) for x in build_result.arg_info]
for arg in args:
random_fill(arg)
ctx.sync()
if "scatter" not in measure_input.task.name:
for arg in args:
random_fill(arg)
masahi marked this conversation as resolved.
Show resolved Hide resolved

costs = time_f(*args).results

Expand Down
6 changes: 2 additions & 4 deletions python/tvm/topi/cuda/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,8 +461,7 @@ def update_func(dst_ptr, dst_index, update):
out_shape = data.shape
out_buf = tvm.tir.decl_buffer(out_shape, data.dtype, "out_buf")

cfg.define_knob("dummy", [1])
cfg.add_flop(1) # dummy value to satisfy AutoTVM
cfg.add_flop(1) # A dummy value to satisfy AutoTVM

out = te.extern(
[out_shape],
Expand Down Expand Up @@ -593,8 +592,7 @@ def scatter_via_sort(cfg, data, indices, updates, axis=0):
assert axis == 0 and len(data.shape) == 1, "sorting based scatter only supported for 1d input"
assert is_thrust_available(), "Thrust is required for this op"

cfg.define_knob("dummy", [1])
cfg.add_flop(1)
cfg.add_flop(1) # A dummy value to satisfy AutoTVM

out_shape = data.shape
out_buf = tvm.tir.decl_buffer(out_shape, data.dtype, "out_buf")
Expand Down
2 changes: 1 addition & 1 deletion tutorials/autotvm/test_scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def tune_tasks(
def tune_and_evaluate(tuning_opt):
# extract workloads from relay program
print("Extract tasks...")
size = (5000, )
size = (10000, )
dshape = ishape = size
axis = 0
mod, params = simple_mod(size, size, axis)
Expand Down