Skip to content

Model registry #246

Open
Open
@lorenzoh

Description

@lorenzoh

In addition to existing feature registries, FastAI.jl will be getting a model registry.

The model registry should allow

  • domain and third-party packages to add additional models when loaded;
  • using models from different "backend" deep learning libraries, i.e.
    Flux.jl, Lux.jl, PyTorch, Jax; and
  • searching for models suitable for a specific task, fitting a computational budget,
    or built in a specific deep learning library

Examples

Some code examples to show how the model registry can be used:

Loading models

Load a pretrained ResNet implemented in Metalhead.jl for transfer learning:

load(models()["metalhead/resnet18/head"], pretrained=true)

Load the ResNet as an untrained backbone for a different task

load(models()["metalhead/resnet18/backbone"], pretrained=false)

Searching for models

Find models that take in preprocessed images:

filter(models(), input=ImageTensor{2})

Or find a suitable model for a supervised learning task directly:

task = SupervisedTask(_)/ImageSegmentation(_)/TabularClassificationSingle(_)
filter(models(), input=task.blocks.x, output=task.blocks.y)

List models implemented in PyTorch:

filter(models(), backend=:pytorch)

Find models of a certain size:

filter(models(), input=ImageTensor{2}, nparams=<(1000000))

Training workflow

Since models in the registry are associated with block information, we can use them
to automatically construct task-specific models using the taskmodel API (possibly
extended by an additional backend argument).

config = models()["torchvision/resnet18/backbone"]
backbone = load(config)

task = ImageSegmentation(_)
# build the task-specific model
model = taskmodel(task,           # includes info about required input and target block for task
                  config.backend  # dispatch on the DL library used, here :pytorch
                  backbone,
                  config.input,   # backbone input block: `ImageTensor{2}(3)`
                  config.output)  # backbone output block: `ConvFeatures{2}(512)`

learner = tasklearner(task, data; model)
fit!(learner, 10)

Metadata

Metadata

Assignees

No one assigned

    Labels

    api-proposalImplementation or suggestion for new APIs and improvements to existing APIsenhancementNew feature or requestplansLong-term plans

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions