Skip to content

Commit

Permalink
[structured config] Enable nesting of non-structured resources in str…
Browse files Browse the repository at this point in the history
…uctured resources
  • Loading branch information
benpankow committed Jan 19, 2023
1 parent 04934d1 commit 746be30
Show file tree
Hide file tree
Showing 4 changed files with 237 additions and 42 deletions.
103 changes: 63 additions & 40 deletions python_modules/dagster/dagster/_config/structured_config.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import inspect
from typing import Generic, Mapping, TypeVar, Union

from typing_extensions import TypeAlias, dataclass_transform, get_args, get_origin
from typing_extensions import TypeAlias, dataclass_transform, get_origin

from dagster._config.config_type import ConfigType
from dagster._config.source import BoolSource, IntSource, StringSource
from dagster._core.definitions.definition_config_schema import IDefinitionConfigSchema
from dagster._core.definitions.definition_config_schema import (
ConfiguredDefinitionConfigSchema,
IDefinitionConfigSchema,
convert_user_facing_definition_config_schema,
)
from dagster._core.execution.context.init import InitResourceContext

try:
Expand All @@ -30,7 +34,11 @@ class cached_property: # type: ignore[no-redef]
config_dictionary_from_values,
convert_potential_field,
)
from dagster._core.definitions.resource_definition import ResourceDefinition, ResourceFunction
from dagster._core.definitions.resource_definition import (
ResourceDefinition,
ResourceFunction,
is_context_provided,
)
from dagster._core.storage.io_manager import IOManager, IOManagerDefinition

Self = Any
Expand Down Expand Up @@ -139,8 +147,8 @@ def __new__(self, name, bases, namespaces, **kwargs):
for field in annotations:
if not field.startswith("__"):
if get_origin(annotations[field]) == _ResourceDep:
arg = get_args(annotations[field])[0]
annotations[field] = Union[_PartialResource[arg], _Resource[arg]]
# arg = get_args(annotations[field])[0]
annotations[field] = Any
elif _safe_is_subclass(annotations[field], _Resource):
base = annotations[field]
annotations[field] = Union[_PartialResource[base], base]
Expand All @@ -150,28 +158,13 @@ def __new__(self, name, bases, namespaces, **kwargs):


class AllowDelayedDependencies:
_top_level_key: Optional[str] = None
_resource_pointers: Mapping[str, "AllowDelayedDependencies"] = {}

def set_top_level_key(self, key: str):
"""
Sets the top-level resource key for this resource, when passed
into the Definitions object.
"""
self._top_level_key = key

def get_top_level_key(self) -> Optional[str]:
"""
Gets the top-level resource key for this resource which was associated with it
in the Definitions object.
"""
return self._top_level_key
_resource_pointers: Mapping[str, ResourceDefinition] = {}

def _resolve_required_resource_keys(self) -> AbstractSet[str]:
# All dependent resources which are not fully configured
# must be specified to the Definitions object so that the
# resource can be configured at runtime by the user
pointer_keys = {k: v.get_top_level_key() for k, v in self._resource_pointers.items()}
pointer_keys = {k: v.top_level_key for k, v in self._resource_pointers.items()}
check.invariant(
all(pointer_key is not None for pointer_key in pointer_keys.values()),
(
Expand All @@ -183,12 +176,11 @@ def _resolve_required_resource_keys(self) -> AbstractSet[str]:
# Recursively get all nested resource keys
nested_pointer_keys: Set[str] = set()
for v in self._resource_pointers.values():
nested_pointer_keys.update(v._resolve_required_resource_keys())
nested_pointer_keys.update(v.required_resource_keys)

resources, _ = _separate_resource_params(self.__dict__)
resources = {k: v for k, v in resources.items() if isinstance(v, Resource)}
for v in resources.values():
nested_pointer_keys.update(v._resolve_required_resource_keys())
nested_pointer_keys.update(v.required_resource_keys)

out = set(cast(Set[str], pointer_keys.values())).union(nested_pointer_keys)
return out
Expand Down Expand Up @@ -231,8 +223,8 @@ def __init__(self, **data: Any):

# We keep track of any resources we depend on which are not fully configured
# so that we can retrieve them at runtime
self._resource_pointers: Mapping[str, PartialResource] = {
k: v for k, v in resource_pointers.items() if isinstance(v, PartialResource)
self._resource_pointers: Mapping[str, ResourceDefinition] = {
k: v for k, v in resource_pointers.items() if (not _is_fully_configured(v))
}

ResourceDefinition.__init__(
Expand Down Expand Up @@ -260,16 +252,16 @@ def initialize_and_run(self, context: InitResourceContext) -> T:

_, config_to_update = _separate_resource_params(context.resource_config)

partial_resources_to_update = {
k: getattr(context.resources, cast(str, v.top_level_key))
for k, v in self._resource_pointers.items()
}

resources_to_update, _ = _separate_resource_params(self.__dict__)
resources_to_update = {
k: v.initialize_and_run(context)
k: _call_resource_fn_with_default(v, context)
for k, v in resources_to_update.items()
if isinstance(v, Resource)
}

partial_resources_to_update = {
k: getattr(context.resources, cast(str, v.get_top_level_key()))
for k, v in self._resource_pointers.items()
if k not in partial_resources_to_update
}

to_update = {**resources_to_update, **partial_resources_to_update, **config_to_update}
Expand Down Expand Up @@ -306,6 +298,16 @@ def __set__(self, obj: Optional[object], value: Union[T, "PartialResource[T]"])
...


def _is_fully_configured(resource: ResourceDefinition) -> bool:
return (
ConfiguredDefinitionConfigSchema(
resource, convert_user_facing_definition_config_schema(resource.config_schema), {}
)
.resolve_config({})
.success
)


class PartialResource(
Generic[T], ResourceDefinition, AllowDelayedDependencies, MakeConfigCacheable
):
Expand All @@ -325,8 +327,8 @@ def __init__(self, resource_cls: Type[Resource[T]], data: Dict[str, Any]):

# We keep track of any resources we depend on which are not fully configured
# so that we can retrieve them at runtime
self._resource_pointers: Dict[str, ResourceOrPartial] = {
k: v for k, v in resource_pointers.items() if isinstance(v, PartialResource)
self._resource_pointers: Dict[str, ResourceDefinition] = {
k: v for k, v in resource_pointers.items() if (not _is_fully_configured(v))
}

schema = infer_schema_from_config_class(
Expand All @@ -350,6 +352,7 @@ def required_resource_keys(self) -> AbstractSet[str]:


ResourceOrPartial: TypeAlias = Union[Resource[T], PartialResource[T]]
ResourceOrPartialOrBase: TypeAlias = Union[Resource[T], PartialResource[T], T]


V = TypeVar("V")
Expand All @@ -362,7 +365,7 @@ def __set_name__(self, _owner, name):
def __get__(self, obj: "Resource", __owner: Any) -> V:
return getattr(obj, self._name)

def __set__(self, obj: Optional[object], value: ResourceOrPartial[V]) -> None:
def __set__(self, obj: Optional[object], value: ResourceOrPartialOrBase[V]) -> None:
setattr(obj, self._name, value)


Expand Down Expand Up @@ -584,17 +587,37 @@ def infer_schema_from_config_class(

def _separate_resource_params(
data: Dict[str, Any]
) -> Tuple[Dict[str, Union[Resource, PartialResource]], Dict[str, Any]]:
) -> Tuple[Dict[str, Union[Resource, PartialResource, ResourceDefinition]], Dict[str, Any]]:
"""
Separates out the key/value inputs of fields in a structured config Resource class which
are themselves Resources and those which are not.
"""
return (
{k: v for k, v in data.items() if isinstance(v, (Resource, PartialResource))},
{k: v for k, v in data.items() if not isinstance(v, (Resource, PartialResource))},
{
k: v
for k, v in data.items()
if isinstance(v, (Resource, PartialResource, ResourceDefinition))
},
{
k: v
for k, v in data.items()
if not isinstance(v, (Resource, PartialResource, ResourceDefinition))
},
)


def _call_resource_fn_with_default(obj: ResourceDefinition, context: InitResourceContext) -> Any:
if isinstance(obj.config_schema, ConfiguredDefinitionConfigSchema):
value = obj.config_schema.resolve_config({}).value
context = context.replace_config(value["config"])
elif obj.config_schema.default_provided:
context = context.replace_config(obj.config_schema.default_value)
if is_context_provided(obj.resource_fn):
return obj.resource_fn(context)
else:
return obj.resource_fn()


_Resource = Resource
_PartialResource = PartialResource
_ResourceDep = ResourceDependency
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from dagster._core.definitions.events import AssetKey, CoercibleToAssetKey
from dagster._core.definitions.executor_definition import ExecutorDefinition
from dagster._core.definitions.logger_definition import LoggerDefinition
from dagster._core.definitions.resource_definition import ResourceDefinition
from dagster._core.execution.build_resources import wrap_resources_for_execution
from dagster._core.execution.with_resources import with_resources
from dagster._core.instance import DagsterInstance
Expand Down Expand Up @@ -92,11 +93,10 @@ def _create_repository_using_definitions_args(

if loggers:
check.mapping_param(loggers, "loggers", key_type=str, value_type=LoggerDefinition)
from dagster._config.structured_config import AllowDelayedDependencies

if resources:
for rkey, resource in resources.items():
if isinstance(resource, AllowDelayedDependencies):
if isinstance(resource, ResourceDefinition):
resource.set_top_level_key(rkey)

resource_defs = wrap_resources_for_execution(resources or {})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,25 @@ def __init__(
required_resource_keys, "required_resource_keys"
)
self._version = check.opt_str_param(version, "version")
self._top_level_key = None
if version:
experimental_arg_warning("version", "ResourceDefinition.__init__")

def set_top_level_key(self, key: str):
"""
Sets the top-level resource key for this resource, when passed
into the Definitions object.
"""
self._top_level_key = key

@property
def top_level_key(self) -> Optional[str]:
"""
Gets the top-level resource key for this resource which was associated with it
in the Definitions object.
"""
return self._top_level_key

@property
def resource_fn(self) -> ResourceFunction:
return self._resource_fn
Expand Down
Loading

0 comments on commit 746be30

Please sign in to comment.