Skip to content

Commit

Permalink
add more registers & refactor splitter (#372)
Browse files Browse the repository at this point in the history
* add more registerable components

* update splitter interface

* remove graph_type splitter and update __repr__ for base splitter

* add README for register mechanism

Co-authored-by: Jones Wong <joneswong@users.noreply.github.com>
  • Loading branch information
rayrayraykk and joneswong authored Sep 16, 2022
1 parent ccafa14 commit 6095dd9
Show file tree
Hide file tree
Showing 37 changed files with 348 additions and 152 deletions.
15 changes: 15 additions & 0 deletions federatedscope/contrib/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Register

In addition to the rich collection of datasets, models, evaluation metrics, etc., FederatedScope (FS) also allows users to create their own ingredients or introduce more customized modules to FS. Inspired by GraphGym, we provide `register` mechanism to help integrating your own components into the FS-based federated learning workflow, including:

* Configurations [`federatedscope/contrib/config`](https://github.com/alibaba/FederatedScope/tree/master/federatedscope/contrib/config)
* Data [`federatedscope/contrib/data`](https://github.com/alibaba/FederatedScope/tree/master/federatedscope/contrib/data)
* Loss [`federatedscope/contrib/loss`](https://github.com/alibaba/FederatedScope/tree/master/federatedscope/contrib/loss)
* Metrics [`federatedscope/contrib/metrics`](https://github.com/alibaba/FederatedScope/tree/master/federatedscope/contrib/metrics)
* Model [`federatedscope/contrib/model`](https://github.com/alibaba/FederatedScope/tree/master/federatedscope/contrib/model)
* Optimizer [`federatedscope/contrib/optimizer`](https://github.com/alibaba/FederatedScope/tree/master/federatedscope/contrib/optimizer)
* Scheduler [`federatedscope/contrib/scheduler`](https://github.com/alibaba/FederatedScope/tree/master/federatedscope/contrib/scheduler)
* Splitter [`federatedscope/contrib/splitter`](https://github.com/alibaba/FederatedScope/tree/master/federatedscope/contrib/splitter)
* Trainer [`federatedscope/contrib/trainer`](https://github.com/alibaba/FederatedScope/tree/master/federatedscope/contrib/trainer)
* Worker [`federatedscope/contrib/worker`](https://github.com/alibaba/FederatedScope/tree/master/federatedscope/contrib/worker)

3 changes: 0 additions & 3 deletions federatedscope/contrib/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +0,0 @@
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
8 changes: 8 additions & 0 deletions federatedscope/contrib/loss/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from os.path import dirname, basename, isfile, join
import glob

modules = glob.glob(join(dirname(__file__), "*.py"))
__all__ = [
basename(f)[:-3] for f in modules
if isfile(f) and not f.endswith('__init__.py')
]
17 changes: 17 additions & 0 deletions federatedscope/contrib/loss/example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from federatedscope.register import register_criterion


def call_my_criterion(type, device):
try:
import torch.nn as nn
except ImportError:
nn = None
criterion = None

if type == 'mycriterion':
if nn is not None:
criterion = nn.CrossEntropyLoss().to(device)
return criterion


register_criterion('mycriterion', call_my_criterion)
8 changes: 8 additions & 0 deletions federatedscope/contrib/optimizer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from os.path import dirname, basename, isfile, join
import glob

modules = glob.glob(join(dirname(__file__), "*.py"))
__all__ = [
basename(f)[:-3] for f in modules
if isfile(f) and not f.endswith('__init__.py')
]
17 changes: 17 additions & 0 deletions federatedscope/contrib/optimizer/example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from federatedscope.register import register_optimizer


def call_my_optimizer(model, type, lr, **kwargs):
try:
import torch.optim as optim
except ImportError:
optim = None
optimizer = None

if type == 'myoptimizer':
if optim is not None:
optimizer = optim.Adam(model.parameters(), lr=lr, **kwargs)
return optimizer


register_optimizer('myoptimizer', call_my_optimizer)
8 changes: 8 additions & 0 deletions federatedscope/contrib/scheduler/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from os.path import dirname, basename, isfile, join
import glob

modules = glob.glob(join(dirname(__file__), "*.py"))
__all__ = [
basename(f)[:-3] for f in modules
if isfile(f) and not f.endswith('__init__.py')
]
18 changes: 18 additions & 0 deletions federatedscope/contrib/scheduler/example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from federatedscope.register import register_scheduler


def call_my_scheduler(optimizer, type):
try:
import torch.optim as optim
except ImportError:
optim = None
scheduler = None

if type == 'myscheduler':
if optim is not None:
lr_lambda = [lambda epoch: epoch // 30]
scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
return scheduler


register_scheduler('myscheduler', call_my_scheduler)
8 changes: 8 additions & 0 deletions federatedscope/contrib/splitter/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from os.path import dirname, basename, isfile, join
import glob

modules = glob.glob(join(dirname(__file__), "*.py"))
__all__ = [
basename(f)[:-3] for f in modules
if isfile(f) and not f.endswith('__init__.py')
]
26 changes: 26 additions & 0 deletions federatedscope/contrib/splitter/example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from federatedscope.register import register_splitter
from federatedscope.core.splitters import BaseSplitter


class MySplitter(BaseSplitter):
def __init__(self, client_num, **kwargs):
super(MySplitter, self).__init__(client_num, **kwargs)

def __call__(self, dataset, *args, **kwargs):
# Dummy splitter, only for demonstration
per_samples = len(dataset) // self.client_num
data_list, cur_index = [], 0
for i in range(self.client_num):
data_list.append(
[x for x in range(cur_index, cur_index + per_samples)])
cur_index += per_samples
return data_list


def call_my_splitter(client_num, **kwargs):
if type == 'mysplitter':
splitter = MySplitter(client_num, **kwargs)
return splitter


register_splitter('mysplitter', call_my_splitter)
8 changes: 8 additions & 0 deletions federatedscope/contrib/worker/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from os.path import dirname, basename, isfile, join
import glob

modules = glob.glob(join(dirname(__file__), "*.py"))
__all__ = [
basename(f)[:-3] for f in modules
if isfile(f) and not f.endswith('__init__.py')
]
20 changes: 20 additions & 0 deletions federatedscope/contrib/worker/example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from federatedscope.register import register_worker
from federatedscope.core.workers import Server, Client


# Build your worker here.
class MyClient(Client):
pass


class MyServer(Server):
pass


def call_my_worker(method):
if method == 'mymethod':
worker_builder = {'client': MyClient, 'server': MyServer}
return worker_builder


register_worker('mymethod', call_my_worker)
10 changes: 10 additions & 0 deletions federatedscope/core/auxiliaries/criterion_builder.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,21 @@
import logging
import federatedscope.register as register

logger = logging.getLogger(__name__)

try:
from torch import nn
from federatedscope.nlp.loss import *
except ImportError:
nn = None

try:
from federatedscope.contrib.loss import *
except ImportError as error:
logger.warning(
f'{error} in `federatedscope.contrib.loss`, some modules are not '
f'available.')


def get_criterion(type, device):
for func in register.criterion_dict.values():
Expand Down
19 changes: 18 additions & 1 deletion federatedscope/core/auxiliaries/optimizer_builder.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,20 @@
import copy
import logging
import federatedscope.register as register

logger = logging.getLogger(__name__)

try:
import torch
except ImportError:
torch = None

import copy
try:
from federatedscope.contrib.optimizer import *
except ImportError as error:
logger.warning(
f'{error} in `federatedscope.contrib.optimizer`, some modules are not '
f'available.')


def get_optimizer(model, type, lr, **kwargs):
Expand All @@ -17,6 +28,12 @@ def get_optimizer(model, type, lr, **kwargs):
del tmp_kwargs['__cfg_check_funcs__']
if 'is_ready_for_run' in tmp_kwargs:
del tmp_kwargs['is_ready_for_run']

for func in register.optimizer_dict.values():
optimizer = func(model, type, lr, **tmp_kwargs)
if optimizer is not None:
return optimizer

if isinstance(type, str):
if hasattr(torch.optim, type):
if isinstance(model, torch.nn.Module):
Expand Down
17 changes: 17 additions & 0 deletions federatedscope/core/auxiliaries/scheduler_builder.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,27 @@
import logging
import federatedscope.register as register

logger = logging.getLogger(__name__)

try:
import torch
except ImportError:
torch = None

try:
from federatedscope.contrib.scheduler import *
except ImportError as error:
logger.warning(
f'{error} in `federatedscope.contrib.scheduler`, some modules are not '
f'available.')


def get_scheduler(optimizer, type, **kwargs):
for func in register.scheduler_dict.values():
scheduler = func(optimizer, type)
if scheduler is not None:
return scheduler

if torch is None or type == '':
return None
if isinstance(type, str):
Expand Down
25 changes: 11 additions & 14 deletions federatedscope/core/auxiliaries/splitter_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,42 +7,39 @@
def get_splitter(config):
client_num = config.federate.client_num
if config.data.splitter_args:
args = config.data.splitter_args[0]
kwargs = config.data.splitter_args[0]
else:
args = {}
kwargs = {}

for func in register.splitter_dict.values():
splitter = func(config)
splitter = func(client_num, **kwargs)
if splitter is not None:
return splitter
# Delay import
# generic splitter
if config.data.splitter == 'lda':
from federatedscope.core.splitters.generic import LDASplitter
splitter = LDASplitter(client_num, **args)
splitter = LDASplitter(client_num, **kwargs)
# graph splitter
elif config.data.splitter == 'louvain':
from federatedscope.core.splitters.graph import LouvainSplitter
splitter = LouvainSplitter(client_num, **args)
splitter = LouvainSplitter(client_num, **kwargs)
elif config.data.splitter == 'random':
from federatedscope.core.splitters.graph import RandomSplitter
splitter = RandomSplitter(client_num, **args)
splitter = RandomSplitter(client_num, **kwargs)
elif config.data.splitter == 'rel_type':
from federatedscope.core.splitters.graph import RelTypeSplitter
splitter = RelTypeSplitter(client_num, **args)
elif config.data.splitter == 'graph_type':
from federatedscope.core.splitters.graph import GraphTypeSplitter
splitter = GraphTypeSplitter(client_num, **args)
splitter = RelTypeSplitter(client_num, **kwargs)
elif config.data.splitter == 'scaffold':
from federatedscope.core.splitters.graph import ScaffoldSplitter
splitter = ScaffoldSplitter(client_num, **args)
splitter = ScaffoldSplitter(client_num, **kwargs)
elif config.data.splitter == 'scaffold_lda':
from federatedscope.core.splitters.graph import ScaffoldLdaSplitter
splitter = ScaffoldLdaSplitter(client_num, **args)
splitter = ScaffoldLdaSplitter(client_num, **kwargs)
elif config.data.splitter == 'rand_chunk':
from federatedscope.core.splitters.graph import RandChunkSplitter
splitter = RandChunkSplitter(client_num, **args)
splitter = RandChunkSplitter(client_num, **kwargs)
else:
logger.warning('Splitter is none or not found.')
logger.warning(f'Splitter {config.data.splitter} not found.')
splitter = None
return splitter
34 changes: 27 additions & 7 deletions federatedscope/core/auxiliaries/worker_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,24 @@

from federatedscope.core.configs import constants
from federatedscope.core.workers import Server, Client
import federatedscope.register as register

logger = logging.getLogger(__name__)

try:
from federatedscope.contrib.worker import *
except ImportError as error:
logger.warning(
f'{error} in `federatedscope.contrib.worker`, some modules are not '
f'available.')


def get_client_cls(cfg):
for func in register.worker_dict.values():
worker_class = func(cfg.federate.method.lower())
if worker_class is not None:
return worker_class['client']

if cfg.hpo.fedex.use:
from federatedscope.autotune.fedex import FedExClient
return FedExClient
Expand Down Expand Up @@ -49,6 +62,11 @@ def get_client_cls(cfg):


def get_server_cls(cfg):
for func in register.worker_dict.values():
worker_class = func(cfg.federate.method.lower())
if worker_class is not None:
return worker_class['server']

if cfg.hpo.fedex.use:
from federatedscope.autotune.fedex import FedExServer
return FedExServer
Expand All @@ -72,18 +90,20 @@ def get_server_cls(cfg):
return vFLServer

if cfg.federate.method.lower() in constants.SERVER_TYPE:
client_type = constants.SERVER_TYPE[cfg.federate.method.lower()]
server_type = constants.SERVER_TYPE[cfg.federate.method.lower()]
else:
client_type = "normal"
server_type = "normal"
logger.warning(
'Server for method {} is not implemented. Will use default one'.
format(cfg.federate.method))

if client_type == 'fedsageplus':
if server_type == 'fedsageplus':
from federatedscope.gfl.fedsageplus.worker import FedSagePlusServer
return FedSagePlusServer
elif client_type == 'gcflplus':
server_class = FedSagePlusServer
elif server_type == 'gcflplus':
from federatedscope.gfl.gcflplus.worker import GCFLPlusServer
return GCFLPlusServer
server_class = GCFLPlusServer
else:
return Server
server_class = Server

return server_class
Loading

0 comments on commit 6095dd9

Please sign in to comment.