Skip to content

Feature/registry ml archivers#346

Merged
vik-rant merged 15 commits intodevfrom
feature/registry-ml-archivers
Feb 25, 2026
Merged

Feature/registry ml archivers#346
vik-rant merged 15 commits intodevfrom
feature/registry-ml-archivers

Conversation

@Yasserelhaddar
Copy link
Collaborator

Summary

Adds archivers for popular ML frameworks to the mindtrace-registry, enabling seamless save/load of models via the Registry system.

Why We Need This

The Registry currently supports basic types and Ultralytics models. ML workflows commonly use models from other frameworks that require manual serialization:

  • HuggingFace Transformers - Vision models, LLMs, with PEFT/LoRA adapters
  • timm - PyTorch Image Models for classification/feature extraction
  • ONNX - Framework-agnostic model interchange format
  • TensorRT - Optimized inference engines for NVIDIA GPUs

These archivers provide:

  • Unified API - Same registry.save() / registry.load() for all model types
  • Config preservation - Architecture, hyperparameters, and metadata automatically extracted
  • Adapter support - PEFT/LoRA weights saved alongside base models
  • Optional dependencies - Graceful handling when frameworks aren't installed

Usage Examples

HuggingFace Model

from transformers import AutoModelForImageClassification
from mindtrace.registry import Registry
import mindtrace.registry.archivers.huggingface  # Register archivers

model = AutoModelForImageClassification.from_pretrained("google/vit-base-patch16-224")
# Fine-tune with PEFT/LoRA...

registry = Registry()
registry.save("vit-classifier:v1", model)  # Saves model + adapter weights

loaded = registry.load("vit-classifier:v1")

timm Model

import timm
from mindtrace.registry import Registry
import mindtrace.registry.archivers.timm

model = timm.create_model("resnet50", pretrained=True, num_classes=10)

registry = Registry()
registry.save("resnet:v1", model)
loaded = registry.load("resnet:v1")

ONNX Model

import onnx
from mindtrace.registry import Registry
import mindtrace.registry.archivers.onnx

model = onnx.load("exported_model.onnx")

registry = Registry()
registry.save("onnx-model:v1", model)
loaded = registry.load("onnx-model:v1")

TensorRT Engine

import tensorrt as trt
from mindtrace.registry import Registry
import mindtrace.registry.archivers.tensorrt

# engine = ... (build or load TensorRT engine)

registry = Registry()
registry.save("trt-engine:v1", engine)
loaded = registry.load("trt-engine:v1")

What's Included

Archiver Types Handled Serialization Format
HuggingFaceModelArchiver PreTrainedModel + PEFT HF format + adapter/ dir
HuggingFaceProcessorArchiver Tokenizers, Processors HF format
TimmModelArchiver timm models config.json + model.pt
OnnxModelArchiver onnx.ModelProto model.onnx + metadata.json
TensorRTEngineArchiver trt.ICudaEngine engine.trt + metadata.json

Dependencies

  • Added optional dependency groups in pyproject.toml:
    • huggingface: transformers, peft
    • timm: timm
    • onnx: onnx
    • tensorrt: tensorrt

Notes

  • TensorRT engines are GPU-specific - An engine built on one GPU architecture may not work on another
  • Optional imports - All archivers gracefully handle missing dependencies with clear error messages
  • Auto-registration - Importing the archiver module automatically registers it with the Registry

Add archivers for HuggingFace transformers models and processors:

- HuggingFaceModelArchiver: handles all PreTrainedModel subclasses
  - Dynamic architecture detection from config.json
  - PEFT/LoRA adapter save/load support
  - Auto-registers for PreTrainedModel base class

- HuggingFaceProcessorArchiver: handles tokenizers and processors
  - Supports AutoProcessor, AutoTokenizer, AutoImageProcessor
  - Auto-detects processor type on load
  - Registers for PreTrainedTokenizerBase, ProcessorMixin, etc.

Includes unit tests for both archivers.
Tested with ViT, Mask2Former, and BERT models.
- Add TimmModelArchiver for PyTorch Image Models (timm) library
- Saves model config (architecture, num_classes, etc.) and state_dict
- Loads models using timm.create_model() with saved configuration
- Add optional dependencies in pyproject.toml for ML frameworks:
  - timm, huggingface (transformers+peft), ultralytics, ml (all)
- Add unit tests with full coverage
- Add OnnxModelArchiver for ONNX models (ModelProto)
- Saves model.onnx file and metadata.json with opset info
- Extracts input/output names, producer info, graph metadata
- Add onnx + onnxruntime to optional dependencies
- Add unit tests with full coverage
- Add TensorRTEngineArchiver for TensorRT inference engines
- Saves engine.trt file and metadata.json with binding info
- Extracts input/output tensor names, shapes, dtypes
- Add CUDA-specific optional deps: tensorrt-cu11, tensorrt-cu12
- Remove generic 'ml' extra (too hardware-specific)
- Add unit tests (skip when TensorRT unavailable)
- Use fallback ASSOCIATED_TYPES when optional libs not installed
- HuggingFace: use nn.Module as fallback, add ImportError checks
- ONNX/TensorRT: use object as fallback to prevent ZenML errors
- All archivers now raise ImportError with clear message if lib missing
@Yasserelhaddar Yasserelhaddar self-assigned this Jan 20, 2026
@Yasserelhaddar Yasserelhaddar added enhancement New feature or request mindtrace-registry Issues raised from registry module in mindtrace package labels Jan 20, 2026
Add module-level pytest.mark.skipif decorators to skip tests when
optional dependencies (transformers, timm, onnx, tensorrt) are not
installed. This fixes CI failures where tests attempted to patch
modules that don't exist.

Also applied ruff formatting to archiver source files.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing archiver registration at the end

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing archiver registration

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing archiver registration

Comment on lines +171 to +179
def _register_timm_archiver():
"""Register the timm archiver if timm is available."""
if not _TIMM_AVAILABLE:
return

# We can't easily get a base class for all timm models,
# so we register a custom type checker
# For now, users need to import this module to enable timm archiving
# The archiver will be selected based on _is_timm_model check
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

register archiver not doing anything.
the check happens in save() here, but registry needs to know beforehand so that we can reach save()

Comment on lines +27 to +28
tensorrt-cu11 = ["tensorrt-cu11>=8.6"]
tensorrt-cu12 = ["tensorrt-cu12>=10.0"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

uv sync (without extras) ended up spending ~10 mins to build the tensorrt packages and not installing.

@vik-rant
Copy link
Contributor

@Yasserelhaddar @JeremyWurbs , given that these archivers introduce very "models"-related dependencies, does it make sense to move them to the mindtrace-models package instead (along with the pre-existing ultralytics archivers)? e.g. from mindtrace.models.archivers import OnnxModelArchiver

"""

# HuggingFace models are nn.Module subclasses
ASSOCIATED_TYPES: ClassVar[Tuple[Type[Any], ...]] = (nn.Module,)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

HF Model archiver associated with all nn.Modules.
Should be for PreTrainedModel?

def __init__(self, uri: str, **kwargs):
super().__init__(uri=uri, **kwargs)

def _is_hf_model(self, model: Any) -> bool:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

function not used anywhere

"""

# timm models are nn.Module but we identify them via pretrained_cfg attribute
ASSOCIATED_TYPES: ClassVar[Tuple[Type[Any], ...]] = (nn.Module,)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also associating with nn.Module.
Let's add integration tests that test with minimal but real models (can be skipped when deps not installed, like the unit tests). e.g. What happens when a generic nn.Module is saved in registry, which archiver would be selected?

self.logger.debug(f"Saved PEFT adapter to {adapter_dir}")

except ImportError:
# PEFT not installed, skip adapter saving
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be an error. otherwise we are silently saving an incomplete model. confusing behaviour for users.

Comment on lines +89 to +95
config["architecture"] = model.pretrained_cfg.get("architecture", "unknown")
# Store full pretrained_cfg for reference
config["pretrained_cfg"] = {
k: v for k, v in model.pretrained_cfg.items() if isinstance(v, (str, int, float, bool, list, tuple))
}
elif hasattr(model, "default_cfg") and model.default_cfg:
config["architecture"] = model.default_cfg.get("architecture", "unknown")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what happens on loading a model that got saved with architecture "unknown"?

self.logger.debug(f"Loaded HuggingFace processor from {self.uri}")
return processor
except Exception:
pass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logging the exception can help in debugging.
the RuntimeError raised at the end might not have the relevant details.

@vik-rant vik-rant linked an issue Feb 17, 2026 that may be closed by this pull request
…spatch

Archiver modules must be imported after register_default_materializers()
so their _register_*_archiver() functions populate the dispatch table.
Without this, models fell through to ZenML generic materializers.
- Register PeftModel for correct dispatch (MRO skips PreTrainedModel)
- Save PeftModel via deep-copy + merge_and_unload for clean state dict
- Preserve adapter config/weights in adapter/ dir for provenance
- Skip adapter re-injection on load when weights were already merged
- Minor fixes to optional dependency guards in other ML archivers
…ion tests

- 10 new unit tests for PeftModel detection, merge-and-save, merged-load
  skip, dispatch routing, and PeftModel registration
- Integration dispatch tests verify all ML archivers resolve correctly
Covers 24 roundtrip tests across timm, HuggingFace, PEFT, processors,
ONNX, TensorRT, and Ultralytics with output verification.
@vik-rant vik-rant merged commit 92dbc0c into dev Feb 25, 2026
4 checks passed
@vik-rant vik-rant deleted the feature/registry-ml-archivers branch February 25, 2026 18:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request mindtrace-registry Issues raised from registry module in mindtrace package

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Model bank and Archiver Implementations

2 participants