Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 0 additions & 39 deletions src/flyte/_internal/controllers/remote/_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import concurrent.futures
import os
import threading
import time
from collections import defaultdict
from collections.abc import Callable
from contextlib import nullcontext
Expand Down Expand Up @@ -154,7 +153,6 @@ async def _submit(self, _task_call_seq: int, _task: TaskTemplate, *args, **kwarg
if tctx is None:
raise flyte.errors.RuntimeSystemError("BadContext", "Task context not initialized")
current_action_id = tctx.action
trace_enabled = self._should_trace_sequence(_task_call_seq)

# In the case of a regular code bundle, we will just pass it down as it is to the downstream tasks
# It is not allowed to change the code bundle (for regular code bundles) in the middle of a run.
Expand All @@ -169,14 +167,11 @@ async def _submit(self, _task_call_seq: int, _task: TaskTemplate, *args, **kwarg
)

_ctx = ctx.new_in_driver_literal_conversion(True) if ctx.is_task_context() else nullcontext()
sdk_inputs_start = time.monotonic()
with _ctx:
inputs = await convert.convert_from_native_to_inputs(_task.native_interface, *args, **kwargs)
sdk_inputs_ms = (time.monotonic() - sdk_inputs_start) * 1000

root_dir = Path(code_bundle.destination).absolute() if code_bundle else Path.cwd()
# Don't set output path in sec context because node executor will set it
sdk_serialize_start = time.monotonic()
new_serialization_context = SerializationContext(
project=current_action_id.project,
domain=current_action_id.domain,
Expand All @@ -194,17 +189,12 @@ async def _submit(self, _task_call_seq: int, _task: TaskTemplate, *args, **kwarg
sub_action_id, sub_action_output_path = convert.generate_sub_action_id_and_output_path(
tctx, task_spec, inputs_hash, _task_call_seq
)
sdk_serialize_ms = (time.monotonic() - sdk_serialize_start) * 1000
logger.info(f"Sub action {sub_action_id} output path {sub_action_output_path}")

serialized_inputs = inputs.proto_inputs.SerializeToString(deterministic=True)
serialized_input_bytes = len(serialized_inputs)
inputs_uri = io.inputs_path(sub_action_output_path)
storage_put_start = time.monotonic()
await upload_inputs_with_retry(serialized_inputs, inputs_uri, max_bytes=_task.max_inline_io_bytes)
storage_put_ms = (time.monotonic() - storage_put_start) * 1000

sdk_cache_start = time.monotonic()
md = task_spec.task_template.metadata
ignored_input_vars = []
if len(md.cache_ignore_input_vars) > 0:
Expand All @@ -220,7 +210,6 @@ async def _submit(self, _task_call_seq: int, _task: TaskTemplate, *args, **kwarg
ignored_input_vars,
inputs.proto_inputs,
)
sdk_cache_ms = (time.monotonic() - sdk_cache_start) * 1000

# Clear to free memory
serialized_inputs = None # type: ignore
Expand All @@ -244,41 +233,13 @@ async def _submit(self, _task_call_seq: int, _task: TaskTemplate, *args, **kwarg
cache_key=cache_key,
queue=_task.queue,
)
self._mark_action_for_trace(action.name)
if trace_enabled:
self._trace_log(
action.name,
"sdk_prepare",
kind="sdk_only",
seq=_task_call_seq,
task=_task.name,
sdk_inputs_ms=f"{sdk_inputs_ms:.1f}",
sdk_serialize_ms=f"{sdk_serialize_ms:.1f}",
sdk_cache_ms=f"{sdk_cache_ms:.1f}",
input_bytes=serialized_input_bytes,
)
self._trace_log(
action.name,
"storage_put_inputs",
kind="storage_api",
elapsed_ms=f"{storage_put_ms:.1f}",
input_bytes=serialized_input_bytes,
)

try:
logger.info(
f"Submitting action Run:[{action.run_name}, Parent:[{action.parent_action_name}], "
f"task:[{_task.name}], action:[{action.name}]"
)
submit_start = time.monotonic()
n = await self.submit_action(action)
if trace_enabled:
self._trace_log(
action.name,
"submit_action_done",
kind="mixed",
elapsed_ms=f"{(time.monotonic() - submit_start) * 1000:.1f}",
)
logger.info(f"Action for task [{_task.name}] action id: {action.name}, completed!")
except asyncio.CancelledError:
# If the action is cancelled, we need to cancel the action on the server as well
Expand Down
65 changes: 0 additions & 65 deletions src/flyte/_internal/controllers/remote/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import os
import sys
import threading
import time
from asyncio import Event
from typing import Awaitable, Coroutine, Optional

Expand Down Expand Up @@ -79,10 +78,6 @@ def __init__(
self._informer_start_wait_timeout = thread_wait_timeout_sec
max_qps = int(os.getenv("_F_MAX_QPS", "100"))
self._rate_limiter = AsyncLimiter(max_qps, 1.0)
self._trace_submit = os.getenv("_F_TRACE_SUBMIT", "").lower() in {"1", "true", "yes", "on"}
self._trace_submit_limit = int(os.getenv("_F_TRACE_SUBMIT_LIMIT", "10"))
self._trace_actions: set[str] = set()
self._trace_lock = threading.Lock()

# Thread management
self._thread = None
Expand All @@ -92,28 +87,6 @@ def __init__(
self._thread_com_lock = threading.Lock()
self._start()

def _should_trace_sequence(self, seq: int) -> bool:
return self._trace_submit and seq <= self._trace_submit_limit

def _mark_action_for_trace(self, action_name: str):
if not self._trace_submit:
return
with self._trace_lock:
if len(self._trace_actions) < self._trace_submit_limit:
self._trace_actions.add(action_name)

def _trace_enabled_for(self, action_name: str) -> bool:
if not self._trace_submit:
return False
with self._trace_lock:
return action_name in self._trace_actions

def _trace_log(self, action_name: str, phase: str, **fields):
if not self._trace_enabled_for(action_name):
return
payload = " ".join(f"{key}={value}" for key, value in fields.items())
print(f"submit_trace action={action_name} phase={phase} {payload}".rstrip(), flush=True)

# ---------------- Public sync methods, we can add more sync methods if needed
@log
def submit_action_sync(self, action: Action) -> Action:
Expand Down Expand Up @@ -305,8 +278,6 @@ async def _bg_finalize_informer(
async def _bg_submit_action(self, action: Action) -> Action:
"""Submit a resource and await its completion, returning the final state"""
logger.debug(f"{threading.current_thread().name} Submitting action {action.name}")
trace_enabled = self._trace_enabled_for(action.name)
informer_start = time.monotonic()
informer = await self._informers.get_or_create(
action.action_id.run,
action.parent_action_name,
Expand All @@ -316,31 +287,13 @@ async def _bg_submit_action(self, action: Action) -> Action:
timeout=self._informer_start_wait_timeout,
actions_service=self._actions_service,
)
if trace_enabled:
watch_api = "actions.watch_for_updates" if self._actions_service else "state.watch"
self._trace_log(
action.name,
"informer_ready",
kind="controlplane_api",
api=watch_api,
elapsed_ms=f"{(time.monotonic() - informer_start) * 1000:.1f}",
)
queue_submit_start = time.monotonic()
await informer.submit(action)
if trace_enabled:
self._trace_log(
action.name,
"queue_submit",
kind="sdk_only",
elapsed_ms=f"{(time.monotonic() - queue_submit_start) * 1000:.1f}",
)

logger.debug(f"{threading.current_thread().name} Waiting for completion of {action.name}")
# Wait for completion. For trace actions apply a timeout so a
# transient watch failure (e.g. gRPC deserialization returning None)
# doesn't block the caller indefinitely. Task actions may legitimately
# run for hours, so they wait without a timeout.
wait_start = time.monotonic()
if action.type == "trace":
_trace_timeout = float(os.getenv("_F_TRACE_COMPLETION_TIMEOUT", "60"))
try:
Expand All @@ -353,13 +306,6 @@ async def _bg_submit_action(self, action: Action) -> Action:
await informer.fire_completion_event(action.name)
else:
await informer.wait_for_action_completion(action.name)
if trace_enabled:
self._trace_log(
action.name,
"wait_for_completion",
kind="lifecycle_wait",
elapsed_ms=f"{(time.monotonic() - wait_start) * 1000:.1f}",
)
logger.info(f"{threading.current_thread().name} Action {action.name} completed")

# Get final resource state and clean up
Expand Down Expand Up @@ -416,9 +362,7 @@ async def _bg_launch(self, action: Action):
Attempt to launch an action.
"""
if not action.is_started():
limiter_wait_start = time.monotonic()
async with self._rate_limiter:
limiter_wait_ms = (time.monotonic() - limiter_wait_start) * 1000
task: run_definition_pb2.TaskAction | None = None
trace: run_definition_pb2.TraceAction | None = None
if action.type == "task":
Expand Down Expand Up @@ -447,7 +391,6 @@ async def _bg_launch(self, action: Action):
trace = action.trace

logger.debug(f"Attempting to launch action: {action.name}, actions? {bool(self._actions_service)}")
launch_start = time.monotonic()
try:
if self._actions_service:
await self._actions_service.enqueue(
Expand Down Expand Up @@ -479,14 +422,6 @@ async def _bg_launch(self, action: Action):
timeout_ms=int(self._enqueue_timeout * 1000),
)
logger.info(f"Successfully launched action: {action.name}")
self._trace_log(
action.name,
"enqueue_action",
kind="controlplane_api",
api="actions.enqueue" if self._actions_service else "queue.enqueue_action",
limiter_wait_ms=f"{limiter_wait_ms:.1f}",
elapsed_ms=f"{(time.monotonic() - launch_start) * 1000:.1f}",
)
except httpx.TransportError as e:
# Transport-level failure (e.g. ConnectTimeout reaching the IDP during auth refresh,
# ReadTimeout, DNS failure). These never produced an HTTP response, so they bypass
Expand Down
Loading