Skip to content

Add Python docstrings for classes, attributes, and methods #129

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jul 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion replicate/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,13 @@ def trainings(self) -> TrainingCollection:

def run(self, model_version: str, **kwargs) -> Union[Any, Iterator[Any]]:
"""
Run a model in the format owner/name:version.
Run a model and wait for its output.

Args:
model_version: The model version to run, in the format `owner/name:version`
kwargs: The input to the model, as a dictionary
Returns:
The output of the model
"""
# Split model_version into owner, name, version in format owner/name:version
m = re.match(r"^(?P<model>[^/]+/[^:]+):(?P<version>.+)$", model_version)
Expand Down
3 changes: 1 addition & 2 deletions replicate/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@

class Collection(abc.ABC, Generic[Model]):
"""
A base class for representing all objects of a particular type on the
server.
A base class for representing objects of a particular type on the server.
"""

def __init__(self, client: "Client") -> None:
Expand Down
10 changes: 9 additions & 1 deletion replicate/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,16 @@

def upload_file(fh: io.IOBase, output_file_prefix: Optional[str] = None) -> str:
"""
Lifted straight from cog.files
Upload a file to the server.

Args:
fh: A file handle to upload.
output_file_prefix: A string to prepend to the output file name.
Returns:
str: A URL to the uploaded file.
"""
# Lifted straight from cog.files

fh.seek(0)

if output_file_prefix is not None:
Expand Down
4 changes: 3 additions & 1 deletion replicate/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@ def encode_json(
obj: Any, upload_file: Callable[[io.IOBase], str] # noqa: ANN401
) -> Any: # noqa: ANN401
"""
Returns a JSON-compatible version of the object. Effectively the same thing as cog.json.encode_json.
Return a JSON-compatible version of the object.
"""
# Effectively the same thing as cog.json.encode_json.

if isinstance(obj, dict):
return {key: encode_json(value, upload_file) for key, value in obj.items()}
if isinstance(obj, (list, set, frozenset, GeneratorType, tuple)):
Expand Down
28 changes: 28 additions & 0 deletions replicate/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,35 @@


class Model(BaseModel):
"""
A machine learning model hosted on Replicate.
"""

username: str
"""
The name of the user or organization that owns the model.
"""

name: str
"""
The name of the model.
"""

def predict(self, *args, **kwargs) -> None:
"""
DEPRECATED: Use `version.predict()` instead.
"""

raise ReplicateException(
"The `model.predict()` method has been removed, because it's unstable: if a new version of the model you're using is pushed and its API has changed, your code may break. Use `version.predict()` instead. See https://github.com/replicate/replicate-python#readme"
)

@property
def versions(self) -> VersionCollection:
"""
Get the versions of this model.
"""

return VersionCollection(client=self._client, model=self)


Expand All @@ -27,6 +46,15 @@ def list(self) -> List[Model]:
raise NotImplementedError()

def get(self, name: str) -> Model:
"""
Get a model by name.

Args:
name: The name of the model, in the format `owner/model-name`.
Returns:
The model.
"""

# TODO: fetch model from server
# TODO: support permanent IDs
username, name = name.split("/")
Expand Down
80 changes: 73 additions & 7 deletions replicate/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,53 @@


class Prediction(BaseModel):
"""
A prediction made by a model hosted on Replicate.
"""

id: str
error: Optional[str]
"""The unique ID of the prediction."""

version: Optional[Version]
"""The version of the model used to create the prediction."""

status: str
"""The status of the prediction."""

input: Optional[Dict[str, Any]]
logs: Optional[str]
"""The input to the prediction."""

output: Optional[Any]
status: str
version: Optional[Version]
started_at: Optional[str]
"""The output of the prediction."""

logs: Optional[str]
"""The logs of the prediction."""

error: Optional[str]
"""The error encountered during the prediction, if any."""

created_at: Optional[str]
"""When the prediction was created."""

started_at: Optional[str]
"""When the prediction was started."""

completed_at: Optional[str]
"""When the prediction was completed, if finished."""

urls: Optional[Dict[str, str]]
"""
URLs associated with the prediction.

The following keys are available:
- `get`: A URL to fetch the prediction.
- `cancel`: A URL to cancel the prediction.
"""

def wait(self) -> None:
"""Wait for prediction to finish."""
"""
Wait for prediction to finish.
"""
while self.status not in ["succeeded", "failed", "canceled"]:
time.sleep(self._client.poll_interval)
self.reload()
Expand All @@ -48,14 +81,23 @@ def output_iterator(self) -> Iterator[Any]:
yield output

def cancel(self) -> None:
"""Cancel a currently running prediction"""
"""
Cancels a running prediction.
"""
self._client._request("POST", f"/v1/predictions/{self.id}/cancel")


class PredictionCollection(Collection):
model = Prediction

def list(self) -> List[Prediction]:
"""
List your predictions.

Returns:
A list of prediction objects.
"""

resp = self._client._request("GET", "/v1/predictions")
# TODO: paginate
predictions = resp.json()["results"]
Expand All @@ -65,6 +107,15 @@ def list(self) -> List[Prediction]:
return [self.prepare_model(obj) for obj in predictions]

def get(self, id: str) -> Prediction:
"""
Get a prediction by ID.

Args:
id: The ID of the prediction.
Returns:
Prediction: The prediction object.
"""

resp = self._client._request("GET", f"/v1/predictions/{id}")
obj = resp.json()
# HACK: resolve this? make it lazy somehow?
Expand All @@ -80,6 +131,21 @@ def create( # type: ignore
webhook_events_filter: Optional[List[str]] = None,
**kwargs,
) -> Prediction:
"""
Create a new prediction for the specified model version.

Args:
version: The model version to use for the prediction.
input: The input data for the prediction.
webhook: The URL to receive a POST request with prediction updates.
webhook_completed: The URL to receive a POST request when the prediction is completed.
webhook_events_filter: List of events to trigger webhooks.
stream: Set to True to enable streaming of prediction output.

Returns:
Prediction: The created prediction object.
"""

input = encode_json(input, upload_file=upload_file)
body = {
"version": version.id,
Expand Down
77 changes: 70 additions & 7 deletions replicate/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,51 @@


class Training(BaseModel):
completed_at: Optional[str]
created_at: Optional[str]
destination: Optional[str]
error: Optional[str]
"""
A training made for a model hosted on Replicate.
"""

id: str
"""The unique ID of the training."""

version: Optional[Version]
"""The version of the model used to create the training."""

destination: Optional[str]
"""The model destination of the training."""

status: str
"""The status of the training."""

input: Optional[Dict[str, Any]]
logs: Optional[str]
"""The input to the training."""

output: Optional[Any]
"""The output of the training."""

logs: Optional[str]
"""The logs of the training."""

error: Optional[str]
"""The error encountered during the training, if any."""

created_at: Optional[str]
"""When the training was created."""

started_at: Optional[str]
status: str
version: Optional[Version]
"""When the training was started."""

completed_at: Optional[str]
"""When the training was completed, if finished."""

urls: Optional[Dict[str, str]]
"""
URLs associated with the training.

The following keys are available:
- `get`: A URL to fetch the training.
- `cancel`: A URL to cancel the training.
"""

def cancel(self) -> None:
"""Cancel a running training"""
Expand All @@ -31,6 +65,13 @@ class TrainingCollection(Collection):
model = Training

def list(self) -> List[Training]:
"""
List your trainings.

Returns:
List[Training]: A list of training objects.
"""

resp = self._client._request("GET", "/v1/trainings")
# TODO: paginate
trainings = resp.json()["results"]
Expand All @@ -40,6 +81,15 @@ def list(self) -> List[Training]:
return [self.prepare_model(obj) for obj in trainings]

def get(self, id: str) -> Training:
"""
Get a training by ID.

Args:
id: The ID of the training.
Returns:
Training: The training object.
"""

resp = self._client._request(
"GET",
f"/v1/trainings/{id}",
Expand All @@ -58,6 +108,19 @@ def create( # type: ignore
webhook_events_filter: Optional[List[str]] = None,
**kwargs,
) -> Training:
"""
Create a new training using the specified model version as a base.

Args:
version: The ID of the base model version that you're using to train a new model version.
input: The input to the training.
destination: The desired model to push to in the format `{owner}/{model_name}`. This should be an existing model owned by the user or organization making the API request.
webhook: The URL to send a POST request to when the training is completed. Defaults to None.
webhook_events_filter: The events to send to the webhook. Defaults to None.
Returns:
The training object.
"""

input = encode_json(input, upload_file=upload_file)
body = {
"input": input,
Expand Down
30 changes: 29 additions & 1 deletion replicate/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,32 @@


class Version(BaseModel):
"""
A version of a model.
"""

id: str
"""The unique ID of the version."""

created_at: datetime.datetime
"""When the version was created."""

cog_version: str
"""The version of the Cog used to create the version."""

openapi_schema: dict
"""An OpenAPI description of the model inputs and outputs."""

def predict(self, **kwargs) -> Union[Any, Iterator[Any]]:
"""
Create a prediction using this model version.

Args:
kwargs: The input to the model.
Returns:
The output of the model.
"""

warnings.warn(
"version.predict() is deprecated. Use replicate.run() instead. It will be removed before version 1.0.",
DeprecationWarning,
Expand Down Expand Up @@ -57,7 +77,12 @@ def __init__(self, client: "Client", model: "Model") -> None:
# doesn't exist yet
def get(self, id: str) -> Version:
"""
Get a specific version.
Get a specific model version.

Args:
id: The version ID.
Returns:
The model version.
"""
resp = self._client._request(
"GET", f"/v1/models/{self._model.username}/{self._model.name}/versions/{id}"
Expand All @@ -70,6 +95,9 @@ def create(self, **kwargs) -> Version:
def list(self) -> List[Version]:
"""
Return a list of all versions for a model.

Returns:
List[Version]: A list of version objects.
"""
resp = self._client._request(
"GET", f"/v1/models/{self._model.username}/{self._model.name}/versions"
Expand Down