Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Ehn] Enhance config module #899

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ppsci/arch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def build_model(cfg):
"""Build model

Args:
cfg (AttrDict): Arch config.
cfg (DictConfig): Arch config.

Returns:
nn.Layer: Model.
Expand Down
2 changes: 1 addition & 1 deletion ppsci/constraint/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def build_constraint(cfg, equation_dict, geom_dict):
"""Build constraint(s).

Args:
cfg (List[AttrDict]): Constraint config list.
cfg (List[DictConfig]): Constraint config list.
equation_dict (Dct[str, Equation]): Equation(s) in dict.
geom_dict (Dct[str, Geometry]): Geometry(ies) in dict.

Expand Down
2 changes: 1 addition & 1 deletion ppsci/data/dataset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def build_dataset(cfg) -> "io.Dataset":
"""Build dataset

Args:
cfg (List[AttrDict]): dataset config list.
cfg (List[DictConfig]): dataset config list.

Returns:
Dict[str, io.Dataset]: dataset.
Expand Down
2 changes: 1 addition & 1 deletion ppsci/equation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def build_equation(cfg):
"""Build equation(s)

Args:
cfg (List[AttrDict]): Equation(s) config list.
cfg (List[DictConfig]): Equation(s) config list.

Returns:
Dict[str, Equation]: Equation(s) in dict.
Expand Down
2 changes: 1 addition & 1 deletion ppsci/geometry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def build_geometry(cfg):
"""Build geometry(ies)

Args:
cfg (List[AttrDict]): Geometry config list.
cfg (List[DictConfig]): Geometry config list.

Returns:
Dict[str, Geometry]: Geometry(ies) in dict.
Expand Down
2 changes: 1 addition & 1 deletion ppsci/loss/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def build_loss(cfg):
"""Build loss.

Args:
cfg (AttrDict): Loss config.
cfg (DictConfig): Loss config.
Returns:
Loss: Callable loss object.
"""
Expand Down
2 changes: 1 addition & 1 deletion ppsci/loss/mtl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def build_mtl_aggregator(cfg):
"""Build loss aggregator with multi-task learning method.
Args:
cfg (AttrDict): Aggregator config.
cfg (DictConfig): Aggregator config.
Returns:
Loss: Callable loss aggregator object.
"""
Expand Down
2 changes: 1 addition & 1 deletion ppsci/metric/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def build_metric(cfg):
"""Build metric.

Args:
cfg (List[AttrDict]): List of metric config.
cfg (List[DictConfig]): List of metric config.

Returns:
Dict[str, Metric]: Dict of callable metric object.
Expand Down
4 changes: 2 additions & 2 deletions ppsci/optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def build_lr_scheduler(cfg, epochs, iters_per_epoch):
"""Build learning rate scheduler.

Args:
cfg (AttrDict): Learning rate scheduler config.
cfg (DictConfig): Learning rate scheduler config.
epochs (int): Total epochs.
iters_per_epoch (int): Number of iterations of one epoch.

Expand All @@ -57,7 +57,7 @@ def build_optimizer(cfg, model_list, epochs, iters_per_epoch):
"""Build optimizer and learning rate scheduler

Args:
cfg (AttrDict): Learning rate scheduler config.
cfg (DictConfig): Learning rate scheduler config.
model_list (Tuple[nn.Layer, ...]): Tuple of model(s).
epochs (int): Total epochs.
iters_per_epoch (int): Number of iterations of one epoch.
Expand Down
141 changes: 98 additions & 43 deletions ppsci/solver/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,12 +158,18 @@ def __init__(
cfg: Optional[DictConfig] = None,
):
self.cfg = cfg
if isinstance(cfg, DictConfig):
# (Recommended)Params can be passed within cfg
# rather than passed to 'Solver.__init__' one-by-one.
self._parse_params_from_cfg(cfg)

# set model
self.model = model
# set constraint
self.constraint = constraint
# set output directory
self.output_dir = output_dir
if not cfg:
self.output_dir = output_dir

# set optimizer
self.optimizer = optimizer
Expand Down Expand Up @@ -192,19 +198,20 @@ def __init__(
)

# set training hyper-parameter
self.epochs = epochs
self.iters_per_epoch = iters_per_epoch
# set update_freq for gradient accumulation
self.update_freq = update_freq
# set checkpoint saving frequency
self.save_freq = save_freq
# set logging frequency
self.log_freq = log_freq

# set evaluation hyper-parameter
self.eval_during_train = eval_during_train
self.start_eval_epoch = start_eval_epoch
self.eval_freq = eval_freq
if not cfg:
self.epochs = epochs
self.iters_per_epoch = iters_per_epoch
# set update_freq for gradient accumulation
self.update_freq = update_freq
# set checkpoint saving frequency
self.save_freq = save_freq
# set logging frequency
self.log_freq = log_freq

# set evaluation hyper-parameter
self.eval_during_train = eval_during_train
self.start_eval_epoch = start_eval_epoch
self.eval_freq = eval_freq

# initialize training log(training loss, time cost, etc.) recorder during one epoch
self.train_output_info: Dict[str, misc.AverageMeter] = {}
Expand All @@ -221,46 +228,45 @@ def __init__(
"reader_cost": misc.AverageMeter("reader_cost", ".5f", postfix="s"),
}

# fix seed for reproducibility
self.seed = seed

# set running device
if device != "cpu" and paddle.device.get_device() == "cpu":
if not cfg:
self.device = device
if self.device != "cpu" and paddle.device.get_device() == "cpu":
logger.warning(f"Set device({device}) to 'cpu' for only cpu available.")
device = "cpu"
self.device = paddle.set_device(device)
self.device = "cpu"
self.device = paddle.set_device(self.device)

# set equations for physics-driven or data-physics hybrid driven task, such as PINN
self.equation = equation

# set geometry for generating data
self.geom = {} if geom is None else geom

# set validator
self.validator = validator

# set visualizer
self.visualizer = visualizer

# set automatic mixed precision(AMP) configuration
self.use_amp = use_amp
self.amp_level = amp_level
if not cfg:
self.use_amp = use_amp
self.amp_level = amp_level
self.scaler = amp.GradScaler(True) if self.use_amp else None

# whether calculate metrics by each batch during evaluation, mainly for memory efficiency
self.compute_metric_by_batch = compute_metric_by_batch
if not cfg:
self.compute_metric_by_batch = compute_metric_by_batch
if validator is not None:
for metric in itertools.chain(
*[_v.metric.values() for _v in self.validator.values()]
):
if metric.keep_batch ^ compute_metric_by_batch:
if metric.keep_batch ^ self.compute_metric_by_batch:
raise ValueError(
f"{misc.typename(metric)}.keep_batch should be "
f"{compute_metric_by_batch} when compute_metric_by_batch="
f"{compute_metric_by_batch}."
f"{self.compute_metric_by_batch} when compute_metric_by_batch="
f"{self.compute_metric_by_batch}."
)
# whether set `stop_gradient=True` for every Tensor if no differentiation involved during evaluation
self.eval_with_no_grad = eval_with_no_grad
if not cfg:
self.eval_with_no_grad = eval_with_no_grad

self.rank = dist.get_rank()
self.world_size = dist.get_world_size()
Expand All @@ -278,34 +284,37 @@ def __init__(
# set moving average model(optional)
self.ema_model = None
if self.cfg and any(key in self.cfg.TRAIN for key in ["ema", "swa"]):
if "ema" in self.cfg.TRAIN:
self.avg_freq = self.cfg.TRAIN.ema.avg_freq
if "ema" in self.cfg.TRAIN and cfg.TRAIN.ema.get("use_ema", False):
self.ema_model = ema.ExponentialMovingAverage(
self.model, self.cfg.TRAIN.ema.decay
)
elif "swa" in self.cfg.TRAIN:
self.avg_freq = self.cfg.TRAIN.swa.avg_freq
elif "swa" in self.cfg.TRAIN and cfg.TRAIN.swa.get("use_swa", False):
self.ema_model = ema.StochasticWeightAverage(self.model)

# load pretrained model, usually used for transfer learning
self.pretrained_model_path = pretrained_model_path
if pretrained_model_path is not None:
save_load.load_pretrain(self.model, pretrained_model_path, self.equation)
if not cfg:
self.pretrained_model_path = pretrained_model_path
if self.pretrained_model_path is not None:
save_load.load_pretrain(
self.model, self.pretrained_model_path, self.equation
)

# initialize an dict for tracking best metric during training
self.best_metric = {
"metric": float("inf"),
"epoch": 0,
}
# load model checkpoint, usually used for resume training
if checkpoint_path is not None:
if pretrained_model_path is not None:
if not cfg:
self.checkpoint_path = checkpoint_path
if self.checkpoint_path is not None:
if self.pretrained_model_path is not None:
logger.warning(
"Detected 'pretrained_model_path' is given, weights in which might be"
"overridden by weights loaded from given 'checkpoint_path'."
)
loaded_metric = save_load.load_checkpoint(
checkpoint_path,
self.checkpoint_path,
self.model,
self.optimizer,
self.scaler,
Expand Down Expand Up @@ -366,7 +375,9 @@ def dist_wrapper(model: nn.Layer) -> paddle.DataParallel:

# set VisualDL tool
self.vdl_writer = None
if use_vdl:
if not cfg:
self.use_vdl = use_vdl
if self.use_vdl:
with misc.RankZeroOnly(self.rank) as is_master:
if is_master:
self.vdl_writer = vdl.LogWriter(osp.join(output_dir, "vdl"))
Expand All @@ -377,7 +388,9 @@ def dist_wrapper(model: nn.Layer) -> paddle.DataParallel:

# set WandB tool
self.wandb_writer = None
if use_wandb:
if not cfg:
self.use_wandb = use_wandb
if self.use_wandb:
try:
import wandb
except ModuleNotFoundError:
Expand All @@ -390,7 +403,9 @@ def dist_wrapper(model: nn.Layer) -> paddle.DataParallel:

# set TensorBoardX tool
self.tbd_writer = None
if use_tbd:
if not cfg:
self.use_tbd = use_tbd
if self.use_tbd:
try:
import tensorboardX
except ModuleNotFoundError:
Expand Down Expand Up @@ -984,3 +999,43 @@ def plot_loss_history(
smooth_step=smooth_step,
use_semilogy=use_semilogy,
)

def _parse_params_from_cfg(self, cfg: DictConfig):
"""
Parse hyper-parameters from DictConfig.
"""
self.output_dir = cfg.output_dir
self.log_freq = cfg.log_freq
self.use_tbd = cfg.use_tbd
self.use_vdl = cfg.use_vdl
self.wandb_config = cfg.wandb_config
self.use_wandb = cfg.use_wandb
self.device = cfg.device
self.to_static = cfg.to_static

self.use_amp = cfg.use_amp
self.amp_level = cfg.amp_level

self.epochs = cfg.TRAIN.epochs
self.iters_per_epoch = cfg.TRAIN.iters_per_epoch
self.update_freq = cfg.TRAIN.update_freq
self.save_freq = cfg.TRAIN.save_freq
self.eval_during_train = cfg.TRAIN.eval_during_train
self.start_eval_epoch = cfg.TRAIN.start_eval_epoch
self.eval_freq = cfg.TRAIN.eval_freq
self.checkpoint_path = cfg.TRAIN.checkpoint_path

if "ema" in cfg.TRAIN and cfg.TRAIN.ema.get("use_ema", False):
self.avg_freq = cfg.TRAIN.ema.avg_freq
elif "swa" in cfg.TRAIN and cfg.TRAIN.swa.get("use_swa", False):
self.avg_freq = cfg.TRAIN.swa.avg_freq

self.compute_metric_by_batch = cfg.EVAL.compute_metric_by_batch
self.eval_with_no_grad = cfg.EVAL.eval_with_no_grad

if cfg.mode == "train":
self.pretrained_model_path = cfg.TRAIN.pretrained_model_path
elif cfg.mode == "eval":
self.pretrained_model_path = cfg.EVAL.pretrained_model_path
elif cfg.mode in ["export", "infer"]:
self.pretrained_model_path = cfg.INFER.pretrained_model_path
5 changes: 3 additions & 2 deletions ppsci/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# NOTE: Put config module import at the top level for register default config(s) in
# ConfigStore at the begining of ppsci
from ppsci.utils import config # isort:skip # noqa: F401
from ppsci.utils import ema
from ppsci.utils import initializer
from ppsci.utils import logger
Expand All @@ -22,7 +25,6 @@
from ppsci.utils.checker import dynamic_import_to_globals
from ppsci.utils.checker import run_check
from ppsci.utils.checker import run_check_mesh
from ppsci.utils.config import AttrDict
from ppsci.utils.expression import ExpressionSolver
from ppsci.utils.misc import AverageMeter
from ppsci.utils.misc import set_random_seed
Expand All @@ -39,7 +41,6 @@
from ppsci.utils.writer import save_tecplot_file

__all__ = [
"AttrDict",
"AverageMeter",
"ExpressionSolver",
"initializer",
Expand Down
9 changes: 3 additions & 6 deletions ppsci/utils/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,10 @@

class InitCallback(Callback):
"""Callback class for:
1. Parse config dict from given yaml file and check its validity, complete missing items by its' default values.
1. Parse config dict from given yaml file and check its validity.
2. Fixing random seed to 'config.seed'.
3. Initialize logger while creating output directory(if not exist).
4. Enable prim mode if specified.

NOTE: This callback is mainly for reducing unnecessary duplicate code in each
examples code when runing with hydra.
Expand All @@ -60,8 +61,6 @@ class InitCallback(Callback):
"""

def on_job_start(self, config: DictConfig, **kwargs: Any) -> None:
# check given cfg using pre-defined pydantic schema in 'SolverConfig', error(s) will be raised
# if any checking failed at this step
if importlib.util.find_spec("pydantic") is not None:
from pydantic import ValidationError
else:
Expand All @@ -76,8 +75,6 @@ def on_job_start(self, config: DictConfig, **kwargs: Any) -> None:
# error(s) will be printed and exit program if any checking failed at this step
try:
_model_pydantic = config_module.SolverConfig(**dict(config))
# complete missing items with default values pre-defined in pydantic schema in
# 'SolverConfig'
full_cfg = DictConfig(_model_pydantic.model_dump())
except ValidationError as e:
print(e)
Expand All @@ -100,7 +97,7 @@ def on_job_start(self, config: DictConfig, **kwargs: Any) -> None:

# enable prim if specified
if "prim" in full_cfg and bool(full_cfg.prim):
# Mostly for dy2st running, will be removed in the future
# Mostly for compiler running with dy2st.
from paddle.framework import core

core.set_prim_eager_enabled(True)
Expand Down
Loading