Skip to content

Commit f16f3a5

Browse files
YuchenJinyongwww
authored andcommitted
Update vm build. (apache#55)
1 parent 93cb308 commit f16f3a5

File tree

2 files changed

+40
-40
lines changed

2 files changed

+40
-40
lines changed

python/tvm/relax/vm.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -138,26 +138,22 @@ def __getitem__(self, key: str) -> PackedFunc:
138138
return self.module[key]
139139

140140

141-
def build(mod: tvm.IRModule,
142-
target: tvm.target.Target,
143-
target_host: tvm.target.Target) -> Tuple[Executable, Module]:
141+
def build(mod: tvm.IRModule, target: tvm.target.Target) -> Tuple[Executable, Module]:
144142
"""
145143
Build an IRModule to VM executable.
146144
147145
Parameters
148146
----------
149147
mod: IRModule
150-
The IR module.
148+
The input IRModule to be built.
151149
152150
target : tvm.target.Target
153-
A build target.
151+
A build target which can have optional host side compilation target.
154152
155-
target_host : tvm.target.Target
156-
Host compilation target, if target is device.
157153
When TVM compiles device specific program such as CUDA,
158154
we also need host(CPU) side code to interact with the driver
159155
to setup the dimensions and parameters correctly.
160-
target_host is used to specify the host side codegen target.
156+
host is used to specify the host side codegen target.
161157
By default, llvm is used if it is enabled,
162158
otherwise a stackvm intepreter is used.
163159
@@ -167,6 +163,20 @@ def build(mod: tvm.IRModule,
167163
An executable that can be loaded by virtual machine.
168164
lib: tvm.runtime.Module
169165
A runtime module that contains generated code.
166+
167+
Example
168+
-------
169+
170+
.. code-block:: python
171+
class InputModule:
172+
@R.function
173+
def foo(x: Tensor[(3, 4), "float32"], y: Tensor[(3, 4), "float32"]):
174+
z = R.add(x, y)
175+
return z
176+
177+
mod = InputModule
178+
target = tvm.target.Target("llvm", host="llvm")
179+
ex, lib = relax.vm.build(mod, target)
170180
"""
171181
passes = [relax.transform.ToNonDataflow()]
172182
passes.append(relax.transform.CallDPSRewrite())
@@ -178,10 +188,11 @@ def build(mod: tvm.IRModule,
178188
# split primfunc and relax function
179189
rx_mod, tir_mod = _split_tir_relax(new_mod)
180190

181-
lib = tvm.build(tir_mod, target, target_host)
191+
lib = tvm.build(tir_mod, target)
182192
ex = _ffi_api.VMCodeGen(rx_mod)
183193
return ex, lib
184194

195+
185196
def _split_tir_relax(mod: tvm.IRModule) -> Tuple[tvm.IRModule, tvm.IRModule]:
186197
rx_mod = IRModule({})
187198
tir_mod = IRModule({})

tests/python/relax/test_vm.py

Lines changed: 20 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -238,9 +238,8 @@ def foo(x: Tensor[(3, 4), "float32"], y: Tensor[(3, 4), "float32"]):
238238
return y
239239

240240
mod = TestVMCompileStage0
241-
target = tvm.target.Target("llvm")
242-
target_host = tvm.target.Target("llvm")
243-
ex, lib = relax.vm.build(mod, target, target_host)
241+
target = tvm.target.Target("llvm", host="llvm")
242+
ex, lib = relax.vm.build(mod, target)
244243
inp1 = tvm.nd.array(np.random.rand(3,4).astype(np.float32))
245244
inp2 = tvm.nd.array(np.random.rand(3,4).astype(np.float32))
246245
vm = relax.VirtualMachine(ex, tvm.cpu(), mod=lib)
@@ -283,10 +282,8 @@ def foo(x: Tensor[_, "float32"]) -> Shape:
283282
return gv3
284283

285284
mod = TestVMCompileStage1
286-
code = R.parser.astext(mod)
287-
target = tvm.target.Target("llvm")
288-
target_host = tvm.target.Target("llvm")
289-
ex, lib = relax.vm.build(mod, target, target_host)
285+
target = tvm.target.Target("llvm", host="llvm")
286+
ex, lib = relax.vm.build(mod, target)
290287
vm = relax.VirtualMachine(ex, tvm.cpu(), mod=lib)
291288

292289
shape = (32, 16)
@@ -305,9 +302,8 @@ def foo(x: Tensor[_, "float32"]) -> Shape:
305302
return (n * 2, m * 3)
306303

307304
mod = TestVMCompileStage2
308-
target = tvm.target.Target("llvm")
309-
target_host = tvm.target.Target("llvm")
310-
ex, lib = relax.vm.build(mod, target, target_host)
305+
target = tvm.target.Target("llvm", host="llvm")
306+
ex, lib = relax.vm.build(mod, target)
311307
vm = relax.VirtualMachine(ex, tvm.cpu(), mod=lib)
312308

313309
shape = (32, 16)
@@ -328,9 +324,8 @@ def foo(x: Tensor[(32, 16), "float32"]) -> Tensor:
328324
return y
329325

330326
mod = TestVMCompileStage3
331-
target = tvm.target.Target("llvm")
332-
target_host = tvm.target.Target("llvm")
333-
ex, lib = relax.vm.build(mod, target, target_host)
327+
target = tvm.target.Target("llvm", host="llvm")
328+
ex, lib = relax.vm.build(mod, target)
334329
vm = relax.VirtualMachine(ex, tvm.cpu(), mod=lib)
335330

336331
shape = (32, 16)
@@ -352,9 +347,8 @@ def foo(x: Tensor[_, "float32"]) -> Tensor:
352347

353348
mod = TestVMCompileE2E
354349

355-
target = tvm.target.Target("llvm")
356-
target_host = tvm.target.Target("llvm")
357-
ex, lib = relax.vm.build(mod, target, target_host)
350+
target = tvm.target.Target("llvm", host="llvm")
351+
ex, lib = relax.vm.build(mod, target)
358352
vm = relax.VirtualMachine(ex, tvm.cpu(), mod=lib)
359353

360354
shape = (32, 16)
@@ -390,9 +384,8 @@ def func(x:Tensor[(m, n), "float32"], w:Tensor[(n, k), "float32"]) -> Tensor:
390384

391385
mod = TestVMCompileE2E2
392386

393-
target = tvm.target.Target("llvm")
394-
target_host = tvm.target.Target("llvm")
395-
ex, lib = relax.vm.build(mod, target, target_host)
387+
target = tvm.target.Target("llvm", host="llvm")
388+
ex, lib = relax.vm.build(mod, target)
396389
vm = relax.VirtualMachine(ex, tvm.cpu(), mod=lib)
397390

398391
data = tvm.nd.array(np.random.rand(32, 16).astype(np.float32))
@@ -415,9 +408,8 @@ def test_vm_emit_te_extern():
415408

416409
mod = bb.get()
417410

418-
target = tvm.target.Target("llvm")
419-
target_host = tvm.target.Target("llvm")
420-
ex, lib = relax.vm.build(mod, target, target_host)
411+
target = tvm.target.Target("llvm", host="llvm")
412+
ex, lib = relax.vm.build(mod, target)
421413
vm = relax.VirtualMachine(ex, tvm.cpu(), mod=lib)
422414

423415
data = tvm.nd.array(np.random.rand(16, 32).astype(np.float32))
@@ -444,9 +436,8 @@ def te_func(A, B):
444436

445437
mod = bb.get()
446438

447-
target = tvm.target.Target("llvm")
448-
target_host = tvm.target.Target("llvm")
449-
ex, lib = relax.vm.build(mod, target, target_host)
439+
target = tvm.target.Target("llvm", host="llvm")
440+
ex, lib = relax.vm.build(mod, target)
450441

451442
vm = relax.VirtualMachine(ex, tvm.cpu(), mod=lib)
452443
inp = tvm.nd.array(np.random.rand(1, ).astype(np.float32))
@@ -471,9 +462,8 @@ def te_func(A):
471462

472463
mod = bb.get()
473464

474-
target = tvm.target.Target("llvm")
475-
target_host = tvm.target.Target("llvm")
476-
ex, lib = relax.vm.build(mod, target, target_host)
465+
target = tvm.target.Target("llvm", host="llvm")
466+
ex, lib = relax.vm.build(mod, target)
477467

478468
vm = relax.VirtualMachine(ex, tvm.cpu(), mod=lib)
479469
shape = (9, )
@@ -503,9 +493,8 @@ def te_func(A, B):
503493

504494
mod = bb.get()
505495

506-
target = tvm.target.Target("llvm")
507-
target_host = tvm.target.Target("llvm")
508-
ex, lib = relax.vm.build(mod, target, target_host)
496+
target = tvm.target.Target("llvm", host="llvm")
497+
ex, lib = relax.vm.build(mod, target)
509498

510499
vm = relax.VirtualMachine(ex, tvm.cpu(), mod=lib)
511500
shape1 = (5, )

0 commit comments

Comments
 (0)