forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path_distributed_rpc.pyi
182 lines (171 loc) · 6.04 KB
/
_distributed_rpc.pyi
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
from typing import Any, Dict, List, Optional, Tuple, Union, overload
from datetime import timedelta
import enum
import torch
from torch.types import Device
from . import Future
from ._autograd import ProfilerEvent
from ._distributed_c10d import ProcessGroup, Store
from ._profiler import ActiveProfilerType, ProfilerConfig, ProfilerState
# This module is defined in torch/csrc/distributed/rpc/init.cpp
_DEFAULT_INIT_METHOD: str
_DEFAULT_NUM_WORKER_THREADS: int
_UNSET_RPC_TIMEOUT: float
_DEFAULT_RPC_TIMEOUT_SEC: float
class RpcBackendOptions:
rpc_timeout: float
init_method: str
def __init__(
self,
rpc_timeout: float = _DEFAULT_RPC_TIMEOUT_SEC,
init_method: str = _DEFAULT_INIT_METHOD,
): ...
class WorkerInfo:
def __init__(self, name: str, worker_id: int): ...
@property
def name(self) -> str: ...
@property
def id(self) -> int: ...
def __eq__(self, other: object) -> bool: ...
def __repr__(self) -> str: ...
class RpcAgent:
def join(self, shutdown: bool = False, timeout: float = 0): ...
def sync(self): ...
def shutdown(self): ...
@overload
def get_worker_info(self) -> WorkerInfo: ...
@overload
def get_worker_info(self, workerName: str) -> WorkerInfo: ...
def get_worker_infos(self) -> List[WorkerInfo]: ...
def _get_device_map(self, dst: WorkerInfo) -> Dict[torch.device, torch.device]: ...
def get_debug_info(self) -> Dict[str, str]: ...
def get_metrics(self) -> Dict[str, str]: ...
class PyRRef:
def __init__(self, value: Any, type_hint: Any = None): ...
def is_owner(self) -> bool: ...
def confirmed_by_owner(self) -> bool: ...
def owner(self) -> WorkerInfo: ...
def owner_name(self) -> str: ...
def to_here(self, timeout: float = _UNSET_RPC_TIMEOUT) -> Any: ...
def local_value(self) -> Any: ...
def rpc_sync(self, timeout: float = _UNSET_RPC_TIMEOUT) -> Any: ...
def rpc_async(self, timeout: float = _UNSET_RPC_TIMEOUT) -> Any: ...
def remote(self, timeout: float = _UNSET_RPC_TIMEOUT) -> Any: ...
def _serialize(self) -> Tuple: ...
@staticmethod
def _deserialize(tp: Tuple) -> 'PyRRef': ...
def _get_type(self) -> Any: ...
def _get_future(self) -> Future: ...
def _get_profiling_future(self) -> Future: ...
def _set_profiling_future(self, profilingFuture: Future): ...
def __repr__(self) -> str: ...
...
class _TensorPipeRpcBackendOptionsBase(RpcBackendOptions):
num_worker_threads: int
device_maps: Dict[str, Dict[torch.device, torch.device]]
devices: List[torch.device]
def __init__(
self,
num_worker_threads: int,
_transports: Optional[List],
_channels: Optional[List],
rpc_timeout: float = _DEFAULT_RPC_TIMEOUT_SEC,
init_method: str = _DEFAULT_INIT_METHOD,
device_maps: Dict[str, Dict[torch.device, torch.device]] = {},
devices: List[torch.device] = list()): ...
def _set_device_map(self, to: str, device_map: Dict[torch.device, torch.device]): ...
class TensorPipeAgent(RpcAgent):
def __init__(
self,
store: Store,
name: str,
worker_id: int,
world_size: Optional[int],
opts: _TensorPipeRpcBackendOptionsBase,
reverse_device_maps: Dict[str, Dict[torch.device, torch.device]],
devices: List[torch.device],
): ...
def join(self, shutdown: bool = False, timeout: float = 0): ...
def shutdown(self): ...
@overload
def get_worker_info(self) -> WorkerInfo: ...
@overload
def get_worker_info(self, workerName: str) -> WorkerInfo: ...
@overload
def get_worker_info(self, id: int) -> WorkerInfo: ...
def get_worker_infos(self) -> List[WorkerInfo]: ...
def _get_device_map(self, dst: WorkerInfo) -> Dict[torch.device, torch.device]: ...
def _update_group_membership(
self,
worker_info: WorkerInfo,
my_devices: List[torch.device],
reverse_device_map: Dict[str, Dict[torch.device, torch.device]],
is_join: bool): ...
def _get_backend_options(self) -> _TensorPipeRpcBackendOptionsBase: ...
@property
def is_static_group(self) -> bool: ...
@property
def store(self) -> Store: ...
def _is_current_rpc_agent_set() -> bool: ...
def _get_current_rpc_agent()-> RpcAgent: ...
def _set_and_start_rpc_agent(agent: RpcAgent): ...
def _reset_current_rpc_agent(): ...
def _delete_all_user_and_unforked_owner_rrefs(timeout: timedelta = ...): ...
def _destroy_rref_context(ignoreRRefLeak: bool): ...
def _rref_context_get_debug_info() -> Dict[str, str]: ...
def _cleanup_python_rpc_handler(): ...
def _invoke_rpc_builtin(
dst: WorkerInfo,
opName: str,
rpcTimeoutSeconds: float,
*args: Any,
**kwargs: Any
): ...
def _invoke_rpc_python_udf(
dst: WorkerInfo,
pickledPythonUDF: str,
tensors: List[torch.Tensor],
rpcTimeoutSeconds: float,
isAsyncExecution: bool
): ...
def _invoke_rpc_torchscript(
dstWorkerName: str,
qualifiedNameStr: str,
argsTuple: Tuple,
kwargsDict: Dict,
rpcTimeoutSeconds: float,
isAsyncExecution: bool,
): ...
def _invoke_remote_builtin(
dst: WorkerInfo,
opName: str,
rpcTimeoutSeconds: float,
*args: Any,
**kwargs: Any
): ...
def _invoke_remote_python_udf(
dst: WorkerInfo,
pickledPythonUDF: str,
tensors: List[torch.Tensor],
rpcTimeoutSeconds: float,
isAsyncExecution: bool,
): ...
def _invoke_remote_torchscript(
dstWorkerName: WorkerInfo,
qualifiedNameStr: str,
rpcTimeoutSeconds: float,
isAsyncExecution: bool,
*args: Any,
**kwargs: Any
): ...
def get_rpc_timeout() -> float: ...
def enable_gil_profiling(flag: bool): ...
def _set_rpc_timeout(rpcTimeoutSeconds: float): ...
class RemoteProfilerManager:
@staticmethod
def set_current_profiling_key(key: str): ...
def _enable_server_process_global_profiler(new_config: ProfilerConfig): ...
def _disable_server_process_global_profiler() -> List[List[List[ProfilerEvent]]]: ...
def _set_profiler_node_id(default_node_id: int): ...
def _enable_jit_rref_pickle(): ...
def _disable_jit_rref_pickle(): ...