Skip to content

Commit

Permalink
Change implementation to use add_instantiator.
Browse files Browse the repository at this point in the history
  • Loading branch information
mauvilsa committed Aug 22, 2023
1 parent 8388f88 commit f007077
Showing 1 changed file with 35 additions and 32 deletions.
67 changes: 35 additions & 32 deletions src/lightning/pytorch/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import os
import sys
from functools import partial, update_wrapper
Expand Down Expand Up @@ -51,6 +52,8 @@
locals()["ArgumentParser"] = object
locals()["Namespace"] = object

ModuleType = TypeVar("ModuleType")


class ReduceLROnPlateau(torch.optim.lr_scheduler.ReduceLROnPlateau):
def __init__(self, optimizer: Optimizer, monitor: str, *args: Any, **kwargs: Any) -> None:
Expand Down Expand Up @@ -198,30 +201,6 @@ def add_lr_scheduler_args(
self.add_class_arguments(lr_scheduler_class, nested_key, sub_configs=True, **kwargs)
self._lr_schedulers[nested_key] = (lr_scheduler_class, link_to)

def class_instantiator(self, class_type, *args, **kwargs):
for key, (base_type, hparams) in getattr(self, "_hparam_context", {}).items():
if issubclass(class_type, base_type):
with given_hyperparameters_context(hparams):
return super().class_instantiator(class_type, *args, **kwargs)
return super().class_instantiator(class_type, *args, **kwargs)

def instantiate_classes(
self,
cfg: Namespace,
instantiate_groups: bool = True,
hparam_context: Optional[Dict[str, type]] = None,
) -> Namespace:
if hparam_context:
cfg_dict = yaml.safe_load(self.dump(cfg)) # TODO: do not remove link targets!
self._hparam_context = {}
for key, base_type in hparam_context.items():
hparams = cfg_dict.get(key, {})
self._hparam_context[key] = (base_type, hparams)
init = super().instantiate_classes(cfg, instantiate_groups=instantiate_groups)
if hparam_context:
delattr(self, "_hparam_context")
return init


class SaveConfigCallback(Callback):
"""Saves a LightningCLI config to the log_dir when training starts.
Expand Down Expand Up @@ -405,6 +384,7 @@ def __init__(

self._set_seed()

self._add_instantiators()
self.before_instantiate_classes()
self.instantiate_classes()

Expand Down Expand Up @@ -551,18 +531,28 @@ def parse_arguments(self, parser: LightningArgumentParser, args: ArgsType) -> No
else:
self.config = parser.parse_args(args)

def _add_instantiators(self) -> None:
self.config_dump = yaml.safe_load(self.parser.dump(self.config, skip_link_targets=False))
if "subcommand" in self.config:
self.config_dump = self.config_dump[self.config.subcommand]

self.parser.add_instantiator(
_InstantiatorFn(cli=self, key="model"),
_get_module_type(self._model_class),
subclasses=self.subclass_mode_model,
)
self.parser.add_instantiator(
_InstantiatorFn(cli=self, key="data"),
_get_module_type(self._datamodule_class),
subclasses=self.subclass_mode_data,
)

def before_instantiate_classes(self) -> None:
"""Implement to run some code before instantiating the classes."""

def instantiate_classes(self) -> None:
"""Instantiates the classes and sets their attributes."""
hparam_prefix = ""
if "subcommand" in self.config:
hparam_prefix = self.config["subcommand"] + "."
hparam_context = {hparam_prefix + "model": self._model_class}
if self.datamodule_class is not None:
hparam_context[hparam_prefix + "data"] = self._datamodule_class
self.config_init = self.parser.instantiate_classes(self.config, hparam_context=hparam_context)
self.config_init = self.parser.instantiate_classes(self.config)
self.datamodule = self._get(self.config_init, "data")
self.model = self._get(self.config_init, "model")
self._add_configure_optimizers_method_to_model(self.subcommand)
Expand Down Expand Up @@ -788,7 +778,20 @@ def _get_short_description(component: object) -> Optional[str]:
rank_zero_warn(f"Failed parsing docstring for {component}: {ex}")


ModuleType = TypeVar("ModuleType")
def _get_module_type(value: Union[Callable, type]) -> type:
if callable(value) and not isinstance(value, type):
return inspect.signature(value).return_annotation
return value


class _InstantiatorFn:
def __init__(self, cli: LightningCLI, key: str) -> None:
self.cli = cli
self.key = key

def __call__(self, class_type: Type[ModuleType], *args, **kwargs) -> ModuleType:
with given_hyperparameters_context(self.cli.config_dump.get(self.key, {})):
return class_type(*args, **kwargs)


def instantiate_module(class_type: Type[ModuleType], config: Dict[str, Any]) -> ModuleType:
Expand Down

0 comments on commit f007077

Please sign in to comment.