diff --git a/sdk/python/kubeflow/training/api/training_client.py b/sdk/python/kubeflow/training/api/training_client.py index 901a9e9028..507bd306d1 100644 --- a/sdk/python/kubeflow/training/api/training_client.py +++ b/sdk/python/kubeflow/training/api/training_client.py @@ -354,6 +354,7 @@ def create_job( env_vars: Optional[ Union[Dict[str, str], List[Union[models.V1EnvVar, models.V1EnvVar]]] ] = None, + pip_args: Optional[List[str]] = None, ): """Create the Training Job. Job can be created using one of the following options: @@ -418,7 +419,9 @@ def create_job( https://github.com/kubernetes-client/python/blob/master/kubernetes/docs/V1EnvVar.md) or a kubernetes.client.models.V1EnvFromSource (documented here: https://github.com/kubernetes-client/python/blob/master/kubernetes/docs/V1EnvFromSource.md) - + pip_args: List of args to pass to pip install that applies to all packages specified in + packages_to_install. For a full list of args, see the pip documentation + https://pip.pypa.io/en/stable/cli/pip_install/ Raises: ValueError: Invalid input parameters. TimeoutError: Timeout to create Job. @@ -486,6 +489,7 @@ def create_job( train_func_parameters=parameters, packages_to_install=packages_to_install, pip_index_url=pip_index_url, + pip_args=pip_args, ) # Get Training Container template. diff --git a/sdk/python/kubeflow/training/utils/utils.py b/sdk/python/kubeflow/training/utils/utils.py index 5389f10baf..31565bd2b8 100644 --- a/sdk/python/kubeflow/training/utils/utils.py +++ b/sdk/python/kubeflow/training/utils/utils.py @@ -110,13 +110,13 @@ def has_condition(conditions: List[models.V1JobCondition], condition_type: str) def get_script_for_python_packages( - packages_to_install: List[str], pip_index_url: str + packages_to_install: List[str], pip_index_url: str, pip_args: Optional[List[str]] ) -> str: """ Get init script to install Python packages from the given pip index URL. """ packages_str = " ".join([str(package) for package in packages_to_install]) - + pip_args_str = " ".join(pip_args) if pip_args is not None else "" script_for_python_packages = textwrap.dedent( f""" if ! [ -x "$(command -v pip)" ]; then @@ -124,7 +124,7 @@ def get_script_for_python_packages( fi PIP_DISABLE_PIP_VERSION_CHECK=1 python -m pip install --quiet \ - --no-warn-script-location --index-url {pip_index_url} {packages_str} + --no-warn-script-location --index-url {pip_index_url} {pip_args_str} {packages_str} """ ) @@ -137,6 +137,7 @@ def get_command_using_train_func( train_func_parameters: Optional[Dict[str, Any]] = None, packages_to_install: Optional[List[str]] = None, pip_index_url: str = constants.DEFAULT_PIP_INDEX_URL, + pip_args: Optional[List[str]] = None ) -> Tuple[List[str], List[str]]: """ Get container args and command from the given training function and parameters. @@ -180,7 +181,7 @@ def get_command_using_train_func( # Install Python packages if that is required. if packages_to_install is not None: exec_script = ( - get_script_for_python_packages(packages_to_install, pip_index_url) + get_script_for_python_packages(packages_to_install, pip_index_url, pip_args) + exec_script )