Skip to content

Commit e421068

Browse files
committed
Fix tests
1 parent cd3beef commit e421068

File tree

3 files changed

+28
-8
lines changed

3 files changed

+28
-8
lines changed

src/relay/transforms/device_domains.cc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -236,12 +236,13 @@ DeviceDomainPtr DeviceDomains::DomainForCallee(const Call& call) {
236236
args_and_result.emplace_back(ForVirtualDevice(device_copy_props.body->checked_type(),
237237
device_copy_props.dst_virtual_device));
238238
} else if (call->op == alloc_storage_op) {
239-
ICHECK_EQ(call->args.size(), 2U);
240-
// alloc_storage(size, alignment, virtual_device=<t>)
241-
// alloc_storage: fn(<cpu>, <cpu>):<t>
239+
ICHECK_EQ(call->args.size(), 3U);
240+
// alloc_storage(size, shape, alignment, virtual_device=<t>)
241+
// alloc_storage: fn(<cpu>, <cpu>, <cpu>):<t>
242242
const auto* attrs = call->attrs.as<AllocStorageAttrs>();
243243
args_and_result.emplace_back(host_domain_);
244244
args_and_result.emplace_back(host_domain_);
245+
args_and_result.emplace_back(host_domain_);
245246
args_and_result.emplace_back(ForVirtualDevice(call->checked_type(), attrs->virtual_device));
246247
} else if (call->op == alloc_tensor_op) {
247248
ICHECK_EQ(call->args.size(), 3U);

tests/python/relay/test_pass_dead_code_elimination.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,10 @@
1616
# under the License.
1717
import tvm
1818
import tvm.testing
19+
from tvm import relay
1920
from tvm.relay import Function, transform
2021
from tvm.relay.testing import inception_v3
22+
import numpy as np
2123
import pytest
2224

2325
cpu_scope = tvm.target.VirtualDevice(tvm.cpu(), tvm.target.Target("llvm"))
@@ -228,14 +230,19 @@ def @main() {
228230

229231

230232
def test_impure_op():
233+
shape = np.array([64, 2])
234+
metatable = {
235+
"VirtualDevice": [cpu_scope],
236+
"relay.Constant": [relay.const(shape, dtype="int64")],
237+
}
231238
"""Don't elide calls to side-effecting operators."""
232239
before_program = tvm.relay.parse(
233240
"""
234241
#[version = "0.0.5"]
235242
def @main() {
236243
let %size: int64 = cast(1024, dtype="int64");
237244
let %alignment: int64 = cast(64, dtype="int64");
238-
let %x = memory.alloc_storage(%size, %alignment, virtual_device=meta[VirtualDevice][0]);
245+
let %x = memory.alloc_storage(%size, meta[relay.Constant][0], %alignment, virtual_device=meta[VirtualDevice][0]);
239246
let %_ = memory.kill(%x);
240247
0
241248
}
@@ -250,6 +257,7 @@ def @main() {
250257
#[version = "0.0.5"]
251258
def @main() {
252259
%0 = memory.alloc_storage(cast(1024, dtype="int64"),
260+
meta[relay.Constant][0],
253261
cast(64, dtype="int64"),
254262
virtual_device=meta[VirtualDevice][0]);
255263
let %_ = memory.kill(%0);
@@ -267,14 +275,19 @@ def @main() {
267275

268276

269277
def test_impure_func():
278+
shape = np.array([64, 2])
279+
metatable = {
280+
"VirtualDevice": [cpu_scope],
281+
"relay.Constant": [relay.const(shape, dtype="int64")],
282+
}
270283
"""Don't elide calls to side-effecting functions."""
271284
before_program = tvm.relay.parse(
272285
"""
273286
#[version = "0.0.5"]
274287
def @f() -> int {
275288
let %size: int64 = cast(1024, dtype="int64");
276289
let %alignment: int64 = cast(64, dtype="int64");
277-
let %x = memory.alloc_storage(%size, %alignment, virtual_device=meta[VirtualDevice][0]);
290+
let %x = memory.alloc_storage(%size, meta[relay.Constant][0], %alignment, virtual_device=meta[VirtualDevice][0]);
278291
let %_ = memory.kill(%x);
279292
0
280293
}
@@ -293,6 +306,7 @@ def @main() -> int {
293306
#[version = "0.0.5"]
294307
def @f() -> int {
295308
%0 = memory.alloc_storage(cast(1024, dtype="int64"),
309+
meta[relay.Constant][0],
296310
cast(64, dtype="int64"),
297311
virtual_device=meta[VirtualDevice][0]);
298312
let %_ = memory.kill(%0);

tests/python/relay/test_pass_plan_devices.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -761,14 +761,18 @@ def ref(x):
761761

762762

763763
def test_alloc_storage():
764-
metatable = {"VirtualDevice": [HOST, GPU]}
764+
shape = np.array([3, 2])
765+
metatable = {
766+
"VirtualDevice": [HOST, GPU],
767+
"relay.Constant": [relay.const(shape, dtype="int64")],
768+
}
765769

766770
def input():
767771
return tvm.relay.parse(
768772
"""
769773
#[version = "0.0.5"]
770774
def @main(%size: int64, %alignment: int64) {
771-
memory.alloc_storage(%size, %alignment, virtual_device=meta[VirtualDevice][1])
775+
memory.alloc_storage(%size, meta[relay.Constant][0], %alignment, virtual_device=meta[VirtualDevice][1])
772776
}
773777
""",
774778
"from_string",
@@ -782,7 +786,8 @@ def expected():
782786
#[version = "0.0.5"]
783787
def @main(%size {virtual_device=meta[VirtualDevice][0]}: int64, %alignment {virtual_device=meta[VirtualDevice][0]}: int64,
784788
virtual_device=meta[VirtualDevice][1]) {
785-
memory.alloc_storage(%size, %alignment, virtual_device=meta[VirtualDevice][1])
789+
%0 = on_device(meta[relay.Constant][0], virtual_device=meta[VirtualDevice][0], constrain_result=True);
790+
memory.alloc_storage(%size, %0, %alignment, virtual_device=meta[VirtualDevice][1])
786791
}
787792
""",
788793
"from_string",

0 commit comments

Comments
 (0)