Skip to content

[Bug][Relay][GraphExecutor] "set_input_zero_copy()" causes error result when input is also in outputs. #14978

@zhuwenxi

Description

@zhuwenxi

Problem Statement

This bug was encountered when we're trying to using tvm as an pytorch backend to accelerating training performance. In training, a weight tensor is both input and output of an forward graph, which is a expected behavior. However, when we use "set_input_zero_copy" instead of "set_input" to reduce memcpy overhead, an error occured and the loss is totally wrong.

How to reproduce

import tvm
from tvm import relay
import numpy as np
from tvm.contrib import graph_executor
from tvm import testing


dev = tvm.cuda(0)
target = tvm.target.Target("cuda")

# A simple relay func:
# 1. y = x + 1
# 2. return [x, y]
x = relay.var("x", shape=(2, 2), dtype="float32")
y = relay.add(x, relay.ones((2, 2), dtype="float32"))
func = relay.Function([x], relay.expr.Tuple([x, y]))

# Build 2 exactly same relay module.
def build_relay_module():

    mod = tvm.IRModule()
    mod["main"] = func
    lib = relay.build(mod, target=target)

    m = graph_executor.GraphModule(lib["default"](dev))

    return m

mod = build_relay_module()
mod_zero_copy = build_relay_module()

# Run these 2 modules.

# 2 same inputs.
input_nd = tvm.nd.array(np.ones((2, 2), dtype="float32"), device=dev)
input_nd_zero_copy = tvm.nd.array(np.ones((2, 2), dtype="float32"), device=dev)

# set_input() vs. set_input_zero_copy()
mod.set_input("x", input_nd)

index = mod_zero_copy.get_input_index("x")
mod_zero_copy.module["set_input_zero_copy"](index, input_nd_zero_copy)

# Run
mod.run()
mod_zero_copy.run()

# We expect 2 mod have the exactly same output "x", however...
testing.assert_allclose(mod.get_output(0).numpy(), mod_zero_copy.get_output(0).numpy())

Output

image

Root Cause

In short, the implementation of SetInputZeroCopy is incorrect:

  1. It updates dltensors recored in input_dltensors_, make them pointing to external dltensor's data field: https://github.com/apache/tvm/blob/v0.12.0/src/runtime/graph_executor/graph_executor.cc#LL215C22-L215C38
  2. However input_dltensors records dltensors from op_args, which is only a "copy" of data_entry_'s dltensors (rather than "reference"). Thus, data_entry_'s dltensor hasn't been changed by SetInputZeroCopy. So the output get from data_entry_ will always be 0.

cc @shingjan

Metadata

Metadata

Assignees

No one assigned

    Labels

    needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions