Skip to content

Commit

Permalink
Refactor ray internal_kv for supporting client mode. (ray-project#77)
Browse files Browse the repository at this point in the history
This PR we refactor the usage of Ray internal kv for better supporting Ray client mode.
In Ray client mode, we create an actor for every party to proxy the kv operators.

## Related issues
ray-project#75

---------

Signed-off-by: Qing Wang <kingchin1218@gmail.com>
  • Loading branch information
jovany-wang authored Mar 8, 2023
1 parent 2f8b197 commit 76d96de
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 23 deletions.
103 changes: 102 additions & 1 deletion fed/_private/compatible_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import abc
import ray
import fed._private.constants as fed_constants

import ray.experimental.internal_kv as ray_internal_kv
from ray._private.gcs_utils import GcsClient


def _compare_version_strings(version1, version2):
"""
Expand Down Expand Up @@ -53,10 +57,107 @@ def init_ray(address: str = None, **kwargs):
ray.init(address=address, **kwargs)


def get_gcs_address_from_ray_worker():
def _get_gcs_address_from_ray_worker():
"""A compatible API to get the gcs address from the ray worker module.
"""
try:
return ray._private.worker._global_node.gcs_address
except AttributeError:
return ray.worker._global_node.gcs_address


class AbstractInternalKv(abc.ABC):
""" An abstract class that represents for bridging Ray internal kv in
both Ray client mode and non Ray client mode.
"""
def __init__(self) -> None:
pass

@abc.abstractmethod
def initialize(self):
pass

@abc.abstractmethod
def put(self, k, v):
pass

@abc.abstractmethod
def get(self, k):
pass

@abc.abstractmethod
def delete(self, k):
pass

@abc.abstractmethod
def reset(self):
pass


class InternalKv(AbstractInternalKv):
"""The internal kv class for non Ray client mode.
"""
def __init__(self) -> None:
super().__init__()

def initialize(self):
gcs_client = GcsClient(
address=_get_gcs_address_from_ray_worker(),
nums_reconnect_retry=10)
return ray_internal_kv._initialize_internal_kv(gcs_client)

def put(self, k, v):
return ray_internal_kv._internal_kv_put(k, v)

def get(self, k):
return ray_internal_kv._internal_kv_get(k)

def delete(self, k):
return ray_internal_kv._internal_kv_del(k)

def reset(self):
return ray_internal_kv._internal_kv_reset()

def _ping(self):
return "pong"


class ClientModeInternalKv(AbstractInternalKv):
"""The internal kv class for Ray client mode.
"""
def __init__(self) -> None:
super().__init__()
self._internal_kv_actor = ray.get_actor("_INTERNAL_KV_ACTOR")

def initialize(self):
o = self._internal_kv_actor.initialize.remote()
return ray.get(o)

def put(self, k, v):
o = self._internal_kv_actor.put.remote(k, v)
return ray.get(o)

def get(self, k):
o = self._internal_kv_actor.get.remote(k)
return ray.get(o)

def delete(self, k):
o = self._internal_kv_actor.delete.remote(k)
return ray.get(o)

def reset(self):
o = self._internal_kv_actor.reset.remote()
return ray.get(o)


def _init_internal_kv():
"""An internal API that initialize the internal kv object."""
from ray._private.client_mode_hook import is_client_mode_enabled
if is_client_mode_enabled:
kv_actor = ray.remote(InternalKv).options(name="_INTERNAL_KV_ACTOR").remote()
response = kv_actor._ping.remote()
ray.get(response)
return ClientModeInternalKv() if is_client_mode_enabled else InternalKv()


kv = _init_internal_kv()
4 changes: 2 additions & 2 deletions fed/_private/serialization_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import io
import cloudpickle

import ray.experimental.internal_kv as internal_kv
import fed._private.compatible_utils as compatible_utils

_pickle_whitelist = None

Expand Down Expand Up @@ -63,7 +63,7 @@ def _apply_loads_function_with_whitelist():
global _pickle_whitelist

from fed._private.constants import RAYFED_CROSS_SILO_SERIALIZING_ALLOWED_LIST
serialized = internal_kv._internal_kv_get(
serialized = compatible_utils.kv.get(
RAYFED_CROSS_SILO_SERIALIZING_ALLOWED_LIST)
if serialized is None:
return
Expand Down
32 changes: 14 additions & 18 deletions fed/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@

import cloudpickle
import ray
import ray.experimental.internal_kv as internal_kv
from ray._private.gcs_utils import GcsClient
import fed.utils as fed_utils
import fed._private.compatible_utils as compatible_utils

Expand Down Expand Up @@ -160,14 +158,12 @@ def init(
), 'Cert or key are not in tls_config.'
# A Ray private accessing, should be replaced in public API.

gcs_client = GcsClient(
address=compatible_utils.get_gcs_address_from_ray_worker(),
nums_reconnect_retry=10)
internal_kv._initialize_internal_kv(gcs_client)
internal_kv._internal_kv_put(RAYFED_CLUSTER_KEY, cloudpickle.dumps(cluster))
internal_kv._internal_kv_put(RAYFED_PARTY_KEY, cloudpickle.dumps(party))
internal_kv._internal_kv_put(RAYFED_TLS_CONFIG, cloudpickle.dumps(tls_config))
internal_kv._internal_kv_put(
compatible_utils._init_internal_kv()
compatible_utils.kv.initialize()
compatible_utils.kv.put(RAYFED_CLUSTER_KEY, cloudpickle.dumps(cluster))
compatible_utils.kv.put(RAYFED_PARTY_KEY, cloudpickle.dumps(party))
compatible_utils.kv.put(RAYFED_TLS_CONFIG, cloudpickle.dumps(tls_config))
compatible_utils.kv.put(
RAYFED_CROSS_SILO_SERIALIZING_ALLOWED_LIST,
cloudpickle.dumps(cross_silo_serializing_allowed_list),
)
Expand Down Expand Up @@ -207,11 +203,11 @@ def shutdown():
Shutdown a RayFed client.
"""
wait_sending()
internal_kv._internal_kv_del(RAYFED_CLUSTER_KEY)
internal_kv._internal_kv_del(RAYFED_PARTY_KEY)
internal_kv._internal_kv_del(RAYFED_TLS_CONFIG)
internal_kv._internal_kv_del(RAYFED_CROSS_SILO_SERIALIZING_ALLOWED_LIST)
internal_kv._internal_kv_reset()
compatible_utils.kv.delete(RAYFED_CLUSTER_KEY)
compatible_utils.kv.delete(RAYFED_PARTY_KEY)
compatible_utils.kv.delete(RAYFED_TLS_CONFIG)
compatible_utils.kv.delete(RAYFED_CROSS_SILO_SERIALIZING_ALLOWED_LIST)
compatible_utils.kv.reset()
ray.shutdown()
logger.info('Shutdowned ray.')

Expand All @@ -221,23 +217,23 @@ def get_cluster():
Get the RayFed cluster configration.
"""
# TODO(qwang): These getter could be cached in local.
serialized = internal_kv._internal_kv_get(RAYFED_CLUSTER_KEY)
serialized = compatible_utils.kv.get(RAYFED_CLUSTER_KEY)
return cloudpickle.loads(serialized)


def get_party():
"""
Get the current party name.
"""
serialized = internal_kv._internal_kv_get(RAYFED_PARTY_KEY)
serialized = compatible_utils.kv.get(RAYFED_PARTY_KEY)
return cloudpickle.loads(serialized)


def get_tls():
"""
Get the tls configurations on this party.
"""
serialized = internal_kv._internal_kv_get(RAYFED_TLS_CONFIG)
serialized = compatible_utils.kv.get(RAYFED_TLS_CONFIG)
return cloudpickle.loads(serialized)


Expand Down
4 changes: 2 additions & 2 deletions tests/test_transport_proxy_tls.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import ray
import cloudpickle
import ray.experimental.internal_kv as internal_kv

import fed._private.compatible_utils as compatible_utils

from fed.barriers import RecverProxyActor, send, start_send_proxy
Expand All @@ -40,7 +40,7 @@ def test_n_to_1_transport():
"cert": os.path.join(cert_dir, "server.crt"),
"key": os.path.join(cert_dir, "server.key"),
}
internal_kv._internal_kv_put(RAYFED_TLS_CONFIG, cloudpickle.dumps(tls_config))
compatible_utils.kv.put(RAYFED_TLS_CONFIG, cloudpickle.dumps(tls_config))

NUM_DATA = 10
SERVER_ADDRESS = "127.0.0.1:65422"
Expand Down

0 comments on commit 76d96de

Please sign in to comment.