Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Fix the optimization for IP fetching in sky launch #2400

Merged
merged 39 commits into from
Aug 17, 2023

Conversation

Michaelvll
Copy link
Collaborator

@Michaelvll Michaelvll commented Aug 14, 2023

The previous optimization for IP fetching does not work as we did not retrieve the internal IP from the provision output, causing the sky launch will always fetch the IP addresses causing additional overheads.

This PR fixes the optimization. The following is the profiling, which shows ~9s faster for launching a new cluster and ~13s faster for sky launch an existing cluster. Note the optimization is for single node only at the moment, but since many users are using single node, it should be helpful.

Profiling (average over 5 runs, on GCP with 2 CPUs, time sky launch -y -d):

  • Original:
    sky launch a new cluster: 2m 41s
    sky launch an existing cluster: 1m 26s
  • Current PR:
    sky launch a new cluster: 2m 32s
    sky launch an existing cluster: 1m 13s

TODO:

  • Test TPU pod

Tested (run the relevant ones):

  • Code formatting: bash format.sh
  • Any manual or new tests for this PR (please specify below)
  • All smoke tests: pytest tests/test_smoke.py
  • All smoke tests: pytest tests/test_smoke.py --aws
  • Relevant individual smoke tests: pytest tests/test_smoke.py::test_fill_in_the_name
  • Backward compatibility tests: bash tests/backward_comaptibility_tests.sh

Copy link
Collaborator

@cblmemo cblmemo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code looks good to me! Left several nits:

sky/backends/cloud_vm_ray_backend.py Outdated Show resolved Hide resolved
# Optimization: Try parse internal head ip from 'ray start' stdout.
# The line looks like: 'Local node IP: <internal head_ip>\n'
position = stdout.rfind('Local node IP')
line = stdout[position + 1:].partition('\n')[0]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why +1 here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point! Changed to position: instead.

internal_ips = self._internal_ips
if internal_ips is not None:
return internal_ips
self.update_cluster_ips(max_attempts=max_attempts)
return self._internal_ips
internal_ips = self._internal_ips
assert internal_ips is not None, 'update_cluster_ips failed.'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we raise an error here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be the internal error, as the self.cached_internal_ips should not be None after the update. We should probably still use assert

self.update_cluster_ips(max_attempts=max_attempts)
return self._external_ips
external_ips = self._external_ips
assert external_ips is not None, 'update_cluster_ips failed.'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

if external_ips is None or len(external_ips) == 0:
raise exceptions.FetchIPError(
reason=exceptions.FetchIPError.Reason.HEAD)
# TODO(zhwu): check the correctness of stopped TPU VM
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we add a TODO for the case in #2304 ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I reverted it to the previous implementation to make it faster. : )

@Michaelvll
Copy link
Collaborator Author

I fixed some problems with the TPU pod. It now passes all smoke tests for GCP and AWS. PTAL @cblmemo.

Copy link
Collaborator

@cblmemo cblmemo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code looks great to me! Left one nit and two questions that I'm not entirely sure I understand:

  • Why do we need manually pass port argument when initializing SSHCommandRunner?
  • I noticed that our IP is ephemeral. Is there a possibility that prev_handle might contain stale IP if we sky stop some UP cluster?

if isinstance(self.launched_resources.cloud, clouds.Kubernetes):
head_port = backend_utils.get_head_ssh_port(
self, use_cache=False, max_attempts=max_attempts)
# TODO(romilb): Multinode doesn't work with Kubernetes yet.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we keep this TODO?

@Michaelvll
Copy link
Collaborator Author

Michaelvll commented Aug 17, 2023

Thanks for the review @cblmemo!

Why do we need manually pass port argument when initializing SSHCommandRunner?

The main reason is for the type checking. It seems if we do not manually pass that argument, mypy will complain about the **ssh_crednetials is **Dict[str, str] but the remaining arguments (i.e., port) in that function has int type.

I noticed that our IP is ephemeral. Is there a possibility that prev_handle might contain stale IP if we sky stop some UP cluster?

It is ok to pass the staled IP, because we will update the IPs later in the code once the VM is provisioned.

handle.update_cluster_ips(max_attempts=_FETCH_IP_MAX_ATTEMPTS,

And in the update, we will check if the IP from the ray up matches the one in the cache. If they do not match, we will use the new one and update the internal IPs.

def is_provided_ips_valid(ips: Optional[List[Optional[str]]]) -> bool:
return (ips is not None and len(ips) == self.num_node_ips and
all(ip is not None for ip in ips))
if is_provided_ips_valid(external_ips):
logger.debug(f'Using provided external IPs: {external_ips}')
cluster_external_ips = typing.cast(List[str], external_ips)
else:
cluster_external_ips = backend_utils.get_node_ips(
self.cluster_yaml,
self.launched_nodes,
handle=self,
head_ip_max_attempts=max_attempts,
worker_ip_max_attempts=max_attempts,
get_internal_ips=False)
if self.cached_external_ips == cluster_external_ips:
logger.debug('Skipping the fetching of internal IPs as the cached '
'external IPs matches the newly fetched ones.')
# Optimization: If the cached external IPs are the same as the
# retrieved external IPs, then we can skip retrieving internal
# IPs since the cached IPs are up-to-date.
return

Copy link
Collaborator

@cblmemo cblmemo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great to me!

@Michaelvll Michaelvll merged commit 06a927c into master Aug 17, 2023
17 checks passed
@Michaelvll Michaelvll deleted the optimize-head-ip branch August 17, 2023 21:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants