diff --git a/src/wanna/core/deployment/models.py b/src/wanna/core/deployment/models.py index a4267db..1297ed8 100644 --- a/src/wanna/core/deployment/models.py +++ b/src/wanna/core/deployment/models.py @@ -93,6 +93,7 @@ class JobResource(GCPResource, Generic[JOB]): network: Optional[str] job_config: JOB encryption_spec: Optional[str] + environment_variables: Optional[Dict[str, str]] class PushArtifact(BaseModel): diff --git a/src/wanna/core/models/gcp_profile.py b/src/wanna/core/models/gcp_profile.py index 847585c..23acc60 100644 --- a/src/wanna/core/models/gcp_profile.py +++ b/src/wanna/core/models/gcp_profile.py @@ -30,6 +30,7 @@ class GCPProfileModel(BaseModel, extra=Extra.forbid): If you get an error, please grant the Service Account with the Cloud KMS CryptoKey Encrypter/Decrypter role - `docker_repository` - [str] Wanna Docker Repository - `docker_registry` - [str] (optional) Wanna Docker Registry, usually in format {region}-docker.pkg.dev + - `env_vars` - Dict[str, str] (optional) Environment variables to be propagated to all the notebooks and custom jobs """ profile_name: str @@ -44,6 +45,7 @@ class GCPProfileModel(BaseModel, extra=Extra.forbid): kms_key: Optional[str] docker_repository: str = "wanna" docker_registry: Optional[str] + env_vars: Optional[Dict[str, str]] _ = validator("project_id", allow_reuse=True)(validators.validate_project_id) _ = validator("zone", allow_reuse=True)(validators.validate_zone) diff --git a/src/wanna/core/models/notebook.py b/src/wanna/core/models/notebook.py index 70f6aac..70b5269 100644 --- a/src/wanna/core/models/notebook.py +++ b/src/wanna/core/models/notebook.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import Dict, List, Optional from pydantic import BaseModel, EmailStr, Extra, Field, root_validator, validator @@ -50,6 +50,7 @@ class NotebookModel(BaseInstanceModel): - `no_proxy_access` - [bool] (optional) If true, the notebook instance will not register with the proxy - `idle_shutdown_timeout` - [int] (optional) Time in minutes, between 10 and 1440. After this time of inactivity, notebook will be stopped. If the parameter is not set, we don't do anything. + - `env_vars` - Dict[str, str] (optional) Environment variables to be propagated to the notebook - `backup` - [str] (optional) Name of the bucket where a data backup is copied (no 'gs://' needed in the name). After creation, any changes (including deletion) made to the data disk contents will be synced to the GCS location It’s recommended that you enable object versioning for the selected location so you can restore accidentally @@ -74,6 +75,7 @@ class NotebookModel(BaseInstanceModel): no_public_ip: bool = True no_proxy_access: bool = False idle_shutdown_timeout: Optional[int] + env_vars: Optional[Dict[str, str]] backup: Optional[str] _machine_type = validator("machine_type")(validators.validate_machine_type) diff --git a/src/wanna/core/models/training_custom_job.py b/src/wanna/core/models/training_custom_job.py index 5dc6da2..ac71f1f 100644 --- a/src/wanna/core/models/training_custom_job.py +++ b/src/wanna/core/models/training_custom_job.py @@ -125,6 +125,7 @@ class BaseCustomJobModel(BaseInstanceModel): - `encryption_spec`- [str] (optional) The Cloud KMS resource identifier. Has the form: projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key The key needs to be in the same region as where the compute resource is created + - `env_vars` - Dict[str, str] (optional) Environment variables to be propagated to the job """ region: str @@ -134,6 +135,7 @@ class BaseCustomJobModel(BaseInstanceModel): tensorboard_ref: Optional[str] timeout_seconds: int = 60 * 60 * 24 # 24 hours encryption_spec: Optional[Any] + env_vars: Optional[Dict[str, str]] @root_validator(pre=False) def _set_base_output_directory_if_not_provided( # pylint: disable=no-self-argument,no-self-use diff --git a/src/wanna/core/services/jobs.py b/src/wanna/core/services/jobs.py index 8a903bf..5ccc238 100644 --- a/src/wanna/core/services/jobs.py +++ b/src/wanna/core/services/jobs.py @@ -219,7 +219,9 @@ def _create_custom_job_resource( encryption_spec_key_name = ( job_model.encryption_spec if job_model.encryption_spec else self.config.gcp_profile.kms_key ) - + env_vars = self.config.gcp_profile.env_vars if self.config.gcp_profile.env_vars else dict() + if job_model.env_vars: + env_vars = {**env_vars, **job_model.env_vars} return JobResource[CustomJobModel]( name=job_model.name, project=job_model.project_id, @@ -242,6 +244,7 @@ def _create_custom_job_resource( else None, network=network, encryption_spec=encryption_spec_key_name, + environment_variables=env_vars, ) def _create_training_job_resource( @@ -296,7 +299,9 @@ def _create_training_job_resource( encryption_spec_key_name = ( job_model.encryption_spec if job_model.encryption_spec else self.config.gcp_profile.kms_key ) - + env_vars = self.config.gcp_profile.env_vars if self.config.gcp_profile.env_vars else dict() + if job_model.env_vars: + env_vars = {**env_vars, **job_model.env_vars} return JobResource[TrainingCustomJobModel]( name=job_model.name, project=job_model.project_id, @@ -309,6 +314,7 @@ def _create_training_job_resource( else None, network=network, encryption_spec=encryption_spec_key_name, + environment_variables=env_vars, ) def _create_worker_pool_spec(self, worker_pool_model: WorkerPoolModel) -> Tuple[str, WorkerPoolSpec]: @@ -457,6 +463,9 @@ def write_manifest( encryption_spec_key_name = ( resource.encryption_spec if resource.encryption_spec else self.config.gcp_profile.kms_key ) + env_vars = self.config.gcp_profile.env_vars if self.config.gcp_profile.env_vars else dict() + if resource.environment_variables: + env_vars = {**env_vars, **resource.environment_variables} json_dict = { "name": resource.name, "project": resource.project, @@ -467,6 +476,7 @@ def write_manifest( "tensorboard": resource.tensorboard, "network": resource.network, "encryption_spec": encryption_spec_key_name, + "environment_variables": env_vars, } json_dump = json.dumps( remove_nones(json_dict), diff --git a/src/wanna/core/services/managed_notebook.py b/src/wanna/core/services/managed_notebook.py index e8b89a2..3e294c0 100644 --- a/src/wanna/core/services/managed_notebook.py +++ b/src/wanna/core/services/managed_notebook.py @@ -303,6 +303,7 @@ def _prepare_startup_script(self, nb_instance: ManagedNotebookModel) -> str: ) else: tensorboard_resource_name = None + startup_script = templates.render_template( Path("notebook_startup_script.sh.j2"), tensorboard_resource_name=tensorboard_resource_name, diff --git a/src/wanna/core/services/notebook.py b/src/wanna/core/services/notebook.py index 3e2f7e9..06b8883 100644 --- a/src/wanna/core/services/notebook.py +++ b/src/wanna/core/services/notebook.py @@ -272,6 +272,8 @@ def _create_instance_request( notebook_instance.bucket_mounts or notebook_instance.tensorboard_ref or notebook_instance.idle_shutdown_timeout + or self.config.gcp_profile.env_vars + or notebook_instance.env_vars ): script = self._prepare_startup_script(self.instances[0]) blob = upload_string_to_gcs( @@ -340,17 +342,23 @@ def _prepare_startup_script(self, nb_instance: NotebookModel) -> str: Returns: startup_script """ + env_vars = self.config.gcp_profile.env_vars if self.config.gcp_profile.env_vars else dict() + if nb_instance.env_vars: + env_vars = {**env_vars, **nb_instance.env_vars} + if nb_instance.tensorboard_ref: tensorboard_resource_name = self.tensorboard_service.get_or_create_tensorboard_instance_by_name( nb_instance.tensorboard_ref ) else: tensorboard_resource_name = None + startup_script = templates.render_template( Path("notebook_startup_script.sh.j2"), bucket_mounts=nb_instance.bucket_mounts, tensorboard_resource_name=tensorboard_resource_name, idle_shutdown_timeout=nb_instance.idle_shutdown_timeout, + env_vars=env_vars, ) return startup_script diff --git a/src/wanna/core/templates/notebook_startup_script.sh.j2 b/src/wanna/core/templates/notebook_startup_script.sh.j2 index e1f2e30..9677e27 100644 --- a/src/wanna/core/templates/notebook_startup_script.sh.j2 +++ b/src/wanna/core/templates/notebook_startup_script.sh.j2 @@ -1,6 +1,13 @@ #!/bin/bash -{% if bucket_mounts is not none %} +{% if env_vars is not none %} +echo '#!/bin/bash' | sudo tee /etc/profile.d/myenvvars.sh +{% for key, value in env_vars.items() %} +echo 'export {{ key }}={{ value }}' | sudo tee -a /etc/profile.d/myenvvars.sh +{% endfor %} +{% endif %} + +{% if bucket_mounts is not none %} {% for bucket_mount in bucket_mounts %} sudo mkdir -p {{ bucket_mount["mount_path"] }}/{{ bucket_mount["bucket_name"] }} sudo chmod -R 777 {{ bucket_mount["mount_path"] }}/{{ bucket_mount["bucket_name"] }}