diff --git a/federatedscope/contrib/README.md b/federatedscope/contrib/README.md new file mode 100644 index 000000000..e5f272d85 --- /dev/null +++ b/federatedscope/contrib/README.md @@ -0,0 +1,15 @@ +# Register + +In addition to the rich collection of datasets, models, evaluation metrics, etc., FederatedScope (FS) also allows users to create their own ingredients or introduce more customized modules to FS. Inspired by GraphGym, we provide `register` mechanism to help integrating your own components into the FS-based federated learning workflow, including: + +* Configurations [`federatedscope/contrib/config`](https://github.com/alibaba/FederatedScope/tree/master/federatedscope/contrib/config) +* Data [`federatedscope/contrib/data`](https://github.com/alibaba/FederatedScope/tree/master/federatedscope/contrib/data) +* Loss [`federatedscope/contrib/loss`](https://github.com/alibaba/FederatedScope/tree/master/federatedscope/contrib/loss) +* Metrics [`federatedscope/contrib/metrics`](https://github.com/alibaba/FederatedScope/tree/master/federatedscope/contrib/metrics) +* Model [`federatedscope/contrib/model`](https://github.com/alibaba/FederatedScope/tree/master/federatedscope/contrib/model) +* Optimizer [`federatedscope/contrib/optimizer`](https://github.com/alibaba/FederatedScope/tree/master/federatedscope/contrib/optimizer) +* Scheduler [`federatedscope/contrib/scheduler`](https://github.com/alibaba/FederatedScope/tree/master/federatedscope/contrib/scheduler) +* Splitter [`federatedscope/contrib/splitter`](https://github.com/alibaba/FederatedScope/tree/master/federatedscope/contrib/splitter) +* Trainer [`federatedscope/contrib/trainer`](https://github.com/alibaba/FederatedScope/tree/master/federatedscope/contrib/trainer) +* Worker [`federatedscope/contrib/worker`](https://github.com/alibaba/FederatedScope/tree/master/federatedscope/contrib/worker) + diff --git a/federatedscope/contrib/__init__.py b/federatedscope/contrib/__init__.py index f8e91f237..e69de29bb 100644 --- a/federatedscope/contrib/__init__.py +++ b/federatedscope/contrib/__init__.py @@ -1,3 +0,0 @@ -from __future__ import absolute_import -from __future__ import print_function -from __future__ import division diff --git a/federatedscope/contrib/loss/__init__.py b/federatedscope/contrib/loss/__init__.py new file mode 100644 index 000000000..c0b31382d --- /dev/null +++ b/federatedscope/contrib/loss/__init__.py @@ -0,0 +1,8 @@ +from os.path import dirname, basename, isfile, join +import glob + +modules = glob.glob(join(dirname(__file__), "*.py")) +__all__ = [ + basename(f)[:-3] for f in modules + if isfile(f) and not f.endswith('__init__.py') +] diff --git a/federatedscope/contrib/loss/example.py b/federatedscope/contrib/loss/example.py new file mode 100644 index 000000000..012eca28a --- /dev/null +++ b/federatedscope/contrib/loss/example.py @@ -0,0 +1,17 @@ +from federatedscope.register import register_criterion + + +def call_my_criterion(type, device): + try: + import torch.nn as nn + except ImportError: + nn = None + criterion = None + + if type == 'mycriterion': + if nn is not None: + criterion = nn.CrossEntropyLoss().to(device) + return criterion + + +register_criterion('mycriterion', call_my_criterion) diff --git a/federatedscope/contrib/optimizer/__init__.py b/federatedscope/contrib/optimizer/__init__.py new file mode 100644 index 000000000..c0b31382d --- /dev/null +++ b/federatedscope/contrib/optimizer/__init__.py @@ -0,0 +1,8 @@ +from os.path import dirname, basename, isfile, join +import glob + +modules = glob.glob(join(dirname(__file__), "*.py")) +__all__ = [ + basename(f)[:-3] for f in modules + if isfile(f) and not f.endswith('__init__.py') +] diff --git a/federatedscope/contrib/optimizer/example.py b/federatedscope/contrib/optimizer/example.py new file mode 100644 index 000000000..41c083bda --- /dev/null +++ b/federatedscope/contrib/optimizer/example.py @@ -0,0 +1,17 @@ +from federatedscope.register import register_optimizer + + +def call_my_optimizer(model, type, lr, **kwargs): + try: + import torch.optim as optim + except ImportError: + optim = None + optimizer = None + + if type == 'myoptimizer': + if optim is not None: + optimizer = optim.Adam(model.parameters(), lr=lr, **kwargs) + return optimizer + + +register_optimizer('myoptimizer', call_my_optimizer) diff --git a/federatedscope/contrib/scheduler/__init__.py b/federatedscope/contrib/scheduler/__init__.py new file mode 100644 index 000000000..c0b31382d --- /dev/null +++ b/federatedscope/contrib/scheduler/__init__.py @@ -0,0 +1,8 @@ +from os.path import dirname, basename, isfile, join +import glob + +modules = glob.glob(join(dirname(__file__), "*.py")) +__all__ = [ + basename(f)[:-3] for f in modules + if isfile(f) and not f.endswith('__init__.py') +] diff --git a/federatedscope/contrib/scheduler/example.py b/federatedscope/contrib/scheduler/example.py new file mode 100644 index 000000000..642d6fab3 --- /dev/null +++ b/federatedscope/contrib/scheduler/example.py @@ -0,0 +1,18 @@ +from federatedscope.register import register_scheduler + + +def call_my_scheduler(optimizer, type): + try: + import torch.optim as optim + except ImportError: + optim = None + scheduler = None + + if type == 'myscheduler': + if optim is not None: + lr_lambda = [lambda epoch: epoch // 30] + scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) + return scheduler + + +register_scheduler('myscheduler', call_my_scheduler) diff --git a/federatedscope/contrib/splitter/__init__.py b/federatedscope/contrib/splitter/__init__.py new file mode 100644 index 000000000..c0b31382d --- /dev/null +++ b/federatedscope/contrib/splitter/__init__.py @@ -0,0 +1,8 @@ +from os.path import dirname, basename, isfile, join +import glob + +modules = glob.glob(join(dirname(__file__), "*.py")) +__all__ = [ + basename(f)[:-3] for f in modules + if isfile(f) and not f.endswith('__init__.py') +] diff --git a/federatedscope/contrib/splitter/example.py b/federatedscope/contrib/splitter/example.py new file mode 100644 index 000000000..f82c07204 --- /dev/null +++ b/federatedscope/contrib/splitter/example.py @@ -0,0 +1,26 @@ +from federatedscope.register import register_splitter +from federatedscope.core.splitters import BaseSplitter + + +class MySplitter(BaseSplitter): + def __init__(self, client_num, **kwargs): + super(MySplitter, self).__init__(client_num, **kwargs) + + def __call__(self, dataset, *args, **kwargs): + # Dummy splitter, only for demonstration + per_samples = len(dataset) // self.client_num + data_list, cur_index = [], 0 + for i in range(self.client_num): + data_list.append( + [x for x in range(cur_index, cur_index + per_samples)]) + cur_index += per_samples + return data_list + + +def call_my_splitter(client_num, **kwargs): + if type == 'mysplitter': + splitter = MySplitter(client_num, **kwargs) + return splitter + + +register_splitter('mysplitter', call_my_splitter) diff --git a/federatedscope/contrib/worker/__init__.py b/federatedscope/contrib/worker/__init__.py new file mode 100644 index 000000000..c0b31382d --- /dev/null +++ b/federatedscope/contrib/worker/__init__.py @@ -0,0 +1,8 @@ +from os.path import dirname, basename, isfile, join +import glob + +modules = glob.glob(join(dirname(__file__), "*.py")) +__all__ = [ + basename(f)[:-3] for f in modules + if isfile(f) and not f.endswith('__init__.py') +] diff --git a/federatedscope/contrib/worker/example.py b/federatedscope/contrib/worker/example.py new file mode 100644 index 000000000..7f0ed2e7e --- /dev/null +++ b/federatedscope/contrib/worker/example.py @@ -0,0 +1,20 @@ +from federatedscope.register import register_worker +from federatedscope.core.workers import Server, Client + + +# Build your worker here. +class MyClient(Client): + pass + + +class MyServer(Server): + pass + + +def call_my_worker(method): + if method == 'mymethod': + worker_builder = {'client': MyClient, 'server': MyServer} + return worker_builder + + +register_worker('mymethod', call_my_worker) diff --git a/federatedscope/core/auxiliaries/criterion_builder.py b/federatedscope/core/auxiliaries/criterion_builder.py index 1502fc0d3..4192a9b02 100644 --- a/federatedscope/core/auxiliaries/criterion_builder.py +++ b/federatedscope/core/auxiliaries/criterion_builder.py @@ -1,11 +1,21 @@ +import logging import federatedscope.register as register +logger = logging.getLogger(__name__) + try: from torch import nn from federatedscope.nlp.loss import * except ImportError: nn = None +try: + from federatedscope.contrib.loss import * +except ImportError as error: + logger.warning( + f'{error} in `federatedscope.contrib.loss`, some modules are not ' + f'available.') + def get_criterion(type, device): for func in register.criterion_dict.values(): diff --git a/federatedscope/core/auxiliaries/optimizer_builder.py b/federatedscope/core/auxiliaries/optimizer_builder.py index 5219b1116..bd6d1bd13 100644 --- a/federatedscope/core/auxiliaries/optimizer_builder.py +++ b/federatedscope/core/auxiliaries/optimizer_builder.py @@ -1,9 +1,20 @@ +import copy +import logging +import federatedscope.register as register + +logger = logging.getLogger(__name__) + try: import torch except ImportError: torch = None -import copy +try: + from federatedscope.contrib.optimizer import * +except ImportError as error: + logger.warning( + f'{error} in `federatedscope.contrib.optimizer`, some modules are not ' + f'available.') def get_optimizer(model, type, lr, **kwargs): @@ -17,6 +28,12 @@ def get_optimizer(model, type, lr, **kwargs): del tmp_kwargs['__cfg_check_funcs__'] if 'is_ready_for_run' in tmp_kwargs: del tmp_kwargs['is_ready_for_run'] + + for func in register.optimizer_dict.values(): + optimizer = func(model, type, lr, **tmp_kwargs) + if optimizer is not None: + return optimizer + if isinstance(type, str): if hasattr(torch.optim, type): if isinstance(model, torch.nn.Module): diff --git a/federatedscope/core/auxiliaries/scheduler_builder.py b/federatedscope/core/auxiliaries/scheduler_builder.py index ef5e0a0c8..afc7a5f99 100644 --- a/federatedscope/core/auxiliaries/scheduler_builder.py +++ b/federatedscope/core/auxiliaries/scheduler_builder.py @@ -1,10 +1,27 @@ +import logging +import federatedscope.register as register + +logger = logging.getLogger(__name__) + try: import torch except ImportError: torch = None +try: + from federatedscope.contrib.scheduler import * +except ImportError as error: + logger.warning( + f'{error} in `federatedscope.contrib.scheduler`, some modules are not ' + f'available.') + def get_scheduler(optimizer, type, **kwargs): + for func in register.scheduler_dict.values(): + scheduler = func(optimizer, type) + if scheduler is not None: + return scheduler + if torch is None or type == '': return None if isinstance(type, str): diff --git a/federatedscope/core/auxiliaries/splitter_builder.py b/federatedscope/core/auxiliaries/splitter_builder.py index c09ca73ce..31ad607ba 100644 --- a/federatedscope/core/auxiliaries/splitter_builder.py +++ b/federatedscope/core/auxiliaries/splitter_builder.py @@ -7,42 +7,39 @@ def get_splitter(config): client_num = config.federate.client_num if config.data.splitter_args: - args = config.data.splitter_args[0] + kwargs = config.data.splitter_args[0] else: - args = {} + kwargs = {} for func in register.splitter_dict.values(): - splitter = func(config) + splitter = func(client_num, **kwargs) if splitter is not None: return splitter # Delay import # generic splitter if config.data.splitter == 'lda': from federatedscope.core.splitters.generic import LDASplitter - splitter = LDASplitter(client_num, **args) + splitter = LDASplitter(client_num, **kwargs) # graph splitter elif config.data.splitter == 'louvain': from federatedscope.core.splitters.graph import LouvainSplitter - splitter = LouvainSplitter(client_num, **args) + splitter = LouvainSplitter(client_num, **kwargs) elif config.data.splitter == 'random': from federatedscope.core.splitters.graph import RandomSplitter - splitter = RandomSplitter(client_num, **args) + splitter = RandomSplitter(client_num, **kwargs) elif config.data.splitter == 'rel_type': from federatedscope.core.splitters.graph import RelTypeSplitter - splitter = RelTypeSplitter(client_num, **args) - elif config.data.splitter == 'graph_type': - from federatedscope.core.splitters.graph import GraphTypeSplitter - splitter = GraphTypeSplitter(client_num, **args) + splitter = RelTypeSplitter(client_num, **kwargs) elif config.data.splitter == 'scaffold': from federatedscope.core.splitters.graph import ScaffoldSplitter - splitter = ScaffoldSplitter(client_num, **args) + splitter = ScaffoldSplitter(client_num, **kwargs) elif config.data.splitter == 'scaffold_lda': from federatedscope.core.splitters.graph import ScaffoldLdaSplitter - splitter = ScaffoldLdaSplitter(client_num, **args) + splitter = ScaffoldLdaSplitter(client_num, **kwargs) elif config.data.splitter == 'rand_chunk': from federatedscope.core.splitters.graph import RandChunkSplitter - splitter = RandChunkSplitter(client_num, **args) + splitter = RandChunkSplitter(client_num, **kwargs) else: - logger.warning('Splitter is none or not found.') + logger.warning(f'Splitter {config.data.splitter} not found.') splitter = None return splitter diff --git a/federatedscope/core/auxiliaries/worker_builder.py b/federatedscope/core/auxiliaries/worker_builder.py index bff5c6183..f0f94e375 100644 --- a/federatedscope/core/auxiliaries/worker_builder.py +++ b/federatedscope/core/auxiliaries/worker_builder.py @@ -2,11 +2,24 @@ from federatedscope.core.configs import constants from federatedscope.core.workers import Server, Client +import federatedscope.register as register logger = logging.getLogger(__name__) +try: + from federatedscope.contrib.worker import * +except ImportError as error: + logger.warning( + f'{error} in `federatedscope.contrib.worker`, some modules are not ' + f'available.') + def get_client_cls(cfg): + for func in register.worker_dict.values(): + worker_class = func(cfg.federate.method.lower()) + if worker_class is not None: + return worker_class['client'] + if cfg.hpo.fedex.use: from federatedscope.autotune.fedex import FedExClient return FedExClient @@ -49,6 +62,11 @@ def get_client_cls(cfg): def get_server_cls(cfg): + for func in register.worker_dict.values(): + worker_class = func(cfg.federate.method.lower()) + if worker_class is not None: + return worker_class['server'] + if cfg.hpo.fedex.use: from federatedscope.autotune.fedex import FedExServer return FedExServer @@ -72,18 +90,20 @@ def get_server_cls(cfg): return vFLServer if cfg.federate.method.lower() in constants.SERVER_TYPE: - client_type = constants.SERVER_TYPE[cfg.federate.method.lower()] + server_type = constants.SERVER_TYPE[cfg.federate.method.lower()] else: - client_type = "normal" + server_type = "normal" logger.warning( 'Server for method {} is not implemented. Will use default one'. format(cfg.federate.method)) - if client_type == 'fedsageplus': + if server_type == 'fedsageplus': from federatedscope.gfl.fedsageplus.worker import FedSagePlusServer - return FedSagePlusServer - elif client_type == 'gcflplus': + server_class = FedSagePlusServer + elif server_type == 'gcflplus': from federatedscope.gfl.gcflplus.worker import GCFLPlusServer - return GCFLPlusServer + server_class = GCFLPlusServer else: - return Server + server_class = Server + + return server_class diff --git a/federatedscope/core/configs/README.md b/federatedscope/core/configs/README.md index 984010c02..2c55ec7d1 100644 --- a/federatedscope/core/configs/README.md +++ b/federatedscope/core/configs/README.md @@ -14,33 +14,33 @@ We summarize all the customizable configurations: ### Data The configurations related to the data/dataset are defined in `cfg_data.py`. -| Name | (Type) Default Value | Description | Note | -|:----:|:-----:|:---------- |:---- | -| `data.root` | (string) 'data' | The folder where the data file located. `data.root` would be used together with `data.type` to load the dataset. | - | +| Name | (Type) Default Value | Description | Note | +|:----:|:-----:|:---------- |:--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| `data.root` | (string) 'data' | The folder where the data file located. `data.root` would be used together with `data.type` to load the dataset. | - | | `data.type` | (string) 'toy' | Dataset name | CV: 'femnist', 'celeba' ; NLP: 'shakespeare', 'subreddit', 'twitter'; Graph: 'cora', 'citeseer', 'pubmed', 'dblp_conf', 'dblp_org', 'csbm', 'epinions', 'ciao', 'fb15k-237', 'wn18', 'fb15k' , 'MUTAG', 'BZR', 'COX2', 'DHFR', 'PTC_MR', 'AIDS', 'NCI1', 'ENZYMES', 'DD', 'PROTEINS', 'COLLAB', 'IMDB-BINARY', 'IMDB-MULTI', 'REDDIT-BINARY', 'IMDB-BINARY', 'IMDB-MULTI', 'HIV', 'ESOL', 'FREESOLV', 'LIPO', 'PCBA', 'MUV', 'BACE', 'BBBP', 'TOX21', 'TOXCAST', 'SIDER', 'CLINTOX', 'graph_multi_domain_mol', 'graph_multi_domain_small', 'graph_multi_domain_mix', 'graph_multi_domain_biochem'; MF: 'vflmovielens1m', 'vflmovielens10m', 'hflmovielens1m', 'hflmovielens10m', 'vflnetflix', 'hflnetflix'; Tabular: 'toy', 'synthetic'; External dataset: 'DNAME@torchvision', 'DNAME@torchtext', 'DNAME@huggingface_datasets', 'DNAME@openml'. | -| `data.args` | (list) [] | Args for the external dataset | Used for external dataset, eg. `[{'download': False}]` | -| `data.save_data` | (bool) False | Whether to save the generated toy data | - | -| `data.splitter` | (string) '' | Splitter name for standalone dataset | Generic splitter: 'lda'; Graph splitter: 'louvain', 'random', 'rel_type', 'graph_type', 'scaffold', 'scaffold_lda', 'rand_chunk' | -| `data.splitter_args` | (list) [] | Args for splitter. | Used for splitter, eg. `[{'alpha': 0.5}]` | -| `data.transform` | (list) [] | Transform for x of data | Used in `get_item` in torch.dataset, eg. `[['ToTensor'], ['Normalize', {'mean': [0.1307], 'std': [0.3081]}]]` | -| `data.target_transform` | (list) [] | Transform for y of data | Use as `data.transform` | -| `data.pre_transform` | (list) [] | Pre_transform for `torch_geometric` dataset | Use as `data.transform` | -| `data.batch_size` | (int) 64 | batch_size for DataLoader | - | -| `data.drop_last` | (bool) False | Whether drop last batch (if the number of last batch is smaller than batch_size) in DataLoader | - | -| `data.sizes` | (list) [10, 5] | Sample size for graph DataLoader | The length of `data.sizes` must meet the layer of GNN models. | -| `data.shuffle` | (bool) True | Shuffle train DataLoader | - | -| `data.server_holds_all` | (bool) False | Only use in global mode, whether the server (workers with idx 0) holds all data, useful in global training/evaluation case | - | -| `data.subsample` | (float) 1.0 |  Only used in LEAF datasets, subsample clients from all clients | - | -| `data.splits` | (list) [0.8, 0.1, 0.1] | Train, valid, test splits | - | -| `data.`
`consistent_label_distribution` | (bool) False | Make label distribution of train/val/test set over clients keep consistent during splitting | - | -| `data.cSBM_phi` | (list) [0.5, 0.5, 0.5] | Phi for cSBM graph dataset | - | -| `data.loader` | (string) '' | Graph sample name, used in minibatch trainer | 'graphsaint-rw': use `GraphSAINTRandomWalkSampler` as DataLoader; 'neighbor': use `NeighborSampler` as DataLoader. | -| `data.num_workers` | (int) 0 | num_workers in DataLoader | - | -| `data.graphsaint.walk_length` | (int) 2 | The length of each random walk in graphsaint. | - | -| `data.graphsaint.num_steps` | (int) 30 | The number of iterations per epoch in graphsaint. | - | -| `data.quadratic.dim` | (int) 1 | Dim of synthetic quadratic  dataset | - | -| `data.quadratic.min_curv` | (float) 0.02 | Min_curve of synthetic quadratic dataset | - | -| `data.quadratic.max_curv` | (float) 12.5 | Max_cur of synthetic quadratic dataset | - | +| `data.args` | (list) [] | Args for the external dataset | Used for external dataset, eg. `[{'download': False}]` | +| `data.save_data` | (bool) False | Whether to save the generated toy data | - | +| `data.splitter` | (string) '' | Splitter name for standalone dataset | Generic splitter: 'lda'; Graph splitter: 'louvain', 'random', 'rel_type', 'scaffold', 'scaffold_lda', 'rand_chunk' | +| `data.splitter_args` | (list) [] | Args for splitter. | Used for splitter, eg. `[{'alpha': 0.5}]` | +| `data.transform` | (list) [] | Transform for x of data | Used in `get_item` in torch.dataset, eg. `[['ToTensor'], ['Normalize', {'mean': [0.1307], 'std': [0.3081]}]]` | +| `data.target_transform` | (list) [] | Transform for y of data | Use as `data.transform` | +| `data.pre_transform` | (list) [] | Pre_transform for `torch_geometric` dataset | Use as `data.transform` | +| `data.batch_size` | (int) 64 | batch_size for DataLoader | - | +| `data.drop_last` | (bool) False | Whether drop last batch (if the number of last batch is smaller than batch_size) in DataLoader | - | +| `data.sizes` | (list) [10, 5] | Sample size for graph DataLoader | The length of `data.sizes` must meet the layer of GNN models. | +| `data.shuffle` | (bool) True | Shuffle train DataLoader | - | +| `data.server_holds_all` | (bool) False | Only use in global mode, whether the server (workers with idx 0) holds all data, useful in global training/evaluation case | - | +| `data.subsample` | (float) 1.0 |  Only used in LEAF datasets, subsample clients from all clients | - | +| `data.splits` | (list) [0.8, 0.1, 0.1] | Train, valid, test splits | - | +| `data.`
`consistent_label_distribution` | (bool) False | Make label distribution of train/val/test set over clients keep consistent during splitting | - | +| `data.cSBM_phi` | (list) [0.5, 0.5, 0.5] | Phi for cSBM graph dataset | - | +| `data.loader` | (string) '' | Graph sample name, used in minibatch trainer | 'graphsaint-rw': use `GraphSAINTRandomWalkSampler` as DataLoader; 'neighbor': use `NeighborSampler` as DataLoader. | +| `data.num_workers` | (int) 0 | num_workers in DataLoader | - | +| `data.graphsaint.walk_length` | (int) 2 | The length of each random walk in graphsaint. | - | +| `data.graphsaint.num_steps` | (int) 30 | The number of iterations per epoch in graphsaint. | - | +| `data.quadratic.dim` | (int) 1 | Dim of synthetic quadratic  dataset | - | +| `data.quadratic.min_curv` | (float) 0.02 | Min_curve of synthetic quadratic dataset | - | +| `data.quadratic.max_curv` | (float) 12.5 | Max_cur of synthetic quadratic dataset | - | ### Model diff --git a/federatedscope/core/monitors/monitor.py b/federatedscope/core/monitors/monitor.py index 4c85ea804..8862b855a 100644 --- a/federatedscope/core/monitors/monitor.py +++ b/federatedscope/core/monitors/monitor.py @@ -250,7 +250,8 @@ def finish_fed_runner(self, fl_mode=None): "cfg.wandb.use=True but not install the wandb package") exit() - from federatedscope.core.auxiliaries.logging import logfile_2_wandb_dict + from federatedscope.core.auxiliaries.logging import \ + logfile_2_wandb_dict with open(os.path.join(self.outdir, "eval_results.log"), "r") as exp_log_f: # track the prediction related performance diff --git a/federatedscope/core/splitters/__init__.py b/federatedscope/core/splitters/__init__.py index f8e91f237..d7af0c191 100644 --- a/federatedscope/core/splitters/__init__.py +++ b/federatedscope/core/splitters/__init__.py @@ -1,3 +1,3 @@ -from __future__ import absolute_import -from __future__ import print_function -from __future__ import division +from federatedscope.core.splitters.base_splitter import BaseSplitter + +__all__ = ['BaseSplitter'] diff --git a/federatedscope/core/splitters/base_splitter.py b/federatedscope/core/splitters/base_splitter.py new file mode 100644 index 000000000..1e80bb986 --- /dev/null +++ b/federatedscope/core/splitters/base_splitter.py @@ -0,0 +1,28 @@ +import abc +import inspect + + +class BaseSplitter(abc.ABC): + def __init__(self, client_num): + """ + This is an abstract base class for all splitter. + + Args: + client_num: Divide the dataset into `client_num` pieces. + """ + self.client_num = client_num + + @abc.abstractmethod + def __call__(self, dataset, *args, **kwargs): + raise NotImplementedError + + def __repr__(self): + """ + + Returns: Meta information for `Splitter`. + + """ + sign = inspect.signature(self.__init__).parameters.values() + meta_info = tuple([(val.name, getattr(self, val.name)) + for val in sign]) + return f'{self.__class__.__name__}{meta_info}' diff --git a/federatedscope/core/splitters/generic/__init__.py b/federatedscope/core/splitters/generic/__init__.py index 2a6d96f58..5f994563d 100644 --- a/federatedscope/core/splitters/generic/__init__.py +++ b/federatedscope/core/splitters/generic/__init__.py @@ -1,7 +1,3 @@ -from __future__ import absolute_import -from __future__ import print_function -from __future__ import division - from federatedscope.core.splitters.generic.lda_splitter import LDASplitter __all__ = ['LDASplitter'] diff --git a/federatedscope/core/splitters/generic/lda_splitter.py b/federatedscope/core/splitters/generic/lda_splitter.py index 436b4339f..c8d4789c7 100644 --- a/federatedscope/core/splitters/generic/lda_splitter.py +++ b/federatedscope/core/splitters/generic/lda_splitter.py @@ -1,12 +1,13 @@ import numpy as np +from federatedscope.core.splitters import BaseSplitter from federatedscope.core.splitters.utils import \ dirichlet_distribution_noniid_slice -class LDASplitter(object): +class LDASplitter(BaseSplitter): def __init__(self, client_num, alpha=0.5): - self.client_num = client_num self.alpha = alpha + super(LDASplitter, self).__init__(client_num) def __call__(self, dataset, prior=None): dataset = [ds for ds in dataset] @@ -17,7 +18,3 @@ def __call__(self, dataset, prior=None): prior=prior) data_list = [[dataset[idx] for idx in idxs] for idxs in idx_slice] return data_list - - def __repr__(self): - return f'{self.__class__.__name__}(client_num={self.client_num}, ' \ - f'alpha={self.alpha})' diff --git a/federatedscope/core/splitters/graph/__init__.py b/federatedscope/core/splitters/graph/__init__.py index 81cb2f404..0ba8c6e2c 100644 --- a/federatedscope/core/splitters/graph/__init__.py +++ b/federatedscope/core/splitters/graph/__init__.py @@ -1,18 +1,10 @@ -from __future__ import absolute_import -from __future__ import print_function -from __future__ import division - from federatedscope.core.splitters.graph.louvain_splitter import \ LouvainSplitter from federatedscope.core.splitters.graph.random_splitter import RandomSplitter - from federatedscope.core.splitters.graph.reltype_splitter import \ RelTypeSplitter - from federatedscope.core.splitters.graph.scaffold_splitter import \ ScaffoldSplitter -from federatedscope.core.splitters.graph.graphtype_splitter import \ - GraphTypeSplitter from federatedscope.core.splitters.graph.randchunk_splitter import \ RandChunkSplitter @@ -22,5 +14,5 @@ __all__ = [ 'LouvainSplitter', 'RandomSplitter', 'RelTypeSplitter', 'ScaffoldSplitter', - 'GraphTypeSplitter', 'RandChunkSplitter', 'Analyzer', 'ScaffoldLdaSplitter' + 'RandChunkSplitter', 'Analyzer', 'ScaffoldLdaSplitter' ] diff --git a/federatedscope/core/splitters/graph/graphtype_splitter.py b/federatedscope/core/splitters/graph/graphtype_splitter.py deleted file mode 100644 index 86be7e8e7..000000000 --- a/federatedscope/core/splitters/graph/graphtype_splitter.py +++ /dev/null @@ -1,29 +0,0 @@ -import numpy as np -from federatedscope.core.splitters.utils import \ - dirichlet_distribution_noniid_slice - - -class GraphTypeSplitter: - def __init__(self, client_num, alpha=0.5): - self.client_num = client_num - self.alpha = alpha - - def __call__(self, dataset): - r"""Split dataset via dirichlet distribution to generate non-i.i.d - data split. - - Arguments: - dataset (List or PyG.dataset): The datasets. - - Returns: - data_list (List(List(PyG.data))): Splited dataset via dirichlet. - """ - dataset = [ds for ds in dataset] - label = np.array([ds.y.item() for ds in dataset]) - idx_slice = dirichlet_distribution_noniid_slice( - label, self.client_num, self.alpha) - data_list = [[dataset[idx] for idx in idxs] for idxs in idx_slice] - return data_list - - def __repr__(self): - return f'{self.__class__.__name__}()' diff --git a/federatedscope/core/splitters/graph/louvain_splitter.py b/federatedscope/core/splitters/graph/louvain_splitter.py index b25b5dec6..9ed0e5f4c 100644 --- a/federatedscope/core/splitters/graph/louvain_splitter.py +++ b/federatedscope/core/splitters/graph/louvain_splitter.py @@ -6,8 +6,10 @@ import networkx as nx import community as community_louvain +from federatedscope.core.splitters import BaseSplitter -class LouvainSplitter(BaseTransform): + +class LouvainSplitter(BaseTransform, BaseSplitter): r""" Split Data into small data via louvain algorithm. @@ -17,11 +19,10 @@ class LouvainSplitter(BaseTransform): """ def __init__(self, client_num, delta=20): - self.client_num = client_num self.delta = delta + BaseSplitter.__init__(self, client_num) def __call__(self, data): - data.index_orig = torch.arange(data.num_nodes) G = to_networkx( data, @@ -71,6 +72,3 @@ def __call__(self, data): graphs.append(from_networkx(nx.subgraph(G, nodes))) return graphs - - def __repr__(self): - return f'{self.__class__.__name__}({self.client_num})' diff --git a/federatedscope/core/splitters/graph/randchunk_splitter.py b/federatedscope/core/splitters/graph/randchunk_splitter.py index 79db41adc..6a8f5ac66 100644 --- a/federatedscope/core/splitters/graph/randchunk_splitter.py +++ b/federatedscope/core/splitters/graph/randchunk_splitter.py @@ -1,9 +1,11 @@ import numpy as np +from torch_geometric.transforms import BaseTransform -class RandChunkSplitter: + +class RandChunkSplitter(BaseTransform): def __init__(self, client_num): - self.client_num = client_num + super(RandChunkSplitter, self).__init__(client_num) def __call__(self, dataset): r"""Split dataset via random chunk. @@ -31,6 +33,3 @@ def __call__(self, dataset): data_list[client_idx].append(graph) return data_list - - def __repr__(self): - return f'{self.__class__.__name__}()' diff --git a/federatedscope/core/splitters/graph/random_splitter.py b/federatedscope/core/splitters/graph/random_splitter.py index dca21b680..10faebcb8 100644 --- a/federatedscope/core/splitters/graph/random_splitter.py +++ b/federatedscope/core/splitters/graph/random_splitter.py @@ -6,10 +6,12 @@ import numpy as np import networkx as nx +from federatedscope.core.splitters import BaseSplitter + EPSILON = 1e-5 -class RandomSplitter(BaseTransform): +class RandomSplitter(BaseTransform, BaseSplitter): r""" Split Data into small data via random sampling. @@ -28,9 +30,8 @@ def __init__(self, sampling_rate=None, overlapping_rate=0, drop_edge=0): - + BaseSplitter.__init__(self, client_num) self.ovlap = overlapping_rate - if sampling_rate is not None: self.sampling_rate = np.array( [float(val) for val in sampling_rate.split(',')]) @@ -49,11 +50,9 @@ def __init__(self, f'The sum of sampling_rate:{self.sampling_rate} and ' f'overlapping_rate({self.ovlap}) should be 1.') - self.client_num = client_num self.drop_edge = drop_edge def __call__(self, data): - data.index_orig = torch.arange(data.num_nodes) G = to_networkx( data, @@ -104,6 +103,3 @@ def __call__(self, data): graphs.append(from_networkx(sub_g)) return graphs - - def __repr__(self): - return f'{self.__class__.__name__}({self.client_num})' diff --git a/federatedscope/core/splitters/graph/reltype_splitter.py b/federatedscope/core/splitters/graph/reltype_splitter.py index 50fd1325f..abf0e011c 100644 --- a/federatedscope/core/splitters/graph/reltype_splitter.py +++ b/federatedscope/core/splitters/graph/reltype_splitter.py @@ -6,9 +6,10 @@ from federatedscope.core.splitters.utils import \ dirichlet_distribution_noniid_slice +from federatedscope.core.splitters import BaseSplitter -class RelTypeSplitter(BaseTransform): +class RelTypeSplitter(BaseTransform, BaseSplitter): r""" Split Data into small data via dirichlet distribution to generate non-i.i.d data split. @@ -19,7 +20,7 @@ class RelTypeSplitter(BaseTransform): """ def __init__(self, client_num, alpha=0.5, realloc_mask=False): - self.client_num = client_num + BaseSplitter.__init__(self, client_num) self.alpha = alpha self.realloc_mask = realloc_mask @@ -62,6 +63,3 @@ def __call__(self, data): data_list.append(sub_g) return data_list - - def __repr__(self): - return f'{self.__class__.__name__}({self.client_num})' diff --git a/federatedscope/core/splitters/graph/scaffold_lda_splitter.py b/federatedscope/core/splitters/graph/scaffold_lda_splitter.py index 3c6db2013..08c520d31 100644 --- a/federatedscope/core/splitters/graph/scaffold_lda_splitter.py +++ b/federatedscope/core/splitters/graph/scaffold_lda_splitter.py @@ -4,11 +4,11 @@ from rdkit import Chem from rdkit import RDLogger -from rdkit.Chem.Scaffolds import MurckoScaffold from federatedscope.core.splitters.utils import \ dirichlet_distribution_noniid_slice from federatedscope.core.splitters.graph.scaffold_splitter import \ generate_scaffold +from federatedscope.core.splitters import BaseSplitter logger = logging.getLogger(__name__) @@ -150,7 +150,7 @@ def gen_scaffold_lda_split(dataset, client_num=5, alpha=0.1): return idx_slice -class ScaffoldLdaSplitter: +class ScaffoldLdaSplitter(BaseSplitter): r"""First adopt scaffold splitting and then assign the samples to clients according to Latent Dirichlet Allocation. @@ -164,7 +164,7 @@ class ScaffoldLdaSplitter: """ def __init__(self, client_num, alpha): - self.client_num = client_num + super(ScaffoldLdaSplitter, self).__init__(client_num) self.alpha = alpha def __call__(self, dataset): @@ -178,6 +178,3 @@ def __call__(self, dataset): self.alpha) data_list = [[dataset[idx] for idx in idxs] for idxs in idx_slice] return data_list - - def __repr__(self): - return f'{self.__class__.__name__}()' diff --git a/federatedscope/core/splitters/graph/scaffold_splitter.py b/federatedscope/core/splitters/graph/scaffold_splitter.py index 7afa2c81d..77dfc60df 100644 --- a/federatedscope/core/splitters/graph/scaffold_splitter.py +++ b/federatedscope/core/splitters/graph/scaffold_splitter.py @@ -5,6 +5,8 @@ from rdkit import RDLogger from rdkit.Chem.Scaffolds import MurckoScaffold +from federatedscope.core.splitters import BaseSplitter + logger = logging.getLogger(__name__) RDLogger.DisableLog('rdApp.*') @@ -47,9 +49,9 @@ def gen_scaffold_split(dataset, client_num=5): return [splits[ID] for ID in range(client_num)] -class ScaffoldSplitter: +class ScaffoldSplitter(BaseSplitter): def __init__(self, client_num): - self.client_num = client_num + super(ScaffoldSplitter, self).__init__(client_num) def __call__(self, dataset): r"""Split dataset with smiles string into scaffold split @@ -65,6 +67,3 @@ def __call__(self, dataset): idx_slice = gen_scaffold_split(dataset) data_list = [[dataset[idx] for idx in idxs] for idxs in idx_slice] return data_list - - def __repr__(self): - return f'{self.__class__.__name__}()' diff --git a/federatedscope/gfl/baseline/repro_exp/graph_level/run_graph_level.sh b/federatedscope/gfl/baseline/repro_exp/graph_level/run_graph_level.sh index 87b3f7b73..29fe3b864 100755 --- a/federatedscope/gfl/baseline/repro_exp/graph_level/run_graph_level.sh +++ b/federatedscope/gfl/baseline/repro_exp/graph_level/run_graph_level.sh @@ -20,7 +20,7 @@ elif [[ $dataset = 'proteins' ]]; then elif [[ $dataset = 'imdb-binary' ]]; then out_channels=2 hidden=64 - splitter='graph_type' + splitter='lda' else out_channels=4 hidden=1024 diff --git a/federatedscope/gfl/baseline/repro_exp/graph_level/run_graph_level_opt.sh b/federatedscope/gfl/baseline/repro_exp/graph_level/run_graph_level_opt.sh index b42480e18..09bc2a0a6 100644 --- a/federatedscope/gfl/baseline/repro_exp/graph_level/run_graph_level_opt.sh +++ b/federatedscope/gfl/baseline/repro_exp/graph_level/run_graph_level_opt.sh @@ -23,7 +23,7 @@ elif [[ $dataset = 'proteins' ]]; then elif [[ $dataset = 'imdb-binary' ]]; then out_channels=2 hidden=64 - splitter='graph_type' + splitter='lda' else out_channels=4 hidden=1024 diff --git a/federatedscope/gfl/baseline/repro_exp/graph_level/run_graph_level_prox.sh b/federatedscope/gfl/baseline/repro_exp/graph_level/run_graph_level_prox.sh index 01d2ded26..4a17bea42 100644 --- a/federatedscope/gfl/baseline/repro_exp/graph_level/run_graph_level_prox.sh +++ b/federatedscope/gfl/baseline/repro_exp/graph_level/run_graph_level_prox.sh @@ -23,7 +23,7 @@ elif [[ $dataset = 'proteins' ]]; then elif [[ $dataset = 'imdb-binary' ]]; then out_channels=2 hidden=64 - splitter='graph_type' + splitter='lda' else out_channels=4 hidden=1024 diff --git a/federatedscope/nlp/loss/__init__.py b/federatedscope/nlp/loss/__init__.py index 1eb5c1b4e..1a871bbf5 100644 --- a/federatedscope/nlp/loss/__init__.py +++ b/federatedscope/nlp/loss/__init__.py @@ -1,5 +1 @@ -from __future__ import absolute_import -from __future__ import print_function -from __future__ import division - from federatedscope.nlp.loss.character_loss import * diff --git a/federatedscope/nlp/model/__init__.py b/federatedscope/nlp/model/__init__.py index d8e21f338..941335213 100644 --- a/federatedscope/nlp/model/__init__.py +++ b/federatedscope/nlp/model/__init__.py @@ -1,7 +1,3 @@ -from __future__ import absolute_import -from __future__ import print_function -from __future__ import division - from federatedscope.nlp.model.rnn import LSTM from federatedscope.nlp.model.model_builder import get_rnn, get_transformer diff --git a/federatedscope/register.py b/federatedscope/register.py index 88fd68eb8..c5ca8285f 100644 --- a/federatedscope/register.py +++ b/federatedscope/register.py @@ -70,6 +70,13 @@ def register_auxiliary_data_loader_PIA(key, module): register(key, module, auxiliary_data_loader_PIA_dict) +transform_dict = {} + + +def register_transform(key, module): + register(key, module, transform_dict) + + splitter_dict = {} @@ -77,8 +84,22 @@ def register_splitter(key, module): register(key, module, splitter_dict) -transform_dict = {} +scheduler_dict = {} -def register_transform(key, module): - register(key, module, transform_dict) +def register_scheduler(key, module): + register(key, module, scheduler_dict) + + +optimizer_dict = {} + + +def register_optimizer(key, module): + register(key, module, optimizer_dict) + + +worker_dict = {} + + +def register_worker(key, module): + register(key, module, worker_dict)