Skip to content

Commit

Permalink
standalone full paralllel
Browse files Browse the repository at this point in the history
  • Loading branch information
pan-x-c committed Oct 7, 2022
1 parent 0249202 commit 7962674
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 27 deletions.
64 changes: 58 additions & 6 deletions federatedscope/core/communication.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,30 @@
import grpc
from concurrent import futures
import logging

from federatedscope.core.configs.config import global_cfg
from federatedscope.core.proto import gRPC_comm_manager_pb2, \
gRPC_comm_manager_pb2_grpc
from federatedscope.core.gRPC_server import gRPCComServeFunc
from federatedscope.core.message import Message

logger = logging.getLogger(__name__)

class StandaloneCommManager(object):
class StandaloneClientCommManager(object):
"""
The communicator used for standalone mode
"""
def __init__(self, comm_queue, monitor=None):
self.comm_queue = comm_queue
def __init__(self, receive_channel, send_channel, monitor=None):
self.receive_channel = receive_channel
self.send_channel = send_channel
self.neighbors = dict()
self.monitor = monitor # used to track the communication related
# metrics

def receive(self):
# we don't need receive() in standalone
pass
message = self.receive_channel.get()
logger.info(f"client {message.receiver} receive message {message.msg_type}")
return message

def add_neighbors(self, neighbor_id, address=None):
self.neighbors[neighbor_id] = address
Expand All @@ -39,11 +43,59 @@ def get_neighbors(self, neighbor_id=None):
return self.neighbors

def send(self, message):
self.comm_queue.append(message)
logger.info(f"client send message {message.msg_type}")
self.send_channel.put(message)
download_bytes, upload_bytes = message.count_bytes()
self.monitor.track_upload_bytes(upload_bytes)


class StandaloneServerCommManager(object):
"""
The communicator used for standalone mode
"""
def __init__(self, channels, monitor=None):
self.send_channel = channels
self.neighbors = dict()
self.monitor = monitor # used to track the communication related
# metrics

def receive(self):
pass

def add_neighbors(self, neighbor_id, address=None):
self.neighbors[neighbor_id] = address

def get_neighbors(self, neighbor_id=None):
address = dict()
if neighbor_id:
if isinstance(neighbor_id, list):
for each_neighbor in neighbor_id:
address[each_neighbor] = self.get_neighbors(each_neighbor)
return address
else:
return self.neighbors[neighbor_id]
else:
# Get all neighbors
return self.neighbors

def send(self, message):
download_bytes, upload_bytes = message.count_bytes()
receiver = message.receiver
if receiver is not None:
if not isinstance(receiver, list):
receiver = [receiver]
for each_receiver in receiver:
if each_receiver in self.neighbors:
logger.info(f"server send message to {each_receiver}")
channel = self.send_channel[each_receiver]
channel.put(message)
self.monitor.track_upload_bytes(upload_bytes)
else:
for channel in self.send_channel:
channel.put(message)
self.monitor.track_upload_bytes(upload_bytes)


class gRPCCommManager(object):
"""
The implementation of gRPCCommManager is referred to the tutorial on
Expand Down
45 changes: 30 additions & 15 deletions federatedscope/core/fed_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from collections import deque
import heapq
import multiprocessing
import time

import numpy as np

Expand Down Expand Up @@ -65,6 +67,7 @@ def __init__(self,

if self.mode == 'standalone':
self.shared_comm_queue = deque()
self.manager = multiprocessing.Manager()
self._setup_for_standalone()
# in standalone mode, by default, we print the trainer info only
# once for better logs readability
Expand Down Expand Up @@ -131,11 +134,16 @@ def _setup_for_standalone(self):
server_resource_info = None
client_resource_info = None

self.server2client_channels = dict()
for client_id in range(1, self.cfg.federate.client_num + 1):
self.server2client_channels[client_id] = self.manager.Queue()

self.server = self._setup_server(
resource_info=server_resource_info,
client_resource_info=client_resource_info)

self.client = dict()
self.client2server_channel = self.manager.Queue()

# assume the client-wise data are consistent in their input&output
# shape
Expand All @@ -144,6 +152,7 @@ def _setup_for_standalone(self):
) if self.cfg.federate.share_local_model else None

for client_id in range(1, self.cfg.federate.client_num + 1):

self.client[client_id] = self._setup_client(
client_id=client_id,
client_model=self._shared_client_model,
Expand Down Expand Up @@ -186,8 +195,10 @@ def run(self):
"""
if self.mode == 'standalone':
# trigger the FL course
self.pool = multiprocessing.Pool(processes=self.cfg.federate.client_num)
for each_client in self.client:
self.client[each_client].join_in()
self.pool.apply_async(self.client[each_client].run_standalone)
self.pool.close()

if self.cfg.federate.online_aggr:
# any broadcast operation would be executed client-by-client
Expand All @@ -199,6 +210,7 @@ def run(self):
self._run_simulation()

self.server._monitor.finish_fed_runner(fl_mode=self.mode)
self.pool.join()

return self.server.best_results

Expand Down Expand Up @@ -242,18 +254,18 @@ def is_broadcast(msg):
break

def _run_simulation(self):

server_msg_cache = list()
while self.client2server_channel.empty():
continue
while True:
if len(self.shared_comm_queue) > 0:
msg = self.shared_comm_queue.popleft()
if msg.receiver == [self.server_id]:
# For the server, move the received message to a
# cache for reordering the messages according to
# the timestamps
heapq.heappush(server_msg_cache, msg)
else:
self._handle_msg(msg)
if not self.client2server_channel.empty():

msg = self.client2server_channel.get()
logger.info(f"server receive message from {msg.sender}")
# For the server, move the received message to a
# cache for reordering the messages according to
# the timestamps
heapq.heappush(server_msg_cache, msg)
elif len(server_msg_cache) > 0:
msg = heapq.heappop(server_msg_cache)
if self.cfg.asyn.use and self.cfg.asyn.aggregator \
Expand All @@ -272,13 +284,14 @@ def _run_simulation(self):
if self.cfg.asyn.use and self.cfg.asyn.aggregator \
== 'time_up':
self.server.trigger_for_time_up()
if len(self.shared_comm_queue) == 0 and \
if self.client2server_channel.empty() and \
len(server_msg_cache) == 0:
break
else:
# terminate when shared_comm_queue and
# server_msg_cache are all empty
break
time.sleep(1)
# break

def _setup_server(self, resource_info=None, client_resource_info=None):
"""
Expand Down Expand Up @@ -309,7 +322,7 @@ def _setup_server(self, resource_info=None, client_resource_info=None):
) # get the model according to client's data if the server
# does not own data
kw = {
'shared_comm_queue': self.shared_comm_queue,
'channels': self.server2client_channels,
'resource_info': resource_info,
'client_resource_info': client_resource_info
}
Expand Down Expand Up @@ -360,7 +373,8 @@ def _setup_client(self,
if self.mode == 'standalone':
client_data = self.data[client_id]
kw = {
'shared_comm_queue': self.shared_comm_queue,
'client2server_channel': self.client2server_channel,
'server2client_channel': self.server2client_channels[client_id],
'resource_info': resource_info
}
elif self.mode == 'distributed':
Expand Down Expand Up @@ -424,6 +438,7 @@ def _handle_msg(self, msg, rcv=-1):
self.server.msg_handlers[msg.msg_type](msg)
self.server._monitor.track_download_bytes(download_bytes)
else:
logger.error("wrong place to go")
self.client[each_receiver].msg_handlers[msg.msg_type](msg)
self.client[each_receiver]._monitor.track_download_bytes(
download_bytes)
15 changes: 12 additions & 3 deletions federatedscope/core/workers/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pickle

from federatedscope.core.message import Message
from federatedscope.core.communication import StandaloneCommManager, \
from federatedscope.core.communication import StandaloneClientCommManager, \
gRPCCommManager
from federatedscope.core.monitors.early_stopper import EarlyStopper
from federatedscope.core.workers import Worker
Expand Down Expand Up @@ -117,8 +117,10 @@ def __init__(self,
# Initialize communication manager
self.server_id = server_id
if self.mode == 'standalone':
comm_queue = kwargs['shared_comm_queue']
self.comm_manager = StandaloneCommManager(comm_queue=comm_queue,
client2server_channel = kwargs['client2server_channel']
server2client_channel = kwargs['server2client_channel']
self.comm_manager = StandaloneClientCommManager(receive_channel=server2client_channel,
send_channel=client2server_channel,
monitor=self._monitor)
self.local_address = None
elif self.mode == 'distributed':
Expand Down Expand Up @@ -217,6 +219,13 @@ def run(self):
if msg.msg_type == 'finish':
break

def run_standalone(self):
"""
Run in standalone mode
"""
self.join_in()
self.run()

def callback_funcs_for_model_para(self, message: Message):
"""
The handling function for receiving model parameters,
Expand Down
6 changes: 3 additions & 3 deletions federatedscope/core/workers/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from federatedscope.core.monitors.early_stopper import EarlyStopper
from federatedscope.core.message import Message
from federatedscope.core.communication import StandaloneCommManager, \
from federatedscope.core.communication import StandaloneServerCommManager, \
gRPCCommManager
from federatedscope.core.workers import Worker
from federatedscope.core.auxiliaries.aggregator_builder import get_aggregator
Expand Down Expand Up @@ -173,8 +173,8 @@ def __init__(self,
self.msg_buffer = {'train': dict(), 'eval': dict()}
self.staled_msg_buffer = list()
if self.mode == 'standalone':
comm_queue = kwargs['shared_comm_queue']
self.comm_manager = StandaloneCommManager(comm_queue=comm_queue,
channels = kwargs['channels']
self.comm_manager = StandaloneServerCommManager(channels=channels,
monitor=self._monitor)
elif self.mode == 'distributed':
host = kwargs['host']
Expand Down

0 comments on commit 7962674

Please sign in to comment.