Open
Description
In addition to existing feature registries, FastAI.jl will be getting a model registry.
The model registry should allow
- domain and third-party packages to add additional models when loaded;
- using models from different "backend" deep learning libraries, i.e.
Flux.jl, Lux.jl, PyTorch, Jax; and - searching for models suitable for a specific task, fitting a computational budget,
or built in a specific deep learning library
Examples
Some code examples to show how the model registry can be used:
Loading models
Load a pretrained ResNet implemented in Metalhead.jl for transfer learning:
load(models()["metalhead/resnet18/head"], pretrained=true)
Load the ResNet as an untrained backbone for a different task
load(models()["metalhead/resnet18/backbone"], pretrained=false)
Searching for models
Find models that take in preprocessed images:
filter(models(), input=ImageTensor{2})
Or find a suitable model for a supervised learning task directly:
task = SupervisedTask(_)/ImageSegmentation(_)/TabularClassificationSingle(_)
filter(models(), input=task.blocks.x, output=task.blocks.y)
List models implemented in PyTorch:
filter(models(), backend=:pytorch)
Find models of a certain size:
filter(models(), input=ImageTensor{2}, nparams=<(1000000))
Training workflow
Since models in the registry are associated with block information, we can use them
to automatically construct task-specific models using the taskmodel
API (possibly
extended by an additional backend
argument).
config = models()["torchvision/resnet18/backbone"]
backbone = load(config)
task = ImageSegmentation(_)
# build the task-specific model
model = taskmodel(task, # includes info about required input and target block for task
config.backend # dispatch on the DL library used, here :pytorch
backbone,
config.input, # backbone input block: `ImageTensor{2}(3)`
config.output) # backbone output block: `ConvFeatures{2}(512)`
learner = tasklearner(task, data; model)
fit!(learner, 10)