Skip to content

RSDK-7191: add mltraining wrappers #602

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 67 additions & 5 deletions src/viam/app/ml_training_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,19 @@
from grpclib.client import Channel

from viam import logging
from viam.proto.app.data import Filter
from viam.proto.app.mltraining import (
CancelTrainingJobRequest,
DeleteCompletedTrainingJobRequest,
GetTrainingJobRequest,
GetTrainingJobResponse,
ListTrainingJobsRequest,
ListTrainingJobsResponse,
MLTrainingServiceStub,
ModelType,
SubmitCustomTrainingJobRequest,
SubmitCustomTrainingJobResponse,
SubmitTrainingJobRequest,
SubmitTrainingJobResponse,
TrainingJobMetadata,
TrainingStatus,
)
Expand Down Expand Up @@ -66,13 +70,62 @@ def __init__(self, channel: Channel, metadata: Mapping[str, str]):
async def submit_training_job(
self,
org_id: str,
dataset_id: str,
model_name: str,
model_version: str,
model_type: ModelType,
model_type: ModelType.ValueType,
tags: List[str],
filter: Optional[Filter] = None,
) -> str:
raise NotImplementedError()
"""Submit a training job.

Args:
org_id (str): the id of the org to submit the training job to
dataset_id (str): the id of the dataset
model_name (str): the model name
model_version (str): the model version
model_type (ModelType.ValueType): the model type
tags (List[str]): the tags

Returns:
str: the id of the training job
"""

request = SubmitTrainingJobRequest(
dataset_id=dataset_id,
organization_id=org_id,
model_name=model_name,
model_version=model_version,
model_type=model_type,
tags=tags,
)
response: SubmitTrainingJobResponse = await self._ml_training_client.SubmitTrainingJob(request, metadata=self._metadata)
return response.id

async def submit_custom_training_job(
self, org_id: str, dataset_id: str, registry_item_id: str, model_name: str, model_version: str
) -> str:
"""Submit a custom training job.

Args:
org_id (str): the id of the org to submit the training job to
dataset_id (str): the id of the dataset
registry_item_id (List[str]): the id of the registry item
model_name (str): the model name
model_version (str): the model version

Returns:
str: the id of the training job
"""

request = SubmitCustomTrainingJobRequest(
dataset_id=dataset_id,
registry_item_id=registry_item_id,
organization_id=org_id,
model_name=model_name,
model_version=model_version,
)
response: SubmitCustomTrainingJobResponse = await self._ml_training_client.SubmitCustomTrainingJob(request, metadata=self._metadata)
return response.id

async def get_training_job(self, id: str) -> TrainingJobMetadata:
"""Gets training job data.
Expand All @@ -83,7 +136,7 @@ async def get_training_job(self, id: str) -> TrainingJobMetadata:
id="INSERT YOUR JOB ID")

Args:
id (str): id of the requested training job.
id (str): the id of the requested training job.

Returns:
viam.proto.app.mltraining.TrainingJobMetadata: training job data.
Expand Down Expand Up @@ -140,3 +193,12 @@ async def cancel_training_job(self, id: str) -> None:

request = CancelTrainingJobRequest(id=id)
await self._ml_training_client.CancelTrainingJob(request, metadata=self._metadata)

async def delete_completed_training_job(self, id: str) -> None:
"""Delete a completed training job from the database, whether the job succeeded or failed
Args:
id (str): the id of the training job
"""

request = DeleteCompletedTrainingJobRequest(id=id)
await self._ml_training_client.DeleteCompletedTrainingJob(request, metadata=self._metadata)
15 changes: 13 additions & 2 deletions tests/mocks/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -1023,6 +1023,7 @@ def __init__(self, job_id: str, training_metadata: TrainingJobMetadata):
async def SubmitTrainingJob(self, stream: Stream[SubmitTrainingJobRequest, SubmitTrainingJobResponse]) -> None:
request = await stream.recv_message()
assert request is not None
self.dataset_id = request.dataset_id
self.org_id = request.organization_id
self.model_name = request.model_name
self.model_version = request.model_version
Expand All @@ -1031,7 +1032,14 @@ async def SubmitTrainingJob(self, stream: Stream[SubmitTrainingJobRequest, Submi
await stream.send_message(SubmitTrainingJobResponse(id=self.job_id))

async def SubmitCustomTrainingJob(self, stream: Stream[SubmitCustomTrainingJobRequest, SubmitCustomTrainingJobResponse]) -> None:
return await super().SubmitCustomTrainingJob(stream)
request = await stream.recv_message()
assert request is not None
self.dataset_id = request.dataset_id
self.registry_item_id = request.registry_item_id
self.org_id = request.organization_id
self.model_name = request.model_name
self.model_version = request.model_version
await stream.send_message(SubmitCustomTrainingJobResponse(id=self.job_id))

async def GetTrainingJob(self, stream: Stream[GetTrainingJobRequest, GetTrainingJobResponse]) -> None:
request = await stream.recv_message()
Expand All @@ -1055,7 +1063,10 @@ async def CancelTrainingJob(self, stream: Stream[CancelTrainingJobRequest, Cance
async def DeleteCompletedTrainingJob(
self, stream: Stream[DeleteCompletedTrainingJobRequest, DeleteCompletedTrainingJobResponse]
) -> None:
raise NotImplementedError()
request = await stream.recv_message()
assert request is not None
self.delete_id = request.id
await stream.send_message(DeleteCompletedTrainingJobResponse())


class MockBilling(BillingServiceBase):
Expand Down
8 changes: 8 additions & 0 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,11 @@ async def test_list_datasets_by_organization_id(self, service: MockDataset):
datasets = await client.list_datasets_by_organization_id(ORG_ID)
assert service.org_id == ORG_ID
assert datasets == DATASETS

@pytest.mark.asyncio
async def test_rename_dataset(self, service: MockDataset):
async with ChannelFor([service]) as channel:
client = DataClient(channel, DATA_SERVICE_METADATA)
await client.rename_dataset(ID, NAME)
assert service.id == ID
assert service.name == NAME
26 changes: 25 additions & 1 deletion tests/test_ml_training_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,11 @@
ID = "id"
TRAINING_JOB_ID = "training-job-id"
CANCEL_ID = "cancel-id"
DELETE_ID = "delete-id"
JOB_ID = "job-id"
ORG_ID = "org-id"
DATASET_ID = "dataset-id"
REGISTRY_ITEM_ID = "registry-item-id"
MODEL_ID = "model-id"
MODEL_NAME = "model-name"
MODEL_VERSION = "model-version"
Expand Down Expand Up @@ -66,7 +69,21 @@ async def test_cancel_training_job(self, service: MockMLTraining):

@pytest.mark.asyncio
async def test_submit_training_job(self, service: MockMLTraining):
assert True
async with ChannelFor([service]) as channel:
client = MLTrainingClient(channel, ML_TRAINING_SERVICE_METADATA)
id = await client.submit_training_job(
org_id=ORG_ID, dataset_id=DATASET_ID, model_name=MODEL_NAME, model_version=MODEL_VERSION, model_type=MODEL_TYPE, tags=TAGS
)
assert id == JOB_ID

@pytest.mark.asyncio
async def test_custom_submit_training_job(self, service: MockMLTraining):
async with ChannelFor([service]) as channel:
client = MLTrainingClient(channel, ML_TRAINING_SERVICE_METADATA)
id = await client.submit_custom_training_job(
org_id=ORG_ID, dataset_id=DATASET_ID, registry_item_id=REGISTRY_ITEM_ID, model_name=MODEL_NAME, model_version=MODEL_VERSION
)
assert id == JOB_ID

@pytest.mark.asyncio
async def test_get_training_job(self, service: MockMLTraining):
Expand All @@ -85,3 +102,10 @@ async def test_list_training_jobs(self, service: MockMLTraining):
assert training_jobs[0] == TRAINING_METADATA
assert service.training_status == TRAINING_STATUS
assert service.org_id == ORG_ID

@pytest.mark.asyncio
async def test_delete_completed_training_job(self, service: MockMLTraining):
async with ChannelFor([service]) as channel:
client = MLTrainingClient(channel, ML_TRAINING_SERVICE_METADATA)
await client.delete_completed_training_job(DELETE_ID)
assert service.delete_id == DELETE_ID
Loading