Skip to content

Commit

Permalink
[Jobs] Make sure JobConfig specified in the init script is not being …
Browse files Browse the repository at this point in the history
…overridden by the SupervisorActor (ray-project#38676)

NOTE: This is a revert of the revert, addressing broken tests

Currently when executing the job submitted via Ray Job submission API, SupervisorActor will be overriding whole of the JobConfig making it impossible to specify additional params like JobConfig.code_search_path for ex.

---------

Signed-off-by: Alexey Kudinkin <ak@anyscale.com>
  • Loading branch information
alexeykudinkin authored Aug 23, 2023
1 parent f3988eb commit e063bd7
Show file tree
Hide file tree
Showing 7 changed files with 181 additions and 31 deletions.
35 changes: 35 additions & 0 deletions dashboard/modules/job/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import os

import ray
from ray._private.gcs_utils import GcsAioClient
from ray.dashboard.modules.job.job_manager import (
JobManager,
)


TEST_NAMESPACE = "jobs_test_namespace"


def create_ray_cluster(_tracing_startup_hook=None):
return ray.init(
num_cpus=16,
num_gpus=1,
resources={"Custom": 1},
namespace=TEST_NAMESPACE,
log_to_driver=True,
_tracing_startup_hook=_tracing_startup_hook,
)


def create_job_manager(ray_cluster, tmp_path):
address_info = ray_cluster
gcs_aio_client = GcsAioClient(
address=address_info["gcs_address"], nums_reconnect_retry=0
)
return JobManager(gcs_aio_client, tmp_path)


def _driver_script_path(file_name: str) -> str:
return os.path.join(
os.path.dirname(__file__), "subprocess_driver_scripts", file_name
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""
A dummy ray driver script that executes in subprocess.
Prints global worker's `load_code_from_local` property that ought to be set
whenever `JobConfig.code_search_path` is specified
"""


def run():
import ray
from ray.job_config import JobConfig

ray.init(job_config=JobConfig(code_search_path=["/home/code/"]))

@ray.remote
def foo() -> bool:
return ray._private.worker.global_worker.load_code_from_local

load_code_from_local = ray.get(foo.remote())

statement = "propagated" if load_code_from_local else "NOT propagated"

# Step 1: Print the statement indicating that the code_search_path have been
# properly respected
print(f"Code search path is {statement}")
# Step 2: Print the whole runtime_env to validate that it's been passed
# appropriately from submit_job API
print(ray.get_runtime_context().runtime_env)


if __name__ == "__main__":
run()
25 changes: 7 additions & 18 deletions dashboard/modules/job/tests/test_job_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@
RAY_JOB_ALLOW_DRIVER_ON_WORKER_NODES_ENV_VAR,
RAY_JOB_START_TIMEOUT_SECONDS_ENV_VAR,
)
from ray.dashboard.modules.job.tests.conftest import (
create_ray_cluster,
create_job_manager,
_driver_script_path,
)
from ray.job_submission import JobStatus
from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy # noqa: F401
from ray.tests.conftest import call_ray_start # noqa: F401
Expand Down Expand Up @@ -256,13 +261,7 @@ def shared_ray_instance():
# submissions.
old_ray_address = os.environ.pop(RAY_ADDRESS_ENVIRONMENT_VARIABLE, None)

yield ray.init(
num_cpus=16,
num_gpus=1,
resources={"Custom": 1},
namespace=TEST_NAMESPACE,
log_to_driver=True,
)
yield create_ray_cluster()

if old_ray_address is not None:
os.environ[RAY_ADDRESS_ENVIRONMENT_VARIABLE] = old_ray_address
Expand All @@ -271,17 +270,7 @@ def shared_ray_instance():
@pytest.mark.asyncio
@pytest.fixture
async def job_manager(shared_ray_instance, tmp_path):
address_info = shared_ray_instance
gcs_aio_client = GcsAioClient(
address=address_info["gcs_address"], nums_reconnect_retry=0
)
yield JobManager(gcs_aio_client, tmp_path)


def _driver_script_path(file_name: str) -> str:
return os.path.join(
os.path.dirname(__file__), "subprocess_driver_scripts", file_name
)
yield create_job_manager(shared_ray_instance, tmp_path)


async def _run_hanging_command(job_manager, tmp_dir, start_signal_actor=None):
Expand Down
74 changes: 74 additions & 0 deletions dashboard/modules/job/tests/test_job_manager_standalone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import pytest
import sys

from ray._private.test_utils import (
async_wait_for_condition_async_predicate,
)
from ray.dashboard.modules.job.tests.conftest import (
create_ray_cluster,
create_job_manager,
_driver_script_path,
)
from ray.dashboard.modules.job.tests.test_job_manager import check_job_succeeded


@pytest.mark.asyncio
class TestRuntimeEnvStandalone:
"""NOTE: PLEASE READ CAREFULLY BEFORE MODIFYING
This test is extracted into a standalone module such that it can bootstrap its own
(standalone) Ray cluster while avoiding affecting the shared one used by other
JobManager tests
"""

@pytest.mark.parametrize(
"tracing_enabled",
[
False,
# TODO(issues/38633): local code loading is broken when tracing is enabled
# True,
],
)
async def test_user_provided_job_config_honored_by_worker(
self, tracing_enabled, tmp_path
):
"""Ensures that the JobConfig instance injected into ray.init in the driver
script is honored even in case when job is submitted via JobManager.submit_job
API (involving RAY_JOB_CONFIG_JSON_ENV_VAR being set in child process env)
"""

if tracing_enabled:
tracing_startup_hook = (
"ray.util.tracing.setup_local_tmp_tracing:setup_tracing"
)
else:
tracing_startup_hook = None

with create_ray_cluster(_tracing_startup_hook=tracing_startup_hook) as cluster:
job_manager = create_job_manager(cluster, tmp_path)

driver_script_path = _driver_script_path(
"check_code_search_path_is_propagated.py"
)

job_id = await job_manager.submit_job(
entrypoint=f"python {driver_script_path}",
# NOTE: We inject runtime_env in here, but also specify the JobConfig in
# the driver script: settings to JobConfig (other than the
# runtime_env) passed in via ray.init(...) have to be respected
# along with the runtime_env passed from submit_job API
runtime_env={"env_vars": {"TEST_SUBPROCESS_RANDOM_VAR": "0xDEEDDEED"}},
)

await async_wait_for_condition_async_predicate(
check_job_succeeded, job_manager=job_manager, job_id=job_id
)

logs = job_manager.get_job_logs(job_id)

assert "Code search path is propagated" in logs, logs
assert "0xDEEDDEED" in logs, logs


if __name__ == "__main__":
sys.exit(pytest.main(["-v", __file__]))
29 changes: 20 additions & 9 deletions python/ray/_private/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1408,6 +1408,9 @@ def init(
logger.debug("Could not import resource module (on Windows)")
pass

if job_config is None:
job_config = ray.job_config.JobConfig()

if RAY_JOB_CONFIG_JSON_ENV_VAR in os.environ:
if runtime_env:
logger.warning(
Expand All @@ -1416,16 +1419,26 @@ def init(
"job_config. Please ensure no runtime_env is used in driver "
"script's ray.init() when using job submission API."
)
# Set runtime_env in job_config if passed as env variable, such as
# ray job submission with driver script executed in subprocess
job_config_json = json.loads(os.environ.get(RAY_JOB_CONFIG_JSON_ENV_VAR))
job_config = ray.job_config.JobConfig.from_json(job_config_json)
injected_job_config_json = json.loads(
os.environ.get(RAY_JOB_CONFIG_JSON_ENV_VAR)
)
injected_job_config: ray.job_config.JobConfig = (
ray.job_config.JobConfig.from_json(injected_job_config_json)
)
# NOTE: We always prefer runtime_env injected via RAY_JOB_CONFIG_JSON_ENV_VAR,
# as compared to via ray.init(runtime_env=...) to make sure runtime_env
# specified via job submission API takes precedence
runtime_env = injected_job_config.runtime_env

if ray_constants.RAY_RUNTIME_ENV_HOOK in os.environ and not _skip_env_hook:
runtime_env = _load_class(os.environ[ray_constants.RAY_RUNTIME_ENV_HOOK])(
job_config.runtime_env
runtime_env
)
job_config.set_runtime_env(runtime_env)

job_config.set_runtime_env(runtime_env)
# Similarly, we prefer metadata provided via job submission API
for key, value in injected_job_config.metadata.items():
job_config.set_metadata(key, value)

# RAY_JOB_CONFIG_JSON_ENV_VAR is only set at ray job manager level and has
# higher priority in case user also provided runtime_env for ray.init()
Expand All @@ -1437,8 +1450,6 @@ def init(

if runtime_env:
# Set runtime_env in job_config if passed in as part of ray.init()
if job_config is None:
job_config = ray.job_config.JobConfig()
job_config.set_runtime_env(runtime_env)

redis_address, gcs_address = None, None
Expand Down Expand Up @@ -2333,7 +2344,7 @@ def connect(
b"tracing_startup_hook", ray_constants.KV_NAMESPACE_TRACING
)
if tracing_hook_val is not None:
ray.util.tracing.tracing_helper._enbale_tracing()
ray.util.tracing.tracing_helper._enable_tracing()
if not getattr(ray, "__traced__", False):
_setup_tracing = _import_from_string(tracing_hook_val.decode("utf-8"))
_setup_tracing()
Expand Down
11 changes: 8 additions & 3 deletions python/ray/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,11 +589,16 @@ def call_ray_start_context(request):
except Exception as e:
print(type(e), e)
raise

# Get the redis address from the output.
redis_substring_prefix = "--address='"
address_location = out.find(redis_substring_prefix) + len(redis_substring_prefix)
address = out[address_location:]
address = address.split("'")[0]
idx = out.find(redis_substring_prefix)
if idx >= 0:
address_location = idx + len(redis_substring_prefix)
address = out[address_location:]
address = address.split("'")[0]
else:
address = None

yield address

Expand Down
7 changes: 6 additions & 1 deletion python/ray/util/tracing/tracing_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def _is_tracing_enabled() -> bool:
return _global_is_tracing_enabled


def _enbale_tracing():
def _enable_tracing():
global _global_is_tracing_enabled, _opentelemetry
_global_is_tracing_enabled = True
_opentelemetry = _OpenTelemetryProxy()
Expand Down Expand Up @@ -340,6 +340,11 @@ def _inject_tracing_into_function(function):
),
)

# Skip wrapping if tracing is disabled (still add _ray_trace_ctx however to make
# sure _ray_trace_ctx could be passed)
if not _is_tracing_enabled():
return function

@wraps(function)
def _function_with_tracing(
*args: Any,
Expand Down

0 comments on commit e063bd7

Please sign in to comment.