-
Notifications
You must be signed in to change notification settings - Fork 219
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add more registers & refactor splitter (#372)
* add more registerable components * update splitter interface * remove graph_type splitter and update __repr__ for base splitter * add README for register mechanism Co-authored-by: Jones Wong <joneswong@users.noreply.github.com>
- Loading branch information
1 parent
ccafa14
commit 6095dd9
Showing
37 changed files
with
348 additions
and
152 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
# Register | ||
|
||
In addition to the rich collection of datasets, models, evaluation metrics, etc., FederatedScope (FS) also allows users to create their own ingredients or introduce more customized modules to FS. Inspired by GraphGym, we provide `register` mechanism to help integrating your own components into the FS-based federated learning workflow, including: | ||
|
||
* Configurations [`federatedscope/contrib/config`](https://github.com/alibaba/FederatedScope/tree/master/federatedscope/contrib/config) | ||
* Data [`federatedscope/contrib/data`](https://github.com/alibaba/FederatedScope/tree/master/federatedscope/contrib/data) | ||
* Loss [`federatedscope/contrib/loss`](https://github.com/alibaba/FederatedScope/tree/master/federatedscope/contrib/loss) | ||
* Metrics [`federatedscope/contrib/metrics`](https://github.com/alibaba/FederatedScope/tree/master/federatedscope/contrib/metrics) | ||
* Model [`federatedscope/contrib/model`](https://github.com/alibaba/FederatedScope/tree/master/federatedscope/contrib/model) | ||
* Optimizer [`federatedscope/contrib/optimizer`](https://github.com/alibaba/FederatedScope/tree/master/federatedscope/contrib/optimizer) | ||
* Scheduler [`federatedscope/contrib/scheduler`](https://github.com/alibaba/FederatedScope/tree/master/federatedscope/contrib/scheduler) | ||
* Splitter [`federatedscope/contrib/splitter`](https://github.com/alibaba/FederatedScope/tree/master/federatedscope/contrib/splitter) | ||
* Trainer [`federatedscope/contrib/trainer`](https://github.com/alibaba/FederatedScope/tree/master/federatedscope/contrib/trainer) | ||
* Worker [`federatedscope/contrib/worker`](https://github.com/alibaba/FederatedScope/tree/master/federatedscope/contrib/worker) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +0,0 @@ | ||
from __future__ import absolute_import | ||
from __future__ import print_function | ||
from __future__ import division | ||
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
from os.path import dirname, basename, isfile, join | ||
import glob | ||
|
||
modules = glob.glob(join(dirname(__file__), "*.py")) | ||
__all__ = [ | ||
basename(f)[:-3] for f in modules | ||
if isfile(f) and not f.endswith('__init__.py') | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
from federatedscope.register import register_criterion | ||
|
||
|
||
def call_my_criterion(type, device): | ||
try: | ||
import torch.nn as nn | ||
except ImportError: | ||
nn = None | ||
criterion = None | ||
|
||
if type == 'mycriterion': | ||
if nn is not None: | ||
criterion = nn.CrossEntropyLoss().to(device) | ||
return criterion | ||
|
||
|
||
register_criterion('mycriterion', call_my_criterion) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
from os.path import dirname, basename, isfile, join | ||
import glob | ||
|
||
modules = glob.glob(join(dirname(__file__), "*.py")) | ||
__all__ = [ | ||
basename(f)[:-3] for f in modules | ||
if isfile(f) and not f.endswith('__init__.py') | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
from federatedscope.register import register_optimizer | ||
|
||
|
||
def call_my_optimizer(model, type, lr, **kwargs): | ||
try: | ||
import torch.optim as optim | ||
except ImportError: | ||
optim = None | ||
optimizer = None | ||
|
||
if type == 'myoptimizer': | ||
if optim is not None: | ||
optimizer = optim.Adam(model.parameters(), lr=lr, **kwargs) | ||
return optimizer | ||
|
||
|
||
register_optimizer('myoptimizer', call_my_optimizer) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
from os.path import dirname, basename, isfile, join | ||
import glob | ||
|
||
modules = glob.glob(join(dirname(__file__), "*.py")) | ||
__all__ = [ | ||
basename(f)[:-3] for f in modules | ||
if isfile(f) and not f.endswith('__init__.py') | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
from federatedscope.register import register_scheduler | ||
|
||
|
||
def call_my_scheduler(optimizer, type): | ||
try: | ||
import torch.optim as optim | ||
except ImportError: | ||
optim = None | ||
scheduler = None | ||
|
||
if type == 'myscheduler': | ||
if optim is not None: | ||
lr_lambda = [lambda epoch: epoch // 30] | ||
scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) | ||
return scheduler | ||
|
||
|
||
register_scheduler('myscheduler', call_my_scheduler) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
from os.path import dirname, basename, isfile, join | ||
import glob | ||
|
||
modules = glob.glob(join(dirname(__file__), "*.py")) | ||
__all__ = [ | ||
basename(f)[:-3] for f in modules | ||
if isfile(f) and not f.endswith('__init__.py') | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
from federatedscope.register import register_splitter | ||
from federatedscope.core.splitters import BaseSplitter | ||
|
||
|
||
class MySplitter(BaseSplitter): | ||
def __init__(self, client_num, **kwargs): | ||
super(MySplitter, self).__init__(client_num, **kwargs) | ||
|
||
def __call__(self, dataset, *args, **kwargs): | ||
# Dummy splitter, only for demonstration | ||
per_samples = len(dataset) // self.client_num | ||
data_list, cur_index = [], 0 | ||
for i in range(self.client_num): | ||
data_list.append( | ||
[x for x in range(cur_index, cur_index + per_samples)]) | ||
cur_index += per_samples | ||
return data_list | ||
|
||
|
||
def call_my_splitter(client_num, **kwargs): | ||
if type == 'mysplitter': | ||
splitter = MySplitter(client_num, **kwargs) | ||
return splitter | ||
|
||
|
||
register_splitter('mysplitter', call_my_splitter) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
from os.path import dirname, basename, isfile, join | ||
import glob | ||
|
||
modules = glob.glob(join(dirname(__file__), "*.py")) | ||
__all__ = [ | ||
basename(f)[:-3] for f in modules | ||
if isfile(f) and not f.endswith('__init__.py') | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
from federatedscope.register import register_worker | ||
from federatedscope.core.workers import Server, Client | ||
|
||
|
||
# Build your worker here. | ||
class MyClient(Client): | ||
pass | ||
|
||
|
||
class MyServer(Server): | ||
pass | ||
|
||
|
||
def call_my_worker(method): | ||
if method == 'mymethod': | ||
worker_builder = {'client': MyClient, 'server': MyServer} | ||
return worker_builder | ||
|
||
|
||
register_worker('mymethod', call_my_worker) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.