Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(controller): rename for finetune info #2995

Merged
merged 6 commits into from
Nov 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions client/starwhale/api/_impl/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from starwhale.consts import DecoratorInjectAttr
from starwhale.base.context import Context
from starwhale.api._impl.model import build as build_starwhale_model
from starwhale.base.client.models.models import FineTune


# TODO: support arguments
Expand Down Expand Up @@ -117,10 +118,12 @@ def _register_ft(
needs=needs,
require_dataset=require_train_datasets or require_validation_datasets,
extra_kwargs=dict(
require_train_datasets=require_train_datasets,
require_validation_datasets=require_validation_datasets,
auto_build_model=auto_build_model,
),
built_in=True,
fine_tune=FineTune(
require_train_datasets=require_train_datasets,
require_validation_datasets=require_validation_datasets,
),
)(func)
setattr(func, DecoratorInjectAttr.FineTune, True)
9 changes: 8 additions & 1 deletion client/starwhale/api/_impl/job/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@
from starwhale.utils.error import NoSupportError
from starwhale.base.models.model import StepSpecClient
from starwhale.api._impl.evaluation import PipelineHandler
from starwhale.base.client.models.models import RuntimeResource, ParameterSignature
from starwhale.base.client.models.models import (
FineTune,
RuntimeResource,
ParameterSignature,
)


class Handler(StepSpecClient):
Expand Down Expand Up @@ -100,6 +104,7 @@ def register(
expose: int = 0,
require_dataset: bool = False,
built_in: bool = False,
fine_tune: FineTune | None = None,
) -> t.Callable:
"""Register a function as a handler. Enable the function execute by needs handler, run with gpu/cpu/mem resources in server side,
and control replicas of handler run.
Expand All @@ -119,6 +124,7 @@ def register(
If True, You must select datasets when executing on the server or cloud instance.
built_in: [bool, optional] A special flag to distinguish user defined args in handler function from the StarWhale ones.
This should always be False unless you know what it does.
fine_tune: [FineTune, optional The fine tune config for the handler. Default is None.

Example:
```python
Expand Down Expand Up @@ -195,6 +201,7 @@ def decorator(func: t.Callable) -> t.Callable:
require_dataset=require_dataset,
parameters_sig=parameters_sig,
ext_cmd_args=ext_cmd_args,
fine_tune=fine_tune,
)

cls._register(_handler, func)
Expand Down
17 changes: 15 additions & 2 deletions client/starwhale/base/client/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -725,6 +725,11 @@ class Env(SwBaseModel):
value: Optional[str] = None


class FineTune(SwBaseModel):
require_train_datasets: Optional[bool] = None
require_validation_datasets: Optional[bool] = None


class ParameterSignature(SwBaseModel):
name: str
required: Optional[bool] = None
Expand All @@ -750,8 +755,7 @@ class StepSpec(SwBaseModel):
job_name: Optional[str] = None
show_name: str
require_dataset: Optional[bool] = None
require_train_datasets: Optional[bool] = None
require_validation_datasets: Optional[bool] = None
fine_tune: Optional[FineTune] = None
container_spec: Optional[ContainerSpec] = None
ext_cmd_args: Optional[str] = None
parameters_sig: Optional[List[ParameterSignature]] = None
Expand Down Expand Up @@ -914,6 +918,14 @@ class ExposedLinkVo(SwBaseModel):
link: str


class JobType(Enum):
evaluation = 'EVALUATION'
train = 'TRAIN'
fine_tune = 'FINE_TUNE'
serving = 'SERVING'
built_in = 'BUILT_IN'


class JobStatus(Enum):
created = 'CREATED'
ready = 'READY'
Expand All @@ -934,6 +946,7 @@ class JobVo(SwBaseModel):
model_version: str = Field(..., alias='modelVersion')
model: ModelVo
job_name: Optional[str] = Field(None, alias='jobName')
job_type: Optional[JobType] = Field(None, alias='jobType')
datasets: Optional[List[str]] = None
dataset_list: Optional[List[DatasetVo]] = Field(None, alias='datasetList')
runtime: RuntimeVo
Expand Down
24 changes: 17 additions & 7 deletions client/tests/sdk/test_job_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@
from starwhale.base.scheduler import Step, Scheduler, TaskExecutor
from starwhale.base.uri.project import Project
from starwhale.base.models.model import JobHandlers, StepSpecClient
from starwhale.base.client.models.models import RuntimeResource, ParameterSignature
from starwhale.base.client.models.models import (
FineTune,
RuntimeResource,
ParameterSignature,
)


class JobTestCase(unittest.TestCase):
Expand Down Expand Up @@ -782,9 +786,11 @@ def ft2(): ...
ext_cmd_args="",
extra_kwargs={
"auto_build_model": True,
"require_train_datasets": True,
"require_validation_datasets": True,
},
fine_tune=FineTune(
require_train_datasets=True,
require_validation_datasets=True,
),
),
],
"mock_user_module:ft2": [
Expand All @@ -803,9 +809,11 @@ def ft2(): ...
ext_cmd_args="",
extra_kwargs={
"auto_build_model": True,
"require_train_datasets": True,
"require_validation_datasets": True,
},
fine_tune=FineTune(
require_train_datasets=True,
require_validation_datasets=True,
),
),
StepSpecClient(
name="mock_user_module:ft2",
Expand All @@ -828,9 +836,11 @@ def ft2(): ...
ext_cmd_args="",
extra_kwargs={
"auto_build_model": False,
"require_train_datasets": False,
"require_validation_datasets": False,
},
fine_tune=FineTune(
require_train_datasets=False,
require_validation_datasets=False,
),
),
],
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import ai.starwhale.mlops.api.protocol.model.ModelVo;
import ai.starwhale.mlops.api.protocol.runtime.RuntimeVo;
import ai.starwhale.mlops.api.protocol.user.UserVo;
import ai.starwhale.mlops.domain.job.JobType;
import ai.starwhale.mlops.domain.job.status.JobStatus;
import com.fasterxml.jackson.annotation.JsonProperty;
import io.swagger.v3.oas.annotations.media.Schema;
Expand Down Expand Up @@ -60,6 +61,9 @@ public class JobVo implements Serializable {
@JsonProperty("jobName")
private String jobName;

@JsonProperty("jobType")
private JobType jobType;

@JsonProperty("datasets")
@Valid
private List<String> datasets;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ public JobVo convert(JobEntity jobEntity) throws ConvertException {

return JobVo.builder()
.id(idConvertor.convert(jobEntity.getId()))
.jobType(jobEntity.getType())
.uuid(jobEntity.getJobUuid())
.owner(UserVo.fromEntity(jobEntity.getOwner(), idConvertor))
.modelName(jobEntity.getModelName())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,17 +84,15 @@ public class StepSpec {

@Data
@JsonIgnoreProperties(ignoreUnknown = true)
public static class ExtraParams {
@JsonProperty("auto_build_model")
private Boolean autoBuildModel;
public static class FineTune {
@JsonProperty("require_train_datasets")
private Boolean requireTrainDatasets;
@JsonProperty("require_validation_datasets")
private Boolean requireValidationDatasets;
}

@JsonProperty("extra_kwargs")
private ExtraParams extraParams;
@JsonProperty("fine_tune")
private FineTune finetune;

@JsonProperty("container_spec")
ContainerSpec containerSpec;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,7 @@ public void testParseFromYamlContent() throws JsonProcessingException {
+ " resources: []\n"
+ " name: mnist.evaluator:MNISTInference.ppl\n"
+ " replicas: 1\n"
+ " extra_kwargs:\n"
+ " auto_build_model: true\n"
+ " fine_tune:\n"
+ " require_train_datasets: true\n"
+ " require_validation_datasets: false\n"
+ " env:\n"
Expand All @@ -81,7 +80,7 @@ public void testParseFromYamlContent() throws JsonProcessingException {
new Env("EVAL_DATASET", "imagenet"),
new Env("EVAL_MODEL", "resnet50")
));
Assertions.assertTrue(stepMetaDatas.get(2).getExtraParams().getRequireTrainDatasets());
Assertions.assertFalse(stepMetaDatas.get(2).getExtraParams().getRequireValidationDatasets());
Assertions.assertTrue(stepMetaDatas.get(2).getFinetune().getRequireTrainDatasets());
Assertions.assertFalse(stepMetaDatas.get(2).getFinetune().getRequireValidationDatasets());
}
}
Loading