Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix PyTorch matmul conversion when given (2-dim, N-dim) input pair #7845

Merged
merged 13 commits into from
Apr 15, 2021
Prev Previous commit
Next Next commit
Lint fix
  • Loading branch information
yuchaoli committed Mar 11, 2021
commit 49e9b60fc2b921895973c4b650aa49be85a9994a
12 changes: 10 additions & 2 deletions python/tvm/auto_scheduler/measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -868,7 +868,11 @@ def _timed_eval_func(
if arg in tensor_input_map:
tensor_name = tensor_input_map[arg]
if tensor_name in task_input_names:
args.append(ndarray.array(get_task_input_buffer(inp.task.workload_key, tensor_name), ctx))
args.append(
ndarray.array(
get_task_input_buffer(inp.task.workload_key, tensor_name), ctx
)
)
task_inputs_count += 1
else:
raise ValueError(
Expand Down Expand Up @@ -1079,7 +1083,11 @@ def _timed_rpc_run(
if arg in tensor_input_map:
tensor_name = tensor_input_map[arg]
if tensor_name in task_input_names:
args.append(ndarray.array(get_task_input_buffer(inp.task.workload_key, tensor_name), ctx))
args.append(
ndarray.array(
get_task_input_buffer(inp.task.workload_key, tensor_name), ctx
)
)
task_inputs_count += 1
else:
raise ValueError(
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/auto_scheduler/relay_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def extract_tasks(
env = TracingEnvironment(
TracingMode.EXTRACT_TASK if include_simple_tasks else TracingMode.EXTRACT_COMPLEX_TASK_ONLY
)

dispatch_ctx = DispatchContext.current
old_verbose = dispatch_ctx.verbose
dispatch_ctx.verbose = 0
Expand Down
36 changes: 34 additions & 2 deletions tests/python/unittest/test_auto_scheduler_measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ def test_measure_target_host():


@tvm.testing.requires_llvm
def test_measure_special_inputs_map_by_name():
def test_measure_special_inputs_map_by_name_local_runner():
@auto_scheduler.register_workload
def foo():
X = te.placeholder(shape=[10], dtype="int32")
Expand All @@ -384,6 +384,37 @@ def foo():
assert mress[0].error_no == 0


@tvm.testing.requires_llvm
def test_measure_special_inputs_map_by_name_rpc_runner():
@auto_scheduler.register_workload
def foo():
X = te.placeholder(shape=[10], dtype="int32")
Index = te.placeholder(shape=[1], dtype="int32", name="Index")
Y = te.compute((1,), lambda i: X[Index[i]])
return [X, Index, Y]

# This workload cannot use random input for the `Index` input
task = auto_scheduler.SearchTask(
func=foo,
target="llvm",
task_inputs={
"Index": tvm.nd.array(np.array([5], dtype="int32")),
},
)

minp = auto_scheduler.MeasureInput(task, task.compute_dag.init_state)
local_builder = auto_scheduler.LocalBuilder()
measure_ctx = auto_scheduler.LocalRPCMeasureContext(
timeout=60, enable_cpu_cache_flush=enable_cpu_cache_flush
)
rpc_runner = measure_ctx.runner

bress = local_builder.build([minp])
assert bress[0].error_no == 0
mress = rpc_runner.run([minp], bress)
assert mress[0].error_no == 0


if __name__ == "__main__":
test_record_split_reorder_fuse_annotation()
test_record_compute_at_root_inline_cache_read_write()
Expand All @@ -395,4 +426,5 @@ def foo():
test_dag_measure_local_builder_runner()
test_measure_local_builder_rpc_runner()
test_measure_target_host()
test_measure_special_inputs_map_by_name()
test_measure_special_inputs_map_by_name_local_runner()
test_measure_special_inputs_map_by_name_rpc_runner()