Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add runtime_context to get some runtime fields in worker #4065

Merged
merged 13 commits into from
Feb 19, 2019
Merged
2 changes: 2 additions & 0 deletions python/ray/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@
# some functions in the worker.
import ray.actor # noqa: F401
from ray.actor import method # noqa: E402
from ray.runtime_context import _get_runtime_context # noqa: E402

# Ray version string.
__version__ = "0.7.0.dev0"
Expand All @@ -103,6 +104,7 @@
"WORKER_MODE",
"__version__",
"_config",
"_get_runtime_context",
"actor",
"connect",
"disconnect",
Expand Down
34 changes: 34 additions & 0 deletions python/ray/runtime_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import ray.worker


class RuntimeContext(object):
"""A class used for getting runtime context."""

def __init__(self, worker=None):
self.worker = worker

@property
def current_driver_id(self):
"""Get current driver ID for this worker or driver.

Returns:
If called by a driver, this returns the driver ID. If called in
a task, return the driver ID of the associated driver.
"""
assert self.worker is not None
return self.worker.task_driver_id


_runtime_context = None


def _get_runtime_context():
global _runtime_context
if _runtime_context is None:
_runtime_context = RuntimeContext(ray.worker.get_global_worker())

return _runtime_context
1 change: 1 addition & 0 deletions python/ray/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
)
from ray import import_thread
from ray import profiling

from ray.core.generated.ErrorType import ErrorType
from ray.exceptions import (
RayActorError,
Expand Down
13 changes: 7 additions & 6 deletions test/runtest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2590,16 +2590,17 @@ def f():

def test_specific_driver_id():
dummy_driver_id = ray.DriverID(b"00112233445566778899")
ray.init(driver_id=dummy_driver_id)
ray.init(num_cpus=1, driver_id=dummy_driver_id)

# in driver
assert dummy_driver_id == ray._get_runtime_context().current_driver_id

# in worker
@ray.remote
def f():
return ray.worker.global_worker.task_driver_id.binary()

assert dummy_driver_id.binary() == ray.worker.global_worker.worker_id
return ray._get_runtime_context().current_driver_id

task_driver_id = ray.get(f.remote())
assert dummy_driver_id.binary() == task_driver_id
assert dummy_driver_id == ray.get(f.remote())

ray.shutdown()

Expand Down