Skip to content

fix: Expose BuiltinTrainer API to users #28

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions python/kubeflow/trainer/api/trainer_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import random
import string
import uuid
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Union

from kubeflow_trainer_api import models
from kubeflow.trainer.constants import constants
Expand Down Expand Up @@ -153,20 +153,22 @@ def train(
self,
runtime: types.Runtime = types.DEFAULT_RUNTIME,
initializer: Optional[types.Initializer] = None,
trainer: Optional[types.CustomTrainer] = None,
trainer: Optional[Union[types.CustomTrainer, types.BuiltinTrainer]] = None,
) -> str:
"""
Create the TrainJob. You can configure these types of training task:

- Custom Training Task: Training with a self-contained function that encapsulates
the entire model training process, e.g. `CustomTrainer`.
- Config-driven Task with Existing Trainer: Training with a trainer that already includes
the post-training logic, requiring only parameter adjustments, e.g. `BuiltinTrainer`.

Args:
runtime (`types.Runtime`): Reference to one of existing Runtimes.
initializer (`Optional[types.Initializer]`):
Configuration for the dataset and model initializers.
trainer (`Optional[types.CustomTrainer]`):
Configuration for Custom Training Task.
trainer (`Optional[types.CustomTrainer, types.BuiltinTrainer]`):
Configuration for Custom Training Task or Config-driven Task with Existing Trainer.

Returns:
str: The unique name of the TrainJob that has been generated.
Expand Down
3 changes: 3 additions & 0 deletions python/kubeflow/trainer/constants/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@
# Also, it represents the `trainjob-ancestor-step` label value for the model initializer step.
MODEL_INITIALIZER = "model-initializer"

# The env name for the access token of dataset/model initializer.
INITIALIZER_ENV_ACCESS_TOKEN = "ACCESS_TOKEN"

# The default path to the users' workspace.
# TODO (andreyvelich): Discuss how to keep this path is sync with pkg.initializers.constants
WORKSPACE_PATH = "/workspace"
Expand Down
2 changes: 1 addition & 1 deletion python/kubeflow/trainer/types/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ class Trainer:
@dataclass
class Runtime:
name: str
trainer: Trainer
trainer: Optional[Trainer] = None
pretrained_model: Optional[str] = None


Expand Down
29 changes: 22 additions & 7 deletions python/kubeflow/trainer/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,10 @@ def get_entrypoint_using_train_func(
"""
Get the Trainer command and args from the given training function and parameters.
"""
# Check if the runtime has a trainer.
if not runtime.trainer:
raise ValueError(f"Runtime must have a trainer: {runtime}")

# Check if training function is callable.
if not callable(train_func):
raise ValueError(
Expand Down Expand Up @@ -365,7 +369,8 @@ def get_args_using_torchtune_config(
else initializer.dataset.storage_uri
)
storage_uri_parsed = urlparse(storage_uri)
relative_path = re.sub(r"^/[^/]+", "", storage_uri_parsed.path)
parts = storage_uri_parsed.path.strip("/").split("/")
relative_path = "/".join(parts[1:]) if len(parts) > 1 else "."

if "." in relative_path:
args.append(
Expand Down Expand Up @@ -493,7 +498,13 @@ def get_dataset_initializer(
dataset.storage_uri
if dataset.storage_uri.startswith("hf://")
else "hf://" + dataset.storage_uri
)
),
env=[
models.IoK8sApiCoreV1EnvVar(
name=constants.INITIALIZER_ENV_ACCESS_TOKEN,
value=dataset.access_token,
),
]
)

return dataset_initializer
Expand All @@ -514,7 +525,13 @@ def get_model_initializer(
model.storage_uri
if model.storage_uri.startswith("hf://")
else "hf://" + model.storage_uri
)
),
env=[
models.IoK8sApiCoreV1EnvVar(
name=constants.INITIALIZER_ENV_ACCESS_TOKEN,
value=model.access_token,
),
]
)

return model_initializer
Expand Down Expand Up @@ -559,12 +576,10 @@ def get_args_in_dataset_preprocess_config(
if dataset_preprocess_config.source:
if not isinstance(dataset_preprocess_config.source, types.DataFormat):
raise ValueError(
f"Invalid data format: {dataset_preprocess_config.source}."
f"Invalid data format: {dataset_preprocess_config.source.value}."
)

args.append(f"dataset.source={dataset_preprocess_config.source}")

# Override the data dir or data files if it is provided.
args.append(f"dataset.source={dataset_preprocess_config.source.value}")

# Override the split field if it is provided.
if dataset_preprocess_config.split:
Expand Down