This repository has been archived by the owner on Nov 29, 2023. It is now read-only.
forked from ray-project/ray
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmonitor_test.py
116 lines (96 loc) · 3.59 KB
/
monitor_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import multiprocessing
import os
import pytest
import subprocess
import time
import ray
from ray.test.test_utils import run_and_get_output
def _test_cleanup_on_driver_exit(num_redis_shards):
stdout = run_and_get_output([
"ray",
"start",
"--head",
"--num-redis-shards",
str(num_redis_shards),
])
lines = [m.strip() for m in stdout.split("\n")]
init_cmd = [m for m in lines if m.startswith("ray.init")]
assert 1 == len(init_cmd)
redis_address = init_cmd[0].split("redis_address=\"")[-1][:-2]
def StateSummary():
obj_tbl_len = len(ray.global_state.object_table())
task_tbl_len = len(ray.global_state.task_table())
func_tbl_len = len(ray.global_state.function_table())
return obj_tbl_len, task_tbl_len, func_tbl_len
def Driver(success):
success.value = True
# Start driver.
ray.init(redis_address=redis_address)
summary_start = StateSummary()
if (0, 1) != summary_start[:2]:
success.value = False
max_attempts_before_failing = 100
# Two new objects.
ray.get(ray.put(1111))
ray.get(ray.put(1111))
attempts = 0
while (2, 1, summary_start[2]) != StateSummary():
time.sleep(0.1)
attempts += 1
if attempts == max_attempts_before_failing:
success.value = False
break
@ray.remote
def f():
ray.put(1111) # Yet another object.
return 1111 # A returned object as well.
# 1 new function.
attempts = 0
while (2, 1, summary_start[2] + 1) != StateSummary():
time.sleep(0.1)
attempts += 1
if attempts == max_attempts_before_failing:
success.value = False
break
ray.get(f.remote())
attempts = 0
while (4, 2, summary_start[2] + 1) != StateSummary():
time.sleep(0.1)
attempts += 1
if attempts == max_attempts_before_failing:
success.value = False
break
ray.shutdown()
success = multiprocessing.Value('b', False)
driver = multiprocessing.Process(target=Driver, args=(success, ))
driver.start()
# Wait for client to exit.
driver.join()
time.sleep(3)
# Just make sure Driver() is run and succeeded. Note(rkn), if the below
# assertion starts failing, then the issue may be that the summary
# values computed in the Driver function are being updated slowly and
# so the call to StateSummary() is getting outdated values. This could
# be fixed by looping until StateSummary() returns the desired values.
assert success.value
# Check that objects, tasks, and functions are cleaned up.
ray.init(redis_address=redis_address)
# The assertion below can fail if the monitor is too slow to clean up
# the global state.
assert (0, 1) == StateSummary()[:2]
ray.shutdown()
subprocess.Popen(["ray", "stop"]).wait()
@pytest.mark.skipif(
os.environ.get("RAY_USE_NEW_GCS") == "on",
reason="Hanging with the new GCS API.")
def test_cleanup_on_driver_exit_single_redis_shard():
_test_cleanup_on_driver_exit(num_redis_shards=1)
@pytest.mark.skipif(
os.environ.get("RAY_USE_NEW_GCS") == "on",
reason="Hanging with the new GCS API.")
def test_cleanup_on_driver_exit_many_redis_shards():
_test_cleanup_on_driver_exit(num_redis_shards=5)
_test_cleanup_on_driver_exit(num_redis_shards=31)