Skip to content

Commit

Permalink
nit
Browse files Browse the repository at this point in the history
Signed-off-by: Kevin Su <pingsutw@gmail.com>
  • Loading branch information
pingsutw committed Apr 3, 2024
1 parent 773fe75 commit d17c290
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 14 deletions.
26 changes: 13 additions & 13 deletions flytekit/core/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,20 +143,20 @@ def with_overrides(
*args,
**kwargs,
):
if node_name:
if node_name is not None:
# Convert the node name into a DNS-compliant.
# https://kubernetes.io/docs/concepts/overview/working-with-objects/names/#dns-subdomain-names
assert_not_promise(node_name, "node_name")
self._id = _dnsify(node_name)

if aliases:
if aliases is not None:
if not isinstance(aliases, dict):
raise AssertionError("Aliases should be specified as dict[str, str]")
self._aliases = []
for k, v in aliases.items():
self._aliases.append(_workflow_model.Alias(var=k, alias=v))

if requests or limits:
if requests is not None or limits is not None:
if requests and not isinstance(requests, Resources):
raise AssertionError("requests should be specified as flytekit.Resources")
if limits and not isinstance(limits, Resources):
Expand All @@ -174,7 +174,7 @@ def with_overrides(
assert_no_promises_in_resources(resources)
self._resources = resources

if timeout:
if timeout is not None:
if timeout is None:
self._metadata._timeout = datetime.timedelta()
elif isinstance(timeout, int):
Expand All @@ -183,42 +183,42 @@ def with_overrides(
self._metadata._timeout = timeout
else:
raise ValueError("timeout should be duration represented as either a datetime.timedelta or int seconds")
if retries:
if retries is not None:
assert_not_promise(retries, "retries")
self._metadata._retries = (
_literal_models.RetryStrategy(0) if retries is None else _literal_models.RetryStrategy(retries)
)

if interruptible:
if interruptible is not None:
assert_not_promise(interruptible, "interruptible")
self._metadata._interruptible = interruptible

if name:
if name is not None:
self._metadata._name = name

if task_config:
if task_config is not None:
logger.warning("This override is beta. We may want to revisit this in the future.")
if not isinstance(task_config, type(self.run_entity._task_config)):
raise ValueError("can't change the type of the task config")
self.run_entity._task_config = task_config

if container_image:
if container_image is not None:
assert_not_promise(container_image, "container_image")
self._container_image = container_image

if accelerator:
if accelerator is not None:
assert_not_promise(accelerator, "accelerator")
self._extended_resources = tasks_pb2.ExtendedResources(gpu_accelerator=accelerator.to_flyte_idl())

if cache:
if cache is not None:
assert_not_promise(cache, "cache")
self._metadata._cacheable = cache

if cache_version:
if cache_version is not None:
assert_not_promise(cache_version, "cache_version")
self._metadata._cache_version = cache_version

if cache_serialize:
if cache_serialize is not None:
assert_not_promise(cache_serialize, "cache_serialize")
self._metadata._cache_serializable = cache_serialize

Expand Down
2 changes: 1 addition & 1 deletion flytekit/core/promise.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,6 @@ def with_overrides(
if not self.is_ready:
# TODO, this should be forwarded, but right now this results in failure and we want to test this behavior
self.ref.node.with_overrides(
*args,
node_name=node_name,
aliases=aliases,
requests=requests,
Expand All @@ -518,6 +517,7 @@ def with_overrides(
cache=cache,
cache_version=cache_version,
cache_serialize=cache_serialize,
*args,
**kwargs,
)
return self
Expand Down

0 comments on commit d17c290

Please sign in to comment.