1616# under the License.
1717import tvm
1818import tvm .testing
19+ from tvm import relay
1920from tvm .relay import Function , transform
2021from tvm .relay .testing import inception_v3
22+ import numpy as np
2123import pytest
2224
2325cpu_scope = tvm .target .VirtualDevice (tvm .cpu (), tvm .target .Target ("llvm" ))
@@ -228,14 +230,19 @@ def @main() {
228230
229231
230232def test_impure_op ():
233+ shape = np .array ([64 , 2 ])
234+ metatable = {
235+ "VirtualDevice" : [cpu_scope ],
236+ "relay.Constant" : [relay .const (shape , dtype = "int64" )],
237+ }
231238 """Don't elide calls to side-effecting operators."""
232239 before_program = tvm .relay .parse (
233240 """
234241 #[version = "0.0.5"]
235242 def @main() {
236243 let %size: int64 = cast(1024, dtype="int64");
237244 let %alignment: int64 = cast(64, dtype="int64");
238- let %x = memory.alloc_storage(%size, %alignment, virtual_device=meta[VirtualDevice][0]);
245+ let %x = memory.alloc_storage(%size, meta[relay.Constant][0], %alignment, virtual_device=meta[VirtualDevice][0]);
239246 let %_ = memory.kill(%x);
240247 0
241248 }
@@ -250,6 +257,7 @@ def @main() {
250257 #[version = "0.0.5"]
251258 def @main() {
252259 %0 = memory.alloc_storage(cast(1024, dtype="int64"),
260+ meta[relay.Constant][0],
253261 cast(64, dtype="int64"),
254262 virtual_device=meta[VirtualDevice][0]);
255263 let %_ = memory.kill(%0);
@@ -267,14 +275,19 @@ def @main() {
267275
268276
269277def test_impure_func ():
278+ shape = np .array ([64 , 2 ])
279+ metatable = {
280+ "VirtualDevice" : [cpu_scope ],
281+ "relay.Constant" : [relay .const (shape , dtype = "int64" )],
282+ }
270283 """Don't elide calls to side-effecting functions."""
271284 before_program = tvm .relay .parse (
272285 """
273286 #[version = "0.0.5"]
274287 def @f() -> int {
275288 let %size: int64 = cast(1024, dtype="int64");
276289 let %alignment: int64 = cast(64, dtype="int64");
277- let %x = memory.alloc_storage(%size, %alignment, virtual_device=meta[VirtualDevice][0]);
290+ let %x = memory.alloc_storage(%size, meta[relay.Constant][0], %alignment, virtual_device=meta[VirtualDevice][0]);
278291 let %_ = memory.kill(%x);
279292 0
280293 }
@@ -293,6 +306,7 @@ def @main() -> int {
293306 #[version = "0.0.5"]
294307 def @f() -> int {
295308 %0 = memory.alloc_storage(cast(1024, dtype="int64"),
309+ meta[relay.Constant][0],
296310 cast(64, dtype="int64"),
297311 virtual_device=meta[VirtualDevice][0]);
298312 let %_ = memory.kill(%0);
0 commit comments