Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modify KedroContext with frozen attributes instead of frozen class #3300

Merged
merged 12 commits into from
Nov 21, 2023
2 changes: 1 addition & 1 deletion RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,7 @@ We are grateful to the following for submitting PRs that contributed to this rel

## Bug fixes and other changes
* Removed fatal error from being logged when a Kedro session is created in a directory without git.
* `KedroContext` is now an `attrs`'s frozen class and `config_loader` is available as public attribute.
* `KedroContext`'s attributes are now frozen and `config_loader` is available as public attribute.
merelcht marked this conversation as resolved.
Show resolved Hide resolved
* Fixed `CONFIG_LOADER_CLASS` validation so that `TemplatedConfigLoader` can be specified in settings.py. Any `CONFIG_LOADER_CLASS` must be a subclass of `AbstractConfigLoader`.
* Added runner name to the `run_params` dictionary used in pipeline hooks.
* Updated [Databricks documentation](https://docs.kedro.org/en/0.18.1/deployment/databricks.html) to include how to get it working with IPython extension and Kedro-Viz.
Expand Down
21 changes: 13 additions & 8 deletions kedro/framework/context/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from urllib.parse import urlparse
from warnings import warn

from attrs import field, frozen
from attrs import define, field
from attrs.setters import frozen
from pluggy import PluginManager

from kedro.config import AbstractConfigLoader, MissingConfigException
Expand Down Expand Up @@ -158,18 +159,22 @@ def _expand_full_path(project_path: str | Path) -> Path:
return Path(project_path).expanduser().resolve()


@frozen
@define(slots=False) # Enable setting new attributes to `KedroContext`
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this still needed now we're not freezing anything?

Copy link
Contributor Author

@noklam noklam Nov 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was just wondering if you then still need all the tests for checking attributes?

It is needed, and for this reason I keep one of the test. If you try to set slots=True (or remove this argument since it's default), you will see the test fail. The default for standard Python library is False

https://www.attrs.org/en/stable/api.html#attrs.define - The reason for this is slotted class is more efficient and make sense if it is used as "dataclass" (think of pydantic model that get convert to JSON etc), in this case it's irrelevant

class KedroContext:
"""``KedroContext`` is the base class which holds the configuration and
Kedro's main functionality.
"""

_package_name: str
project_path: Path = field(converter=_expand_full_path)
config_loader: AbstractConfigLoader
_hook_manager: PluginManager
env: str | None = None
_extra_params: dict[str, Any] | None = field(default=None, converter=deepcopy)
_package_name: str = field(init=True, on_setattr=frozen)
project_path: Path = field(
init=True, converter=_expand_full_path, on_setattr=frozen
)
config_loader: AbstractConfigLoader = field(init=True, on_setattr=frozen)
_hook_manager: PluginManager = field(init=True, on_setattr=frozen)
env: str | None = field(init=True, on_setattr=frozen)
_extra_params: dict[str, Any] | None = field(
init=True, default=None, converter=deepcopy, on_setattr=frozen
)

"""Create a context object by providing the root of a Kedro project and
the environment configuration subfolders (see ``kedro.config.OmegaConfigLoader``)
Expand Down
28 changes: 24 additions & 4 deletions tests/framework/context/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import pytest
import toml
import yaml
from attrs.exceptions import FrozenInstanceError
from attrs.exceptions import FrozenAttributeError
from pandas.testing import assert_frame_equal

from kedro import __version__ as kedro_version
Expand Down Expand Up @@ -209,9 +209,29 @@ def test_attributes(self, tmp_path, dummy_context):
assert isinstance(dummy_context.project_path, Path)
assert dummy_context.project_path == tmp_path.resolve()

def test_immutable_instance(self, dummy_context):
with pytest.raises(FrozenInstanceError):
dummy_context.catalog = 1
@pytest.mark.parametrize(
"attr",
(
"project_path",
"config_loader",
"env",
),
)
def test_public_attributes(self, dummy_context, attr):
getattr(dummy_context, attr)

@pytest.mark.parametrize(
"internal_attr", ("_package_name", "_hook_manager", "_extra_params")
)
def test_internal_attributes(self, dummy_context, internal_attr):
getattr(dummy_context, internal_attr)

def test_immutable_class_attribute(self, dummy_context):
with pytest.raises(FrozenAttributeError):
dummy_context.project_path = "dummy"

def test_set_new_attribute(self, dummy_context):
dummy_context.mlflow = 1
merelcht marked this conversation as resolved.
Show resolved Hide resolved

def test_get_catalog_always_using_absolute_path(self, dummy_context):
config_loader = dummy_context.config_loader
Expand Down
47 changes: 44 additions & 3 deletions tests/framework/session/test_session.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import logging
import re
import subprocess
import sys
import textwrap
from collections.abc import Mapping
from pathlib import Path
from typing import Any, Type
from unittest.mock import create_autospec

import pytest
import toml
Expand All @@ -13,7 +16,6 @@
from kedro import __version__ as kedro_version
from kedro.config import AbstractConfigLoader, OmegaConfigLoader
from kedro.framework.cli.utils import _split_params
from kedro.framework.context import KedroContext
from kedro.framework.project import (
LOGGING,
ValidationError,
Expand All @@ -23,7 +25,7 @@
_ProjectSettings,
)
from kedro.framework.session import KedroSession
from kedro.framework.session.session import KedroSessionError
from kedro.framework.session.session import KedroContext, KedroSessionError
from kedro.framework.session.shelvestore import ShelveStore
from kedro.framework.session.store import BaseSessionStore

Expand All @@ -43,6 +45,36 @@ class BadConfigLoader:
"""


ATTRS_ATTRIBUTE = "__attrs_attrs__"

NEW_TYPING = sys.version_info[:3] >= (3, 7, 0) # PEP 560


def create_attrs_autospec(spec: Type, spec_set: bool = True) -> Any:
"""Creates a mock of an attr class (creates mocks recursively on all attributes).
https://github.com/python-attrs/attrs/issues/462#issuecomment-1134656377

:param spec: the spec to mock
:param spec_set: if True, AttributeError will be raised if an attribute that is not in the spec is set.
"""

if not hasattr(spec, ATTRS_ATTRIBUTE):
raise TypeError(f"{spec!r} is not an attrs class")
mock = create_autospec(spec, spec_set=spec_set)
for attribute in getattr(spec, ATTRS_ATTRIBUTE):
attribute_type = attribute.type
if NEW_TYPING:
# A[T] does not get a copy of __dict__ from A(Generic[T]) anymore, use __origin__ to get it
while hasattr(attribute_type, "__origin__"):
attribute_type = attribute_type.__origin__
if hasattr(attribute_type, ATTRS_ATTRIBUTE):
mock_attribute = create_attrs_autospec(attribute_type, spec_set)
else:
mock_attribute = create_autospec(attribute_type, spec_set=spec_set)
object.__setattr__(mock, attribute.name, mock_attribute)
return mock


@pytest.fixture
def mock_runner(mocker):
mock_runner = mocker.patch(
Expand All @@ -55,7 +87,12 @@ def mock_runner(mocker):

@pytest.fixture
def mock_context_class(mocker):
return mocker.patch("kedro.framework.session.session.KedroContext", autospec=True)
mock_cls = create_attrs_autospec(KedroContext)
return mocker.patch(
"kedro.framework.session.session.KedroContext",
autospec=True,
return_value=mock_cls,
)


def _mock_imported_settings_paths(mocker, mock_settings):
Expand All @@ -74,7 +111,9 @@ def mock_settings(mocker):

@pytest.fixture
def mock_settings_context_class(mocker, mock_context_class):
# mocker.patch("dynaconf.base.LazySettings.unset")
merelcht marked this conversation as resolved.
Show resolved Hide resolved
class MockSettings(_ProjectSettings):
# dynaconf automatically deleted some attribute when the class is MagicMock
_CONTEXT_CLASS = Validator(
"CONTEXT_CLASS", default=lambda *_: mock_context_class
)
Expand Down Expand Up @@ -601,7 +640,9 @@ def test_run(
"__default__": mocker.Mock(),
},
)
print(f"{mock_context_class=}")
mock_context = mock_context_class.return_value
print(f"{mock_context=}")
merelcht marked this conversation as resolved.
Show resolved Hide resolved
mock_catalog = mock_context._get_catalog.return_value
mock_runner.__name__ = "SequentialRunner"
mock_pipeline = mock_pipelines.__getitem__.return_value.filter.return_value
Expand Down