Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/tvm/runtime/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,4 @@
)

from . import executor
from . import disco
24 changes: 23 additions & 1 deletion python/tvm/runtime/disco/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

import numpy as np

from ..._ffi import register_object
from ..._ffi import register_object, register_func
from ..._ffi.runtime_ctypes import Device
from ..container import ShapeTuple
from ..ndarray import NDArray
Expand Down Expand Up @@ -153,6 +153,23 @@ def get_global_func(self, name: str) -> DRef:
"""
return DPackedFunc(_ffi_api.SessionGetGlobalFunc(self, name), self) # type: ignore # pylint: disable=no-member

def import_python_module(self, module_name: str) -> None:
"""Import a python module in each worker

This may be required before call

Parameters
----------
module_name: str

The python module name, as it would be used in a python
`import` statement.
"""
if not hasattr(self, "_import_python_module"):
self._import_python_module = self.get_global_func("runtime.disco._import_python_module")

self._import_python_module(module_name)

def call_packed(self, func: DRef, *args) -> DRef:
"""Call a PackedFunc on workers providing variadic arguments.

Expand Down Expand Up @@ -369,6 +386,11 @@ def __init__(self, num_workers: int, entrypoint: str) -> None:
)


@register_func("runtime.disco._import_python_module")
def _import_python_module(module_name: str) -> None:
__import__(module_name)


REDUCE_OPS = {
"sum": 0,
"prod": 1,
Expand Down