-
Notifications
You must be signed in to change notification settings - Fork 3.8k
Closed as not planned
Closed as not planned
Copy link
Labels
needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address itPRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug
Description
Problem Statement
This bug was encountered when we're trying to using tvm as an pytorch backend to accelerating training performance. In training, a weight tensor is both input and output of an forward graph, which is a expected behavior. However, when we use "set_input_zero_copy" instead of "set_input" to reduce memcpy overhead, an error occured and the loss is totally wrong.
How to reproduce
import tvm
from tvm import relay
import numpy as np
from tvm.contrib import graph_executor
from tvm import testing
dev = tvm.cuda(0)
target = tvm.target.Target("cuda")
# A simple relay func:
# 1. y = x + 1
# 2. return [x, y]
x = relay.var("x", shape=(2, 2), dtype="float32")
y = relay.add(x, relay.ones((2, 2), dtype="float32"))
func = relay.Function([x], relay.expr.Tuple([x, y]))
# Build 2 exactly same relay module.
def build_relay_module():
mod = tvm.IRModule()
mod["main"] = func
lib = relay.build(mod, target=target)
m = graph_executor.GraphModule(lib["default"](dev))
return m
mod = build_relay_module()
mod_zero_copy = build_relay_module()
# Run these 2 modules.
# 2 same inputs.
input_nd = tvm.nd.array(np.ones((2, 2), dtype="float32"), device=dev)
input_nd_zero_copy = tvm.nd.array(np.ones((2, 2), dtype="float32"), device=dev)
# set_input() vs. set_input_zero_copy()
mod.set_input("x", input_nd)
index = mod_zero_copy.get_input_index("x")
mod_zero_copy.module["set_input_zero_copy"](index, input_nd_zero_copy)
# Run
mod.run()
mod_zero_copy.run()
# We expect 2 mod have the exactly same output "x", however...
testing.assert_allclose(mod.get_output(0).numpy(), mod_zero_copy.get_output(0).numpy())Output
Root Cause
In short, the implementation of SetInputZeroCopy is incorrect:
- It updates dltensors recored in
input_dltensors_, make them pointing to external dltensor'sdatafield: https://github.com/apache/tvm/blob/v0.12.0/src/runtime/graph_executor/graph_executor.cc#LL215C22-L215C38 - However
input_dltensorsrecords dltensors from op_args, which is only a "copy" ofdata_entry_'s dltensors (rather than "reference"). Thus,data_entry_'s dltensor hasn't been changed bySetInputZeroCopy. So the output get fromdata_entry_will always be 0.
cc @shingjan
Metadata
Metadata
Assignees
Labels
needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address itPRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug
