Skip to content

Commit

Permalink
code
Browse files Browse the repository at this point in the history
  • Loading branch information
schrockn committed Feb 11, 2023
1 parent beaf683 commit 9afb310
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,7 @@ class cached_property:
)
from dagster._core.storage.io_manager import IOManager, IOManagerDefinition

from . import typing_utils
from .typing_utils import BaseResourceMeta
from .typing_utils import BaseResourceMeta, LateBoundTypesForResourceTypeChecking
from .utils import safe_is_subclass

Self = TypeVar("Self", bound="Resource")
Expand Down Expand Up @@ -823,6 +822,9 @@ def _call_resource_fn_with_default(obj: ResourceDefinition, context: InitResourc
return cast(ResourceFunctionWithoutContext, obj.resource_fn)()


typing_utils._Resource = Resource
typing_utils._PartialResource = PartialResource
typing_utils._ResourceDep = ResourceDependency

LateBoundTypesForResourceTypeChecking.set_actual_types_for_type_checking(
resource_dep_type=ResourceDependency,
resource_type=Resource,
partial_resource_type=PartialResource,
)
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,41 @@
from dagster._config.structured_config import PartialResource





# Since a metaclass is invoked by Resource before Resource or PartialResource is defined, we need to
# define a temporary class to use as a placeholder for use in the initial metaclass invocation.
# When the metaclass is invoked for a Resource subclass, it will use the non-placeholder values.
class LateBoundTypesForResourceTypeChecking:
_ResValue = TypeVar("_ResValue")

_ResValue = TypeVar("_ResValue")
class _Temp(Generic[_ResValue]):
pass

_ResourceDep: Type = _Temp
_Resource: Type = _Temp
_PartialResource: Type = _Temp

class _Temp(Generic[_ResValue]):
pass
@staticmethod
def get_resource_rep_type() -> Type:
return LateBoundTypesForResourceTypeChecking._ResourceDep

@staticmethod
def get_resource_type() -> Type:
return LateBoundTypesForResourceTypeChecking._Resource

_ResourceDep: Type = _Temp
_Resource: Type = _Temp
_PartialResource: Type = _Temp
@staticmethod
def get_partial_resource_type(base: Type) -> Type:
return LateBoundTypesForResourceTypeChecking._PartialResource[base]

@staticmethod
def set_actual_types_for_type_checking(
resource_dep_type: Type, resource_type: Type, partial_resource_type: Type
) -> None:
LateBoundTypesForResourceTypeChecking._ResourceDep = resource_dep_type
LateBoundTypesForResourceTypeChecking._Resource = resource_type
LateBoundTypesForResourceTypeChecking._PartialResource = partial_resource_type


@dataclass_transform()
Expand Down Expand Up @@ -58,17 +79,17 @@ def __new__(cls, name, bases, namespaces, **kwargs):
for field in annotations:
if not field.startswith("__"):
# Check if the annotation is a ResourceDependency
if get_origin(annotations[field]) == _ResourceDep:
if get_origin(annotations[field]) == LateBoundTypesForResourceTypeChecking.get_resource_rep_type():
# arg = get_args(annotations[field])[0]
# If so, we treat it as a Union of a PartialResource and a Resource
# for Pydantic's sake.
annotations[field] = Any
elif safe_is_subclass(annotations[field], _Resource):
elif safe_is_subclass(annotations[field], LateBoundTypesForResourceTypeChecking.get_resource_type()):
# If the annotation is a Resource, we treat it as a Union of a PartialResource
# and a Resource for Pydantic's sake, so that a user can pass in a partially
# configured resource.
base = annotations[field]
annotations[field] = Union[_PartialResource[base], base]
annotations[field] = Union[LateBoundTypesForResourceTypeChecking.get_partial_resource_type(base), base]

namespaces["__annotations__"] = annotations
return super().__new__(cls, name, bases, namespaces, **kwargs)
Expand Down

0 comments on commit 9afb310

Please sign in to comment.