Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added

- Add official Python 3.13 support to the CI workflow and package release.
- Add `examples/torch/mlm-metadata.yaml` example that provides a minimal metadata example for
a PyTorch model which can be validated using the MLM Schema without the need to be fully
compliant with the STAC Specification.
- Add `ModelDataVariable` to `stac_model` for corresponding `mlm:input` and `mlm:output` definitions as the JSON schema.
- Add `variables` properties to [Model Input Object](README.md#model-input-object)
to allow specifying the relevant data variables used by the model,
Expand Down
16 changes: 16 additions & 0 deletions README_STAC_MODEL.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,22 @@ stac-model

This will make [this example item](./examples/item_basic.json) for an example model.

## Validating Model Metadata

An alternative use of `stac_model` is to validate config files containing model metadata using the `MLModelProperties` schema.

Given a YAML or JSON file with the structure in [examples/torch/mlm-metadata.yaml](./examples/torch/mlm-metadata.yaml), the model metadata can be validated as follows:

```python
import yaml
from stac_model.schema import MLModelProperties

with open("examples/mlm-metadata.yaml", "r", encoding="utf-8") as f:
metadata = yaml.safe_load(f)

MLModelProperties.model_validate(metadata["properties"])
```

## 📈 Releases

You can see the list of available releases on the [GitHub Releases][github-releases] page.
Expand Down
44 changes: 44 additions & 0 deletions examples/torch/mlm-metadata.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
$schema: "https://stac-extensions.github.io/mlm/v1.4.0/schema.json"
properties:
name: Unet
architecture: Unet
artifact_type: torch.export.save
framework: torch
framework_version: 2.7.0
accelerator: cuda
total_parameters: 1234567
tasks: [semantic-segmentation]
input:
- name: Imagery
bands: [red, blue, green]
input:
shape: [-1, 3, 512, 512]
dim_order: [batch, channel, height, width]
data_type: float32
pre_processing_function:
format: torch.export.load
expression:
id: export-transform
name: transforms
href: path/to/archive.pt2
type: "application/octet-stream; framework=pytorch; profile=ExportedProgram"
output:
- name: segmentation-output
tasks: [semantic-segmentation]
result:
shape: [-1, 6, 512, 512]
dim_order: [batch, classes, height, width]
data_type: float32
classification:classes:
- value: 0
name: Class 0
- value: 1
name: Class 1
- value: 2
name: Class 2
- value: 3
name: Class 3
- value: 4
name: Class 4
- value: 5
name: Class 5
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ dev-dependencies = [
"ruff<1.0.0,>=0.2.2",
"bump-my-version>=0.21",
"types-python-dateutil>=2.9.0.20241003",
"types-pyyaml>=6.0.12.20250516",
"requests>=2.32.4",
]

Expand Down
8 changes: 7 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import pystac
import pytest
import yaml

from stac_model.base import JSON
from stac_model.examples import eurosat_resnet as make_eurosat_resnet
Expand Down Expand Up @@ -60,7 +61,12 @@ def mlm_validator(
@pytest.fixture
def mlm_example(request: "SubRequest") -> dict[str, JSON]:
with open(os.path.join(EXAMPLES_DIR, request.param), mode="r", encoding="utf-8") as example_file:
data = json.load(example_file)
if request.param.endswith(".json"):
data = json.load(example_file)
elif request.param.endswith(".yaml"):
data = yaml.safe_load(example_file)
else:
raise ValueError(f"Unsupported file format for example: {request.param}")
return cast(dict[str, JSON], data)


Expand Down
14 changes: 14 additions & 0 deletions tests/torch/test_metadata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import os

import pytest

from stac_model.schema import MLModelProperties


@pytest.mark.parametrize(
"mlm_example",
[os.path.join("torch", "mlm-metadata.yaml")],
indirect=True,
)
def test_mlm_metadata_only_yaml_validation(mlm_example):
MLModelProperties.model_validate(mlm_example["properties"])
Loading