Skip to content

Add training #82

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions replicate/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
run = default_client.run
models = default_client.models
predictions = default_client.predictions
trainings = default_client.trainings
5 changes: 5 additions & 0 deletions replicate/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from replicate.exceptions import ModelError, ReplicateError
from replicate.model import ModelCollection
from replicate.prediction import PredictionCollection
from replicate.training import TrainingCollection


class Client:
Expand Down Expand Up @@ -107,6 +108,10 @@ def models(self) -> ModelCollection:
def predictions(self) -> PredictionCollection:
return PredictionCollection(client=self)

@property
def trainings(self) -> TrainingCollection:
return TrainingCollection(client=self)

def run(self, model_version, **kwargs) -> Union[Any, Iterator[Any]]:
"""
Run a model in the format owner/name:version.
Expand Down
78 changes: 78 additions & 0 deletions replicate/training.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import re
import time
from typing import Any, Dict, Iterator, List, Optional

from replicate.base_model import BaseModel
from replicate.collection import Collection
from replicate.exceptions import ModelError, ReplicateException
from replicate.files import upload_file
from replicate.json import encode_json
from replicate.version import Version


class Training(BaseModel):
completed_at: Optional[str]
created_at: Optional[str]
destination: Optional[str]
error: Optional[str]
id: str
input: Optional[Dict[str, Any]]
logs: Optional[str]
output: Optional[Any]
started_at: Optional[str]
status: str
version: str

def cancel(self):
"""Cancel a running training"""
self._client._request("POST", f"/v1/trainings/{self.id}/cancel")


class TrainingCollection(Collection):
model = Training

def create(
self,
version: str,
input: Dict[str, Any],
destination: str,
webhook: Optional[str] = None,
webhook_events_filter: Optional[List[str]] = None,
) -> Training:
input = encode_json(input, upload_file=upload_file)
body = {
"input": input,
"destination": destination,
}
if webhook is not None:
body["webhook"] = webhook
if webhook_events_filter is not None:
body["webhook_events_filter"] = webhook_events_filter

# Split version in format "username/model_name:version_id"
match = re.match(
r"^(?P<username>[^/]+)/(?P<model_name>[^:]+):(?P<version_id>.+)$", version
)
if not match:
raise ReplicateException(
f"version must be in format username/model_name:version_id"
)
username = match.group("username")
model_name = match.group("model_name")
version_id = match.group("version_id")

resp = self._client._request(
"POST",
f"/v1/models/{username}/{model_name}/versions/{version_id}/trainings",
json=body,
)
obj = resp.json()
return self.prepare_model(obj)

def get(self, id: str) -> Training:
resp = self._client._request(
"GET",
f"/v1/trainings/{id}",
)
obj = resp.json()
return self.prepare_model(obj)