Skip to content

Commit

Permalink
Always bootstrap config
Browse files Browse the repository at this point in the history
  • Loading branch information
edoakes committed Aug 22, 2019
1 parent 8ec5f23 commit 3c454c2
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 26 deletions.
11 changes: 4 additions & 7 deletions python/ray/autoscaler/autoscaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ def __init__(self, provider, queue, pending, index=None, *args, **kwargs):
def _launch_node(self, config, count):
worker_filter = {TAG_RAY_NODE_TYPE: "worker"}
before = self.provider.non_terminated_nodes(tag_filters=worker_filter)
launch_hash = hash_launch_conf(config["worker_nodes"], config["auth"] if "auth" in self.config else None)
launch_hash = hash_launch_conf(config["worker_nodes"], config["auth"] if "auth" in config else None)
self.log("Launching {} nodes.".format(count))
self.provider.create_node(
config["worker_nodes"], {
Expand Down Expand Up @@ -368,8 +368,7 @@ def __init__(self,
self.config_path = config_path
self.reload_config(errors_fatal=True)
self.load_metrics = load_metrics
self.provider = get_node_provider(self.config["provider"],
self.config["cluster_name"])
self.provider = get_node_provider(self.config)

self.max_failures = max_failures
self.max_launch_batch = max_launch_batch
Expand Down Expand Up @@ -540,7 +539,7 @@ def reload_config(self, errors_fatal=False):
new_config = yaml.safe_load(f.read())
validate_config(new_config)
new_launch_hash = hash_launch_conf(new_config["worker_nodes"],
new_config["auth"] if "auth" in self.config else None)
new_config["auth"] if "auth" in new_config else None)
new_runtime_hash = hash_runtime_conf(new_config["file_mounts"], [
new_config["worker_setup_commands"],
new_config["worker_start_ray_commands"]
Expand Down Expand Up @@ -788,10 +787,8 @@ def validate_config(config, schema=CLUSTER_CONFIG_SCHEMA):


def fillout_defaults(config):
config_copy = config.copy()
defaults = get_default_config(config["provider"])
config_copy["provider"].update(defaults["provider"])
defaults.update(config_copy)
defaults.update(config)
merge_setup_commands(defaults)
dockerize_if_needed(defaults)
return defaults
Expand Down
19 changes: 9 additions & 10 deletions python/ray/autoscaler/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,12 @@ def teardown_cluster(config_file, yes, workers_only, override_cluster_name):
config = yaml.safe_load(open(config_file).read())
if override_cluster_name is not None:
config["cluster_name"] = override_cluster_name
validate_config(config)
config = fillout_defaults(config)
validate_config(config)

confirm("This will destroy your cluster", yes)

provider = get_node_provider(config["provider"], config["cluster_name"])

provider = get_node_provider(config)
try:

def remaining_nodes():
Expand Down Expand Up @@ -126,7 +125,7 @@ def kill_node(config_file, yes, hard, override_cluster_name):

confirm("This will kill a node in your cluster", yes)

provider = get_node_provider(config["provider"], config["cluster_name"])
provider = get_node_provider(config)
try:
nodes = provider.non_terminated_nodes({TAG_RAY_NODE_TYPE: "worker"})
node = random.choice(nodes)
Expand Down Expand Up @@ -169,7 +168,7 @@ def monitor_cluster(cluster_config_file, num_lines, override_cluster_name):
def get_or_create_head_node(config, config_file, no_restart, restart_only, yes,
override_cluster_name):
"""Create the cluster head node, which in turn creates the workers."""
provider = get_node_provider(config["provider"], config["cluster_name"])
provider = get_node_provider(config)
try:
head_node_tags = {
TAG_RAY_NODE_TYPE: "head",
Expand Down Expand Up @@ -347,7 +346,7 @@ def exec_cluster(config_file, cmd, docker, screen, tmux, stop, start,
head_node = _get_head_node(
config, config_file, override_cluster_name, create_if_needed=start)

provider = get_node_provider(config["provider"], config["cluster_name"])
provider = get_node_provider(config)
try:
updater = NodeUpdaterThread(
node_id=head_node,
Expand Down Expand Up @@ -440,7 +439,7 @@ def rsync(config_file, source, target, override_cluster_name, down):
head_node = _get_head_node(
config, config_file, override_cluster_name, create_if_needed=False)

provider = get_node_provider(config["provider"], config["cluster_name"])
provider = get_node_provider(config)
try:
updater = NodeUpdaterThread(
node_id=head_node,
Expand Down Expand Up @@ -474,7 +473,7 @@ def get_head_node_ip(config_file, override_cluster_name):
if override_cluster_name is not None:
config["cluster_name"] = override_cluster_name

provider = get_node_provider(config["provider"], config["cluster_name"])
provider = get_node_provider(config)
try:
head_node = _get_head_node(config, config_file, override_cluster_name)
if config.get("provider", {}).get("use_internal_ips", False) is True:
Expand All @@ -494,7 +493,7 @@ def get_worker_node_ips(config_file, override_cluster_name):
if override_cluster_name is not None:
config["cluster_name"] = override_cluster_name

provider = get_node_provider(config["provider"], config["cluster_name"])
provider = get_node_provider(config)
try:
nodes = provider.non_terminated_nodes({TAG_RAY_NODE_TYPE: "worker"})

Expand All @@ -510,7 +509,7 @@ def _get_head_node(config,
config_file,
override_cluster_name,
create_if_needed=False):
provider = get_node_provider(config["provider"], config["cluster_name"])
provider = get_node_provider(config)
try:
head_node_tags = {
TAG_RAY_NODE_TYPE: "head",
Expand Down
6 changes: 3 additions & 3 deletions python/ray/autoscaler/kubernetes/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@


def bootstrap_kubernetes(config):
config["use_internal_ips"] = True
config["provider"]["use_internal_ips"] = True
config = _configure_namespace(config)

# TODO: make cluster role if not set
Expand All @@ -19,6 +19,6 @@ def bootstrap_kubernetes(config):


def _configure_namespace(config):
if "namespace" not in config:
config["namespace"] = DEFAULT_NAMESPACE
if "namespace" not in config["provider"]:
config["provider"]["namespace"] = DEFAULT_NAMESPACE
return config
4 changes: 2 additions & 2 deletions python/ray/autoscaler/kubernetes/example-full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ idle_timeout_minutes: 5
# Cloud-provider specific configuration.
provider:
type: kubernetes
use_internal_ips: true
namespace: ray

# Provider-specific config for the head node, e.g. instance type.
Expand Down Expand Up @@ -117,7 +116,8 @@ file_mounts: {
initialization_commands: []

# List of shell commands to run to set up nodes.
setup_commands: []
setup_commands:
- "pushd ray && git fetch && git checkout k8s && git reset --hard origin/k8s && popd"

# Custom commands that will be run on the head node after common setup.
head_setup_commands: []
Expand Down
10 changes: 6 additions & 4 deletions python/ray/autoscaler/node_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,18 +101,20 @@ def load_class(path):
return getattr(module, class_str)


def get_node_provider(provider_config, cluster_name):
def get_node_provider(config):
provider_config = config["provider"]
if provider_config["type"] == "external":
provider_cls = load_class(path=provider_config["module"])
return provider_cls(provider_config, cluster_name)
return provider_cls(provider_config, config["cluster_name"])

importer = NODE_PROVIDERS.get(provider_config["type"])

if importer is None:
raise NotImplementedError("Unsupported node provider: {}".format(
provider_config["type"]))
_, provider_cls = importer()
return provider_cls(provider_config, cluster_name)
bootstrap_config, provider_cls = importer()
bootstrap_config(config)
return provider_cls(provider_config, config["cluster_name"])


def get_default_config(provider_config):
Expand Down

0 comments on commit 3c454c2

Please sign in to comment.