forked from ray-project/ray
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathhead.py
253 lines (219 loc) · 9.03 KB
/
head.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
import os
import asyncio
import logging
import threading
from concurrent.futures import Future
from queue import Queue
import grpc
try:
from grpc import aio as aiogrpc
except ImportError:
from grpc.experimental import aio as aiogrpc
import ray.experimental.internal_kv as internal_kv
import ray._private.utils
from ray._private.gcs_utils import GcsClient
import ray._private.services
import ray.dashboard.consts as dashboard_consts
import ray.dashboard.utils as dashboard_utils
from ray import ray_constants
from ray._private.gcs_pubsub import (
GcsAioErrorSubscriber,
GcsAioLogSubscriber,
)
from ray.core.generated import gcs_service_pb2
from ray.core.generated import gcs_service_pb2_grpc
from ray.dashboard.datacenter import DataOrganizer
from ray.dashboard.utils import async_loop_forever
logger = logging.getLogger(__name__)
aiogrpc.init_grpc_aio()
GRPC_CHANNEL_OPTIONS = (
("grpc.enable_http_proxy", 0),
("grpc.max_send_message_length", ray_constants.GRPC_CPP_MAX_MESSAGE_SIZE),
("grpc.max_receive_message_length", ray_constants.GRPC_CPP_MAX_MESSAGE_SIZE),
)
class GCSHealthCheckThread(threading.Thread):
def __init__(self, gcs_address: str):
self.grpc_gcs_channel = ray._private.utils.init_grpc_channel(
gcs_address, options=GRPC_CHANNEL_OPTIONS
)
self.gcs_heartbeat_info_stub = gcs_service_pb2_grpc.HeartbeatInfoGcsServiceStub(
self.grpc_gcs_channel
)
self.work_queue = Queue()
super().__init__(daemon=True)
def run(self) -> None:
while True:
future = self.work_queue.get()
check_result = self._check_once_synchrounously()
future.set_result(check_result)
def _check_once_synchrounously(self) -> bool:
request = gcs_service_pb2.CheckAliveRequest()
try:
reply = self.gcs_heartbeat_info_stub.CheckAlive(
request, timeout=dashboard_consts.GCS_CHECK_ALIVE_RPC_TIMEOUT
)
if reply.status.code != 0:
logger.exception(f"Failed to CheckAlive: {reply.status.message}")
return False
except grpc.RpcError: # Deadline Exceeded
logger.exception("Got RpcError when checking GCS is alive")
return False
return True
async def check_once(self) -> bool:
"""Ask the thread to perform a healthcheck."""
assert (
threading.current_thread != self
), "caller shouldn't be from the same thread as GCSHealthCheckThread."
future = Future()
self.work_queue.put(future)
return await asyncio.wrap_future(future)
class DashboardHead:
def __init__(
self,
http_host,
http_port,
http_port_retries,
gcs_address,
log_dir,
temp_dir,
session_dir,
minimal,
):
self.minimal = minimal
self.health_check_thread: GCSHealthCheckThread = None
self._gcs_rpc_error_counter = 0
# Public attributes are accessible for all head modules.
# Walkaround for issue: https://github.com/ray-project/ray/issues/7084
self.http_host = "127.0.0.1" if http_host == "localhost" else http_host
self.http_port = http_port
self.http_port_retries = http_port_retries
self.gcs_address = None
assert gcs_address is not None
self.gcs_address = gcs_address
self.log_dir = log_dir
self.temp_dir = temp_dir
self.session_dir = session_dir
self.aiogrpc_gcs_channel = None
self.gcs_error_subscriber = None
self.gcs_log_subscriber = None
self.ip = ray.util.get_node_ip_address()
ip, port = gcs_address.split(":")
self.server = aiogrpc.server(options=(("grpc.so_reuseport", 0),))
grpc_ip = "127.0.0.1" if self.ip == "127.0.0.1" else "0.0.0.0"
self.grpc_port = ray._private.tls_utils.add_port_to_grpc_server(
self.server, f"{grpc_ip}:0"
)
logger.info("Dashboard head grpc address: %s:%s", grpc_ip, self.grpc_port)
# If the dashboard is started as non-minimal version, http server should
# be configured to expose APIs.
self.http_server = None
async def _configure_http_server(self, modules):
from ray.dashboard.http_server_head import HttpServerDashboardHead
http_server = HttpServerDashboardHead(
self.ip, self.http_host, self.http_port, self.http_port_retries
)
await http_server.run(modules)
return http_server
@property
def http_session(self):
assert self.http_server, "Accessing unsupported API in a minimal ray."
return self.http_server.http_session
@async_loop_forever(dashboard_consts.GCS_CHECK_ALIVE_INTERVAL_SECONDS)
async def _gcs_check_alive(self):
check_future = self.health_check_thread.check_once()
# NOTE(simon): making sure the check procedure doesn't timeout itself.
# Otherwise, the dashboard will always think that gcs is alive.
try:
is_alive = await asyncio.wait_for(
check_future, dashboard_consts.GCS_CHECK_ALIVE_RPC_TIMEOUT + 1
)
except asyncio.TimeoutError:
logger.error("Failed to check gcs health, client timed out.")
is_alive = False
if is_alive:
self._gcs_rpc_error_counter = 0
else:
self._gcs_rpc_error_counter += 1
if (
self._gcs_rpc_error_counter
> dashboard_consts.GCS_CHECK_ALIVE_MAX_COUNT_OF_RPC_ERROR
):
logger.error(
"Dashboard exiting because it received too many GCS RPC "
"errors count: %s, threshold is %s.",
self._gcs_rpc_error_counter,
dashboard_consts.GCS_CHECK_ALIVE_MAX_COUNT_OF_RPC_ERROR,
)
# TODO(fyrestone): Do not use ray.state in
# PrometheusServiceDiscoveryWriter.
# Currently, we use os._exit() here to avoid hanging at the ray
# shutdown(). Please refer to:
# https://github.com/ray-project/ray/issues/16328
os._exit(-1)
def _load_modules(self):
"""Load dashboard head modules."""
modules = []
head_cls_list = dashboard_utils.get_all_modules(
dashboard_utils.DashboardHeadModule
)
for cls in head_cls_list:
logger.info(
"Loading %s: %s", dashboard_utils.DashboardHeadModule.__name__, cls
)
c = cls(self)
modules.append(c)
logger.info("Loaded %d modules.", len(modules))
return modules
async def run(self):
gcs_address = self.gcs_address
# Dashboard will handle connection failure automatically
self.gcs_client = GcsClient(address=gcs_address, nums_reconnect_retry=0)
internal_kv._initialize_internal_kv(self.gcs_client)
self.aiogrpc_gcs_channel = ray._private.utils.init_grpc_channel(
gcs_address, GRPC_CHANNEL_OPTIONS, asynchronous=True
)
self.gcs_error_subscriber = GcsAioErrorSubscriber(address=gcs_address)
self.gcs_log_subscriber = GcsAioLogSubscriber(address=gcs_address)
await self.gcs_error_subscriber.subscribe()
await self.gcs_log_subscriber.subscribe()
self.health_check_thread = GCSHealthCheckThread(gcs_address)
self.health_check_thread.start()
# Start a grpc asyncio server.
await self.server.start()
async def _async_notify():
"""Notify signals from queue."""
while True:
co = await dashboard_utils.NotifyQueue.get()
try:
await co
except Exception:
logger.exception(f"Error notifying coroutine {co}")
modules = self._load_modules()
http_host, http_port = self.http_host, self.http_port
if not self.minimal:
self.http_server = await self._configure_http_server(modules)
http_host, http_port = self.http_server.get_address()
internal_kv._internal_kv_put(
ray_constants.DASHBOARD_ADDRESS,
f"{http_host}:{http_port}",
namespace=ray_constants.KV_NAMESPACE_DASHBOARD,
)
# TODO: Use async version if performance is an issue
# Write the dashboard head port to gcs kv.
internal_kv._internal_kv_put(
dashboard_consts.DASHBOARD_RPC_ADDRESS,
f"{self.ip}:{self.grpc_port}",
namespace=ray_constants.KV_NAMESPACE_DASHBOARD,
)
# Freeze signal after all modules loaded.
dashboard_utils.SignalManager.freeze()
concurrent_tasks = [
self._gcs_check_alive(),
_async_notify(),
DataOrganizer.purge(),
DataOrganizer.organize(),
]
await asyncio.gather(*concurrent_tasks, *(m.run(self.server) for m in modules))
await self.server.wait_for_termination()
if self.http_server:
await self.http_server.cleanup()