Skip to content

Commit

Permalink
feat: add ReceiverSenderProxy. (ray-project#168)
Browse files Browse the repository at this point in the history
* feat: add ReceiverSenderProxy.

* Fix linter.

* fix max_concurrency.

* Fix mac_concurrency.

* Set default concurrency to 1.
  • Loading branch information
zhouaihui authored Aug 4, 2023
1 parent 270546d commit 60ab909
Show file tree
Hide file tree
Showing 19 changed files with 439 additions and 291 deletions.
2 changes: 1 addition & 1 deletion fed/_private/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

KEY_OF_TLS_CONFIG = "TLS_CONFIG"

KEY_OF_CROSS_SILO_MESSAGE_CONFIG = "CROSS_SILO_MESSAGE_CONFIG"
KEY_OF_CROSS_SILO_COMM_CONFIG_DICT = "CROSS_SILO_COMM_CONFIG_DICT"

RAYFED_LOG_FMT = "%(asctime)s %(levelname)s %(filename)s:%(lineno)s [%(party)s] -- %(message)s" # noqa

Expand Down
13 changes: 9 additions & 4 deletions fed/_private/serialization_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def _restricted_loads(
buffers=None,
):
from sys import version_info

assert version_info.major == 3

if version_info.minor >= 8:
Expand All @@ -41,8 +42,10 @@ class RestrictedUnpickler(pickle.Unpickler):
def find_class(self, module, name):
if _pickle_whitelist is None or (
module in _pickle_whitelist
and (_pickle_whitelist[module] is None or name in _pickle_whitelist[
module])
and (
_pickle_whitelist[module] is None
or name in _pickle_whitelist[module]
)
):
return super().find_class(module, name)

Expand All @@ -63,8 +66,10 @@ def find_class(self, module, name):
def _apply_loads_function_with_whitelist():
global _pickle_whitelist

_pickle_whitelist = fed_config.get_job_config() \
.cross_silo_message_config.serializing_allowed_list
cross_silo_comm_config = fed_config.CrossSiloMessageConfig.from_dict(
fed_config.get_job_config().cross_silo_comm_config_dict
)
_pickle_whitelist = cross_silo_comm_config.serializing_allowed_list
if _pickle_whitelist is None:
return

Expand Down
102 changes: 63 additions & 39 deletions fed/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,12 @@
send,
_start_receiver_proxy,
_start_sender_proxy,
_start_sender_receiver_proxy,
set_receiver_proxy_actor_name,
set_sender_proxy_actor_name,
)
from fed.proxy.grpc.grpc_proxy import SenderProxy, ReceiverProxy
from fed.config import GrpcCrossSiloMessageConfig
from fed.proxy.base_proxy import SenderProxy, ReceiverProxy, SenderReceiverProxy
from fed.config import CrossSiloMessageConfig
from fed.fed_object import FedObject
from fed.utils import is_ray_object_refs, setup_logger

Expand All @@ -50,6 +53,7 @@ def init(
logging_level: str = 'info',
sender_proxy_cls: SenderProxy = None,
receiver_proxy_cls: ReceiverProxy = None,
receiver_sender_proxy_cls: SenderReceiverProxy = None,
):
"""
Initialize a RayFed client.
Expand Down Expand Up @@ -112,9 +116,7 @@ def init(
'cert' in tls_config and 'key' in tls_config
), 'Cert or key are not in tls_config.'

cross_silo_message_dict = config.get("cross_silo_message", {})
cross_silo_message_config = GrpcCrossSiloMessageConfig.from_dict(
cross_silo_message_dict)
cross_silo_comm_dict = config.get("cross_silo_comm", {})
# A Ray private accessing, should be replaced in public API.
compatible_utils._init_internal_kv()

Expand All @@ -125,11 +127,11 @@ def init(
}

job_config = {
constants.KEY_OF_CROSS_SILO_MESSAGE_CONFIG:
cross_silo_message_config,
constants.KEY_OF_CROSS_SILO_COMM_CONFIG_DICT: cross_silo_comm_dict,
}
compatible_utils.kv.put(constants.KEY_OF_CLUSTER_CONFIG,
cloudpickle.dumps(cluster_config))
compatible_utils.kv.put(
constants.KEY_OF_CLUSTER_CONFIG, cloudpickle.dumps(cluster_config)
)
compatible_utils.kv.put(constants.KEY_OF_JOB_CONFIG, cloudpickle.dumps(job_config))
# Set logger.
# Note(NKcqx): This should be called after internal_kv has party value, i.e.
Expand All @@ -143,39 +145,61 @@ def init(
)

logger.info(f'Started rayfed with {cluster_config}')
cross_silo_comm_config = CrossSiloMessageConfig.from_dict(cross_silo_comm_dict)
get_global_context().get_cleanup_manager().start(
exit_when_failure_sending=cross_silo_message_config.exit_on_sending_failure) # noqa

if receiver_proxy_cls is None:
logger.debug(
"There is no receiver proxy class specified, it uses `GrpcRecvProxy` by "
"default.")
from fed.proxy.grpc.grpc_proxy import GrpcReceiverProxy
receiver_proxy_cls = GrpcReceiverProxy
_start_receiver_proxy(
addresses=addresses,
party=party,
logging_level=logging_level,
tls_config=tls_config,
proxy_cls=receiver_proxy_cls,
proxy_config=cross_silo_message_config
exit_on_sending_failure=cross_silo_comm_config.exit_on_sending_failure
)
if receiver_sender_proxy_cls is not None:
proxy_actor_name = 'sender_recevier_actor'
set_sender_proxy_actor_name(proxy_actor_name)
set_receiver_proxy_actor_name(proxy_actor_name)
_start_sender_receiver_proxy(
addresses=addresses,
party=party,
logging_level=logging_level,
tls_config=tls_config,
proxy_cls=receiver_sender_proxy_cls,
proxy_config=cross_silo_comm_dict,
ready_timeout_second=cross_silo_comm_config.timeout_in_ms / 1000,
)
else:
if receiver_proxy_cls is None:
logger.debug(
(
"There is no receiver proxy class specified, "
"it uses `GrpcRecvProxy` by default."
)
)
from fed.proxy.grpc.grpc_proxy import GrpcReceiverProxy

receiver_proxy_cls = GrpcReceiverProxy
_start_receiver_proxy(
addresses=addresses,
party=party,
logging_level=logging_level,
tls_config=tls_config,
proxy_cls=receiver_proxy_cls,
proxy_config=cross_silo_comm_dict,
ready_timeout_second=cross_silo_comm_config.timeout_in_ms / 1000,
)

if sender_proxy_cls is None:
logger.debug(
"There is no sender proxy class specified, it uses `GrpcRecvProxy` by "
"default.")
from fed.proxy.grpc.grpc_proxy import GrpcSenderProxy
sender_proxy_cls = GrpcSenderProxy
_start_sender_proxy(
addresses=addresses,
party=party,
logging_level=logging_level,
tls_config=tls_config,
proxy_cls=sender_proxy_cls,
# TODO(qwang): proxy_config -> cross_silo_message_config
proxy_config=cross_silo_message_config
)
if sender_proxy_cls is None:
logger.debug(
"There is no sender proxy class specified, it uses `GrpcRecvProxy` by "
"default."
)
from fed.proxy.grpc.grpc_proxy import GrpcSenderProxy

sender_proxy_cls = GrpcSenderProxy
_start_sender_proxy(
addresses=addresses,
party=party,
logging_level=logging_level,
tls_config=tls_config,
proxy_cls=sender_proxy_cls,
proxy_config=cross_silo_comm_dict,
ready_timeout_second=cross_silo_comm_config.timeout_in_ms / 1000,
)

if config.get("barrier_on_initializing", False):
# TODO(zhouaihui): can be removed after we have a better retry strategy.
Expand Down
6 changes: 3 additions & 3 deletions fed/cleanup.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ def __init__(self) -> None:
self._check_send_thread = None
self._monitor_thread = None

def start(self, exit_when_failure_sending=False):
self._exit_when_failure_sending = exit_when_failure_sending
def start(self, exit_on_sending_failure=False):
self._exit_on_sending_failure = exit_on_sending_failure

def __check_func():
self._check_sending_objs()
Expand Down Expand Up @@ -98,7 +98,7 @@ def _signal_exit():
except Exception as e:
logger.warn(f'Failed to send {obj_ref} with error: {e}')
res = False
if not res and self._exit_when_failure_sending:
if not res and self._exit_on_sending_failure:
logger.warn('Signal self to exit.')
_signal_exit()
break
Expand Down
24 changes: 10 additions & 14 deletions fed/config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@


"""This module should be cached locally due to all configurations
are mutable.
"""
Expand All @@ -10,11 +8,12 @@
import json

from typing import Dict, List, Optional
from dataclasses import dataclass
from dataclasses import dataclass, fields


class ClusterConfig:
"""A local cache of cluster configuration items."""

def __init__(self, raw_bytes: bytes) -> None:
self._data = cloudpickle.loads(raw_bytes)

Expand All @@ -39,10 +38,8 @@ def __init__(self, raw_bytes: bytes) -> None:
self._data = cloudpickle.loads(raw_bytes)

@property
def cross_silo_message_config(self):
return self._data.get(
fed_constants.KEY_OF_CROSS_SILO_MESSAGE_CONFIG,
CrossSiloMessageConfig())
def cross_silo_comm_config_dict(self) -> Dict:
return self._data.get(fed_constants.KEY_OF_CROSS_SILO_COMM_CONFIG_DICT, {})


# A module level cache for the cluster configurations.
Expand Down Expand Up @@ -103,7 +100,9 @@ class CrossSiloMessageConfig:
http_header: The HTTP header, e.g. metadata in grpc, sent with the RPC request.
This won't override basic tcp headers, such as `user-agent`, but concat
them together.
max_concurrency: the max_concurrency of the sender/receiver proxy actor.
"""

proxy_max_restarts: int = None
timeout_in_ms: int = 60000
messages_max_size_in_bytes: int = None
Expand All @@ -112,9 +111,7 @@ class CrossSiloMessageConfig:
send_resource_label: Optional[Dict[str, str]] = None
recv_resource_label: Optional[Dict[str, str]] = None
http_header: Optional[Dict[str, str]] = None
# (Optional) The address this party is going to listen on.
# If not provided, the `address` will be used.
listening_address: str = None
max_concurrency: Optional[int] = None

def __json__(self):
return json.dumps(self.__dict__)
Expand All @@ -125,7 +122,7 @@ def from_json(cls, json_str):
return cls(**data)

@classmethod
def from_dict(cls, data: Dict):
def from_dict(cls, data: Dict) -> 'CrossSiloMessageConfig':
"""Initialize CrossSiloMessageConfig from a dictionary.
Args:
Expand All @@ -135,10 +132,8 @@ def from_dict(cls, data: Dict):
CrossSiloMessageConfig: An instance of CrossSiloMessageConfig.
"""
# Get the attributes of the class

data = data or {}
all_annotations = {**cls.__annotations__, **cls.__base__.__annotations__}
attrs = {attr for attr, _ in all_annotations.items()}
attrs = [field.name for field in fields(cls)]
# Filter the dictionary to only include keys that are attributes of the class
filtered_data = {key: value for key, value in data.items() if key in attrs}
return cls(**filtered_data)
Expand Down Expand Up @@ -170,5 +165,6 @@ class GrpcCrossSiloMessageConfig(CrossSiloMessageConfig):
('grpc.max_send_message_length', 50 * 1024 * 1024)
]
"""

grpc_channel_options: List = None
grpc_retry_policy: Dict[str, str] = None
Loading

0 comments on commit 60ab909

Please sign in to comment.