Skip to content

Commit

Permalink
Coding style (NVIDIA#171)
Browse files Browse the repository at this point in the history
* Coding style check (flake8/isort/black/pydocstyle).

* WIP

* Done with app_common, fuel/[common,sec], lighter, security

* Improve docstring on BaseContext
  • Loading branch information
IsaacYangSLA authored Feb 3, 2022
1 parent 889e8c7 commit 8e5a836
Show file tree
Hide file tree
Showing 58 changed files with 415 additions and 375 deletions.
6 changes: 3 additions & 3 deletions nvflare/apis/impl/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,7 +591,7 @@ def cancel_task(
Args:
task (Task): the task to be cancelled
completion_status ([type], optional): the completion status for this cancellation. Defaults to TaskCompletionStatus.CANCELLED.
completion_status (str, optional): the completion status for this cancellation. Defaults to TaskCompletionStatus.CANCELLED.
fl_ctx (Optional[FLContext], optional): FLContext associated with this cancellation. Defaults to None.
"""
task.completion_status = completion_status
Expand All @@ -600,7 +600,7 @@ def cancel_all_tasks(self, completion_status=TaskCompletionStatus.CANCELLED, fl_
"""Cancel all standing tasks in this controller.
Args:
completion_status ([type], optional): the completion status for this cancellation. Defaults to TaskCompletionStatus.CANCELLED.
completion_status (str, optional): the completion status for this cancellation. Defaults to TaskCompletionStatus.CANCELLED.
fl_ctx (Optional[FLContext], optional): FLContext associated with this cancellation. Defaults to None.
"""
with self._task_lock:
Expand All @@ -611,7 +611,7 @@ def abort_task(self, task, fl_ctx: FLContext):
"""Ask all clients to abort the execution of the specified task.
Args:
task ([type]): the task to be aborted
task (str): the task to be aborted
fl_ctx (FLContext): FLContext associated with this action
"""
self.log_info(fl_ctx, "asked all clients to abort task {}".format(task.name))
Expand Down
6 changes: 2 additions & 4 deletions nvflare/app_common/abstract/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@
class Aggregator(FLComponent, ABC):
@abstractmethod
def accept(self, shareable: Shareable, fl_ctx: FLContext) -> bool:
"""
accept the shareable submitted by the client.
"""Accept the shareable submitted by the client.
Args:
shareable: submitted Shareable object
Expand All @@ -37,8 +36,7 @@ def accept(self, shareable: Shareable, fl_ctx: FLContext) -> bool:

@abstractmethod
def aggregate(self, fl_ctx: FLContext) -> Shareable:
"""
perform the aggregation for all the received Shareable from the clients.
"""Perform the aggregation for all the received Shareable from the clients.
Args:
fl_ctx: FLContext
Expand Down
5 changes: 1 addition & 4 deletions nvflare/app_common/abstract/formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,9 @@


class Formatter(FLComponent):
def __init__(self) -> None:
super(Formatter, self).__init__()

@abstractmethod
def format(self, fl_ctx: FLContext) -> str:
"""Format the data into human readable string for
"""Format the data into human readable string.
Args:
fl_ctx (FLContext): FL Context object.
Expand Down
9 changes: 2 additions & 7 deletions nvflare/app_common/abstract/learnable.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,8 @@


class Learnable(dict):
def __init__(self) -> None:
super().__init__()

def to_bytes(self) -> bytes:
"""
method to serialize the Learnable object into bytes.
"""Method to serialize the Learnable object into bytes.
Returns:
object serialized in bytes.
Expand All @@ -33,8 +29,7 @@ def to_bytes(self) -> bytes:

@classmethod
def from_bytes(cls, data: bytes):
"""
method to convert the object bytes into Learnable object.
"""Method to convert the object bytes into Learnable object.
Args:
data: a bytes object
Expand Down
6 changes: 2 additions & 4 deletions nvflare/app_common/abstract/learnable_persistor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@
class LearnablePersistor(FLComponent, ABC):
@abstractmethod
def load(self, fl_ctx: FLContext) -> Learnable:
"""
load the Learnable object.
"""Load the Learnable object.
Args:
fl_ctx: FLContext
Expand All @@ -36,8 +35,7 @@ def load(self, fl_ctx: FLContext) -> Learnable:

@abstractmethod
def save(self, learnable: Learnable, fl_ctx: FLContext):
"""
persist the Learnable object
"""Persist the Learnable object.
Args:
learnable: the Learnable object to be saved
Expand Down
28 changes: 9 additions & 19 deletions nvflare/app_common/abstract/learner_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,22 +20,18 @@

class Learner(FLComponent):
def initialize(self, parts: dict, fl_ctx: FLContext):
"""
Initialize the Learner object. This is called before the Learner can train or validate.
"""Initialize the Learner object. This is called before the Learner can train or validate.
This is called only once.
Args:
parts: components to be used by the Trainer
fl_ctx: FLContext of the running environment
Returns:
"""
pass

def train(self, data: Shareable, fl_ctx: FLContext, abort_signal: Signal) -> Shareable:
"""
Called to perform training. Can be called many times during the lifetime of the Learner.
"""Called to perform training. Can be called many times during the lifetime of the Learner.
Args:
data: the training input data (e.g. model weights)
Expand All @@ -48,8 +44,7 @@ def train(self, data: Shareable, fl_ctx: FLContext, abort_signal: Signal) -> Sha
return make_reply(ReturnCode.TASK_UNSUPPORTED)

def get_model_for_validation(self, model_name: str, fl_ctx: FLContext) -> Shareable:
"""
Called to return the trained model from the Learner.
"""Called to return the trained model from the Learner.
Args:
model_name: type of the model for validation
Expand All @@ -61,8 +56,7 @@ def get_model_for_validation(self, model_name: str, fl_ctx: FLContext) -> Sharea
return make_reply(ReturnCode.TASK_UNSUPPORTED)

def validate(self, data: Shareable, fl_ctx: FLContext, abort_signal: Signal) -> Shareable:
"""
Called to perform validation. Can be called many times during the lifetime of the Learner.
"""Called to perform validation. Can be called many times during the lifetime of the Learner.
Args:
data: the training input data (e.g. model weights)
Expand All @@ -75,28 +69,24 @@ def validate(self, data: Shareable, fl_ctx: FLContext, abort_signal: Signal) ->
return make_reply(ReturnCode.TASK_UNSUPPORTED)

def abort(self, fl_ctx: FLContext):
"""
Called (from another thread) to abort the current task (validate or train)
"""Called (from another thread) to abort the current task (validate or train).
Note: this is to abort the current task only, not the Trainer. After aborting, the Learner.
may still be called to perform another task.
Args:
fl_ctx: FLContext of the running environment
Returns:
"""
pass

def finalize(self, fl_ctx: FLContext):
"""
Called to finalize the Learner (close/release resources gracefully).
"""Called to finalize the Learner (close/release resources gracefully).
After this call, the Learner will be destroyed.
Args:
fl_ctx: FLContext of the running environment
Returns:
"""
pass
19 changes: 6 additions & 13 deletions nvflare/app_common/abstract/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,36 +12,29 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""
"""The Learnable in the deep learning domain is usually called Model by researchers.
The Learnable in the deep learning domain is usually called Model by researchers.
This import simply lets you call the Learnable 'Model'.
Model Learnable is a dict that contains two items: weights and meta info
"""
from nvflare.apis.dxo import DXO, DataKind

from .learnable import Learnable as ModelLearnable


class ModelLearnableKey(object):

"""
Model Learnable is a dict that contains two items: weights and meta info
"""

WEIGHTS = "weights"
META = "meta"


def validate_model_learnable(model_learnable: ModelLearnable) -> str:
"""
Check whether the specified model is a valid Model Shareable
"""Check whether the specified model is a valid Model Shareable.
Args:
model: model to be validated
Returns: error text
model_learnable (ModelLearnable): model to be validated
Returns:
str: error text or empty string if no error
"""
if not isinstance(model_learnable, ModelLearnable):
return "invalid model learnable: expect Model type but got {}".format(type(model_learnable))
Expand Down
7 changes: 2 additions & 5 deletions nvflare/app_common/abstract/model_locator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,8 @@


class ModelLocator(FLComponent):
def __init__(self):
super(ModelLocator, self).__init__()

def get_model_names(self, fl_ctx: FLContext) -> List[str]:
"""List the name of the models
"""List the name of the models.
Args:
fl_ctx (FLContext): FL Context object
Expand All @@ -35,7 +32,7 @@ def get_model_names(self, fl_ctx: FLContext) -> List[str]:
pass

def locate_model(self, model_name, fl_ctx: FLContext) -> DXO:
"""Locate a single model by it's name
"""Locate a single model by it's name.
Args:
model_name (str): Name of the model.
Expand Down
19 changes: 4 additions & 15 deletions nvflare/app_common/abstract/model_persistor.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@ def save(self, learnable: ModelLearnable, fl_ctx: FLContext):

@abstractmethod
def load_model(self, fl_ctx: FLContext) -> ModelLearnable:
"""
initialize and load the model.
"""Initialize and load the model.
Args:
fl_ctx: FLContext
Expand All @@ -44,8 +43,7 @@ def load_model(self, fl_ctx: FLContext) -> ModelLearnable:

@abstractmethod
def save_model(self, model: ModelLearnable, fl_ctx: FLContext):
"""
persist the model object
"""Persist the model object.
Args:
model: Model object to be saved
Expand All @@ -55,8 +53,8 @@ def save_model(self, model: ModelLearnable, fl_ctx: FLContext):
pass

def get_model_inventory(self, fl_ctx: FLContext) -> {str: ModelDescriptor}:
"""
Get the model inventory of the ModelPersister
"""Get the model inventory of the ModelPersister.
Args:
fl_ctx: FLContext
Expand All @@ -66,13 +64,4 @@ def get_model_inventory(self, fl_ctx: FLContext) -> {str: ModelDescriptor}:
pass

def get_model(self, model_file, fl_ctx: FLContext) -> object:
"""
Args:
model_file:
fl_ctx:
Returns:
"""
pass
14 changes: 6 additions & 8 deletions nvflare/app_common/abstract/model_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,31 +20,29 @@
class ModelProcessor(ABC):
@abstractmethod
def extract_model(self, network, multi_processes: bool, model_vars: dict, fl_ctx: FLContext) -> dict:
"""
Call to extract the current model from the training network
"""Call to extract the current model from the training network.
Args:
network: training network
multi_processes: boolean to indicates if it's a multi-processes
model_vars: global model dict
fl_ctx: FLContext
Returns:
a dictionary representing the model
"""
pass

@abstractmethod
def apply_model(self, network, multi_processes: bool, model_params: dict, fl_ctx: FLContext, options=None):
"""
Call to apply the model parameters to the training network
"""Call to apply the model parameters to the training network.
Args:
network: training network
multi_processes: boolean to indicates if it's a multi-processes
model_params: model parameters to apply
fl_ctx: FLContext
options:
Returns:
options: optional information that can be used for this process
"""
pass
6 changes: 2 additions & 4 deletions nvflare/app_common/abstract/shareable_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@
class ShareableGenerator(FLComponent, ABC):
@abstractmethod
def learnable_to_shareable(self, model: Learnable, fl_ctx: FLContext) -> Shareable:
"""
generate the initial Shareable from the Learnable object.
"""Generate the initial Shareable from the Learnable object.
Args:
model: model object
Expand All @@ -38,8 +37,7 @@ def learnable_to_shareable(self, model: Learnable, fl_ctx: FLContext) -> Shareab

@abstractmethod
def shareable_to_learnable(self, shareable: Shareable, fl_ctx: FLContext) -> Learnable:
"""
construct the Learnable object from Shareable
"""Construct the Learnable object from Shareable.
Args:
shareable: shareable
Expand Down
Loading

0 comments on commit 8e5a836

Please sign in to comment.