Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(core): Update with_overrides signatures and type hints #2323

Merged
merged 5 commits into from
Aug 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 65 additions & 60 deletions flytekit/core/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@

import datetime
import typing
from typing import Any, List
from typing import Any, Dict, List, Optional, Union

from flyteidl.core import tasks_pb2

from flytekit.core.resources import Resources, convert_resources_to_resource_model
from flytekit.core.utils import _dnsify
from flytekit.extras.accelerators import BaseAccelerator
from flytekit.loggers import logger
from flytekit.models import literals as _literal_models
from flytekit.models.core import workflow as _workflow_model
Expand Down Expand Up @@ -123,27 +124,41 @@
def metadata(self) -> _workflow_model.NodeMetadata:
return self._metadata

def with_overrides(self, *args, **kwargs):
if "node_name" in kwargs:
def with_overrides(
self,
node_name: Optional[str] = None,
aliases: Optional[Dict[str, str]] = None,
requests: Optional[Resources] = None,
limits: Optional[Resources] = None,
timeout: Optional[Union[int, datetime.timedelta]] = None,
retries: Optional[int] = None,
interruptible: Optional[bool] = None,
name: Optional[str] = None,
task_config: Optional[Any] = None,
container_image: Optional[str] = None,
accelerator: Optional[BaseAccelerator] = None,
cache: Optional[bool] = None,
cache_version: Optional[str] = None,
cache_serialize: Optional[bool] = None,
*args,
**kwargs,
):
if node_name is not None:
# Convert the node name into a DNS-compliant.
# https://kubernetes.io/docs/concepts/overview/working-with-objects/names/#dns-subdomain-names
v = kwargs["node_name"]
assert_not_promise(v, "node_name")
self._id = _dnsify(v)
assert_not_promise(node_name, "node_name")
self._id = _dnsify(node_name)

if "aliases" in kwargs:
alias_dict = kwargs["aliases"]
if not isinstance(alias_dict, dict):
if aliases is not None:
if not isinstance(aliases, dict):
raise AssertionError("Aliases should be specified as dict[str, str]")
self._aliases = []
for k, v in alias_dict.items():
for k, v in aliases.items():
self._aliases.append(_workflow_model.Alias(var=k, alias=v))

if "requests" in kwargs or "limits" in kwargs:
requests = kwargs.get("requests")
if requests is not None or limits is not None:
if requests and not isinstance(requests, Resources):
raise AssertionError("requests should be specified as flytekit.Resources")
limits = kwargs.get("limits")
if limits and not isinstance(limits, Resources):
raise AssertionError("limits should be specified as flytekit.Resources")

Expand All @@ -159,62 +174,52 @@
assert_no_promises_in_resources(resources)
self._resources = resources

if "timeout" in kwargs:
timeout = kwargs["timeout"]
if timeout is None:
self._metadata._timeout = datetime.timedelta()
elif isinstance(timeout, int):
self._metadata._timeout = datetime.timedelta(seconds=timeout)
elif isinstance(timeout, datetime.timedelta):
self._metadata._timeout = timeout
else:
raise ValueError("timeout should be duration represented as either a datetime.timedelta or int seconds")
if "retries" in kwargs:
retries = kwargs["retries"]
if timeout is None:
self._metadata._timeout = datetime.timedelta()
elif isinstance(timeout, int):
self._metadata._timeout = datetime.timedelta(seconds=timeout)
elif isinstance(timeout, datetime.timedelta):
self._metadata._timeout = timeout

Check warning on line 182 in flytekit/core/node.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/node.py#L182

Added line #L182 was not covered by tests
else:
raise ValueError("timeout should be duration represented as either a datetime.timedelta or int seconds")
if retries is not None:
assert_not_promise(retries, "retries")
self._metadata._retries = (
_literal_models.RetryStrategy(0) if retries is None else _literal_models.RetryStrategy(retries)
)

if "interruptible" in kwargs:
v = kwargs["interruptible"]
assert_not_promise(v, "interruptible")
self._metadata._interruptible = kwargs["interruptible"]
if interruptible is not None:
assert_not_promise(interruptible, "interruptible")
self._metadata._interruptible = interruptible

if "name" in kwargs:
self._metadata._name = kwargs["name"]
if name is not None:
self._metadata._name = name

if "task_config" in kwargs:
if task_config is not None:
logger.warning("This override is beta. We may want to revisit this in the future.")
new_task_config = kwargs["task_config"]
if not isinstance(new_task_config, type(self.run_entity._task_config)):
if not isinstance(task_config, type(self.run_entity._task_config)):
raise ValueError("can't change the type of the task config")
self.run_entity._task_config = new_task_config

if "container_image" in kwargs:
v = kwargs["container_image"]
assert_not_promise(v, "container_image")
self._container_image = v

if "accelerator" in kwargs:
v = kwargs["accelerator"]
assert_not_promise(v, "accelerator")
self._extended_resources = tasks_pb2.ExtendedResources(gpu_accelerator=v.to_flyte_idl())

if "cache" in kwargs:
v = kwargs["cache"]
assert_not_promise(v, "cache")
self._metadata._cacheable = kwargs["cache"]

if "cache_version" in kwargs:
v = kwargs["cache_version"]
assert_not_promise(v, "cache_version")
self._metadata._cache_version = kwargs["cache_version"]

if "cache_serialize" in kwargs:
v = kwargs["cache_serialize"]
assert_not_promise(v, "cache_serialize")
self._metadata._cache_serializable = kwargs["cache_serialize"]
self.run_entity._task_config = task_config

if container_image is not None:
assert_not_promise(container_image, "container_image")
self._container_image = container_image

if accelerator is not None:
assert_not_promise(accelerator, "accelerator")
self._extended_resources = tasks_pb2.ExtendedResources(gpu_accelerator=accelerator.to_flyte_idl())

if cache is not None:
assert_not_promise(cache, "cache")
self._metadata._cacheable = cache

if cache_version is not None:
assert_not_promise(cache_version, "cache_version")
self._metadata._cache_version = cache_version

if cache_serialize is not None:
assert_not_promise(cache_serialize, "cache_serialize")
self._metadata._cache_serializable = cache_serialize

return self

Expand Down
42 changes: 40 additions & 2 deletions flytekit/core/promise.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import collections
import datetime
import inspect
import typing
from copy import deepcopy
Expand Down Expand Up @@ -33,13 +34,15 @@
)
from flytekit.exceptions import user as _user_exceptions
from flytekit.exceptions.user import FlytePromiseAttributeResolveException
from flytekit.extras.accelerators import BaseAccelerator
from flytekit.loggers import logger
from flytekit.models import interface as _interface_models
from flytekit.models import literals as _literals_models
from flytekit.models import types as _type_models
from flytekit.models import types as type_models
from flytekit.models.core import workflow as _workflow_model
from flytekit.models.literals import Primitive
from flytekit.models.task import Resources
from flytekit.models.types import SimpleType


Expand Down Expand Up @@ -497,10 +500,45 @@ def __and__(self, other):
def __or__(self, other):
raise ValueError("Cannot perform Logical OR of Promise with other")

def with_overrides(self, *args, **kwargs):
def with_overrides(
self,
node_name: Optional[str] = None,
aliases: Optional[Dict[str, str]] = None,
requests: Optional[Resources] = None,
limits: Optional[Resources] = None,
timeout: Optional[Union[int, datetime.timedelta]] = None,
retries: Optional[int] = None,
interruptible: Optional[bool] = None,
name: Optional[str] = None,
task_config: Optional[Any] = None,
container_image: Optional[str] = None,
accelerator: Optional[BaseAccelerator] = None,
cache: Optional[bool] = None,
cache_version: Optional[str] = None,
cache_serialize: Optional[bool] = None,
*args,
**kwargs,
):
if not self.is_ready:
# TODO, this should be forwarded, but right now this results in failure and we want to test this behavior
self.ref.node.with_overrides(*args, **kwargs)
self.ref.node.with_overrides( # type: ignore
node_name=node_name,
aliases=aliases,
requests=requests,
limits=limits,
timeout=timeout,
retries=retries,
interruptible=interruptible,
name=name,
task_config=task_config,
container_image=container_image,
accelerator=accelerator,
cache=cache,
cache_version=cache_version,
cache_serialize=cache_serialize,
*args,
**kwargs,
)
return self

def __repr__(self):
Expand Down
2 changes: 1 addition & 1 deletion flytekit/models/literals.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@


class RetryStrategy(_common.FlyteIdlEntity):
def __init__(self, retries):
def __init__(self, retries: int):
"""
:param int retries: Number of retries to attempt on recoverable failures. If retries is 0, then
only one attempt will be made.
Expand Down
2 changes: 1 addition & 1 deletion flytekit/types/file/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def __init__(
self,
path: typing.Union[str, os.PathLike],
downloader: typing.Callable = noop,
remote_path: typing.Optional[typing.Union[os.PathLike, bool]] = None,
remote_path: typing.Optional[typing.Union[os.PathLike, str, bool]] = None,
):
"""
FlyteFile's init method.
Expand Down
2 changes: 1 addition & 1 deletion tests/flytekit/unit/core/test_node_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@ def my_wf(a: str) -> str:
def my_wf(a: str) -> str:
return t1(a=a).with_overrides(task_config=None)

my_wf()
my_wf(a=2)


def test_override_image():
Expand Down
Loading