Skip to content

Commit

Permalink
feat: env variables in notebook models (#59)
Browse files Browse the repository at this point in the history
## Describe your changes

## Issue ticket number and link

## Checklist before requesting a review
- [x] I have performed a self-review of my code
- [ ] I have commented my code, particularly in hard-to-understand areas
- [x] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [ ] If it is a core feature, I have added thorough tests.
  • Loading branch information
michalmrazek authored Nov 1, 2022
2 parents 5b34a11 + e5f165a commit e738788
Show file tree
Hide file tree
Showing 8 changed files with 37 additions and 4 deletions.
1 change: 1 addition & 0 deletions src/wanna/core/deployment/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ class JobResource(GCPResource, Generic[JOB]):
network: Optional[str]
job_config: JOB
encryption_spec: Optional[str]
environment_variables: Optional[Dict[str, str]]


class PushArtifact(BaseModel):
Expand Down
2 changes: 2 additions & 0 deletions src/wanna/core/models/gcp_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class GCPProfileModel(BaseModel, extra=Extra.forbid):
If you get an error, please grant the Service Account with the Cloud KMS CryptoKey Encrypter/Decrypter role
- `docker_repository` - [str] Wanna Docker Repository
- `docker_registry` - [str] (optional) Wanna Docker Registry, usually in format {region}-docker.pkg.dev
- `env_vars` - Dict[str, str] (optional) Environment variables to be propagated to all the notebooks and custom jobs
"""

profile_name: str
Expand All @@ -44,6 +45,7 @@ class GCPProfileModel(BaseModel, extra=Extra.forbid):
kms_key: Optional[str]
docker_repository: str = "wanna"
docker_registry: Optional[str]
env_vars: Optional[Dict[str, str]]

_ = validator("project_id", allow_reuse=True)(validators.validate_project_id)
_ = validator("zone", allow_reuse=True)(validators.validate_zone)
Expand Down
4 changes: 3 additions & 1 deletion src/wanna/core/models/notebook.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional
from typing import Dict, List, Optional

from pydantic import BaseModel, EmailStr, Extra, Field, root_validator, validator

Expand Down Expand Up @@ -50,6 +50,7 @@ class NotebookModel(BaseInstanceModel):
- `no_proxy_access` - [bool] (optional) If true, the notebook instance will not register with the proxy
- `idle_shutdown_timeout` - [int] (optional) Time in minutes, between 10 and 1440. After this time of inactivity,
notebook will be stopped. If the parameter is not set, we don't do anything.
- `env_vars` - Dict[str, str] (optional) Environment variables to be propagated to the notebook
- `backup` - [str] (optional) Name of the bucket where a data backup is copied (no 'gs://' needed in the name).
After creation, any changes (including deletion) made to the data disk contents will be synced to the GCS location
It’s recommended that you enable object versioning for the selected location so you can restore accidentally
Expand All @@ -74,6 +75,7 @@ class NotebookModel(BaseInstanceModel):
no_public_ip: bool = True
no_proxy_access: bool = False
idle_shutdown_timeout: Optional[int]
env_vars: Optional[Dict[str, str]]
backup: Optional[str]

_machine_type = validator("machine_type")(validators.validate_machine_type)
Expand Down
2 changes: 2 additions & 0 deletions src/wanna/core/models/training_custom_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ class BaseCustomJobModel(BaseInstanceModel):
- `encryption_spec`- [str] (optional) The Cloud KMS resource identifier. Has the form:
projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key
The key needs to be in the same region as where the compute resource is created
- `env_vars` - Dict[str, str] (optional) Environment variables to be propagated to the job
"""

region: str
Expand All @@ -134,6 +135,7 @@ class BaseCustomJobModel(BaseInstanceModel):
tensorboard_ref: Optional[str]
timeout_seconds: int = 60 * 60 * 24 # 24 hours
encryption_spec: Optional[Any]
env_vars: Optional[Dict[str, str]]

@root_validator(pre=False)
def _set_base_output_directory_if_not_provided( # pylint: disable=no-self-argument,no-self-use
Expand Down
14 changes: 12 additions & 2 deletions src/wanna/core/services/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,9 @@ def _create_custom_job_resource(
encryption_spec_key_name = (
job_model.encryption_spec if job_model.encryption_spec else self.config.gcp_profile.kms_key
)

env_vars = self.config.gcp_profile.env_vars if self.config.gcp_profile.env_vars else dict()
if job_model.env_vars:
env_vars = {**env_vars, **job_model.env_vars}
return JobResource[CustomJobModel](
name=job_model.name,
project=job_model.project_id,
Expand All @@ -242,6 +244,7 @@ def _create_custom_job_resource(
else None,
network=network,
encryption_spec=encryption_spec_key_name,
environment_variables=env_vars,
)

def _create_training_job_resource(
Expand Down Expand Up @@ -296,7 +299,9 @@ def _create_training_job_resource(
encryption_spec_key_name = (
job_model.encryption_spec if job_model.encryption_spec else self.config.gcp_profile.kms_key
)

env_vars = self.config.gcp_profile.env_vars if self.config.gcp_profile.env_vars else dict()
if job_model.env_vars:
env_vars = {**env_vars, **job_model.env_vars}
return JobResource[TrainingCustomJobModel](
name=job_model.name,
project=job_model.project_id,
Expand All @@ -309,6 +314,7 @@ def _create_training_job_resource(
else None,
network=network,
encryption_spec=encryption_spec_key_name,
environment_variables=env_vars,
)

def _create_worker_pool_spec(self, worker_pool_model: WorkerPoolModel) -> Tuple[str, WorkerPoolSpec]:
Expand Down Expand Up @@ -457,6 +463,9 @@ def write_manifest(
encryption_spec_key_name = (
resource.encryption_spec if resource.encryption_spec else self.config.gcp_profile.kms_key
)
env_vars = self.config.gcp_profile.env_vars if self.config.gcp_profile.env_vars else dict()
if resource.environment_variables:
env_vars = {**env_vars, **resource.environment_variables}
json_dict = {
"name": resource.name,
"project": resource.project,
Expand All @@ -467,6 +476,7 @@ def write_manifest(
"tensorboard": resource.tensorboard,
"network": resource.network,
"encryption_spec": encryption_spec_key_name,
"environment_variables": env_vars,
}
json_dump = json.dumps(
remove_nones(json_dict),
Expand Down
1 change: 1 addition & 0 deletions src/wanna/core/services/managed_notebook.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,7 @@ def _prepare_startup_script(self, nb_instance: ManagedNotebookModel) -> str:
)
else:
tensorboard_resource_name = None

startup_script = templates.render_template(
Path("notebook_startup_script.sh.j2"),
tensorboard_resource_name=tensorboard_resource_name,
Expand Down
8 changes: 8 additions & 0 deletions src/wanna/core/services/notebook.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,8 @@ def _create_instance_request(
notebook_instance.bucket_mounts
or notebook_instance.tensorboard_ref
or notebook_instance.idle_shutdown_timeout
or self.config.gcp_profile.env_vars
or notebook_instance.env_vars
):
script = self._prepare_startup_script(self.instances[0])
blob = upload_string_to_gcs(
Expand Down Expand Up @@ -340,17 +342,23 @@ def _prepare_startup_script(self, nb_instance: NotebookModel) -> str:
Returns:
startup_script
"""
env_vars = self.config.gcp_profile.env_vars if self.config.gcp_profile.env_vars else dict()
if nb_instance.env_vars:
env_vars = {**env_vars, **nb_instance.env_vars}

if nb_instance.tensorboard_ref:
tensorboard_resource_name = self.tensorboard_service.get_or_create_tensorboard_instance_by_name(
nb_instance.tensorboard_ref
)
else:
tensorboard_resource_name = None

startup_script = templates.render_template(
Path("notebook_startup_script.sh.j2"),
bucket_mounts=nb_instance.bucket_mounts,
tensorboard_resource_name=tensorboard_resource_name,
idle_shutdown_timeout=nb_instance.idle_shutdown_timeout,
env_vars=env_vars,
)
return startup_script

Expand Down
9 changes: 8 additions & 1 deletion src/wanna/core/templates/notebook_startup_script.sh.j2
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
#!/bin/bash
{% if bucket_mounts is not none %}

{% if env_vars is not none %}
echo '#!/bin/bash' | sudo tee /etc/profile.d/myenvvars.sh
{% for key, value in env_vars.items() %}
echo 'export {{ key }}={{ value }}' | sudo tee -a /etc/profile.d/myenvvars.sh
{% endfor %}
{% endif %}

{% if bucket_mounts is not none %}
{% for bucket_mount in bucket_mounts %}
sudo mkdir -p {{ bucket_mount["mount_path"] }}/{{ bucket_mount["bucket_name"] }}
sudo chmod -R 777 {{ bucket_mount["mount_path"] }}/{{ bucket_mount["bucket_name"] }}
Expand Down

0 comments on commit e738788

Please sign in to comment.