Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions _pkgdown.yml
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,9 @@ navbar:
- text: "Tuning"
- text: "Tuning Fit and Compile Arguments"
href: articles/tuning_fit_compile_args.html
- text: "Applications"
- text: "Transfer Learning"
href: articles/applications.html
github:
icon: fa-github
href: https://github.com/davidrsch/kerasnip
172 changes: 172 additions & 0 deletions vignettes/applications.Rmd
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
---
title: "Transfer Learning with Keras Applications"
output: rmarkdown::html_vignette
vignette: >
%\VignetteIndexEntry{Transfer Learning with Keras Applications}
%\VignetteEngine{knitr::rmarkdown}
%\VignetteEncoding{UTF-8}
---

```{r, include = FALSE}
knitr::opts_chunk$set(
collapse = TRUE,
comment = "#>",
eval = reticulate::py_module_available("keras")
)
# Suppress verbose Keras output for the vignette
options(keras.fit_verbose = 0)
set.seed(123)
```

## Introduction

Transfer learning is a powerful technique where a model developed for one task is reused as the starting point for a model on a second task. It is especially popular in computer vision, where pre-trained models like `ResNet50`, which were trained on the massive ImageNet dataset, can be used as powerful, ready-made feature extractors.

The `kerasnip` package makes it easy to incorporate these pre-trained Keras Applications directly into a `tidymodels` workflow. This vignette will demonstrate how to:

1. Define a `kerasnip` model that uses a pre-trained `ResNet50` as a frozen base layer.
2. Add a new, trainable classification "head" on top of the frozen base.
3. Tune the hyperparameters of the new classification head using a standard `tidymodels` workflow.

## Setup

First, we load the necessary packages.

```{r load-packages}
library(kerasnip)
library(tidymodels)
library(keras3)
```

## Data Preparation

We'll use the CIFAR-10 dataset, which consists of 60,000 32x32 color images in 10 classes. `keras3` provides a convenient function to download it.

The `ResNet50` model was pre-trained on ImageNet, which has a different set of classes. Our goal is to fine-tune it to classify the 10 classes in CIFAR-10.

```{r data-prep}
# Load CIFAR-10 dataset
cifar10 <- dataset_cifar10()

# Separate training and test data
x_train <- cifar10$train$x
y_train <- cifar10$train$y
x_test <- cifar10$test$x
y_test <- cifar10$test$y

# Rescale pixel values from [0, 255] to [0, 1]
x_train <- x_train / 255
x_test <- x_test / 255

# Convert outcomes to factors for tidymodels
y_train_factor <- factor(y_train[, 1])
y_test_factor <- factor(y_test[, 1])

# For tidymodels, it's best to work with data frames.
# We'll use a list-column to hold the image arrays.
train_df <- tibble::tibble(
x = lapply(seq_len(nrow(x_train)), function(i) x_train[i, , , , drop = TRUE]),
y = y_train_factor
)

test_df <- tibble::tibble(
x = lapply(seq_len(nrow(x_test)), function(i) x_test[i, , , , drop = TRUE]),
y = y_test_factor
)

# Use a smaller subset for faster vignette execution
train_df_small <- train_df[1:500, ]
test_df_small <- test_df[1:100, ]
```

## Functional API with a Pre-trained Base

The standard approach for transfer learning is to use the Keras Functional API. We will define a model where:
1. The base is a pre-trained `ResNet50`, with its final classification layer removed (`include_top = FALSE`).
2. The weights of the base are frozen (`trainable = FALSE`) so that only our new layers are trained.
3. A new classification "head" is added, consisting of a flatten layer and a dense output layer.

### Define Layer Blocks

```{r define-functional-blocks}
# Input block: shape is determined automatically from the data
input_block <- function(input_shape) {
layer_input(shape = input_shape)
}

# ResNet50 base block
resnet_base_block <- function(tensor) {
# The base model is not trainable; we use it for feature extraction.
resnet_base <- application_resnet50(
weights = "imagenet",
include_top = FALSE
)
resnet_base$trainable <- FALSE
resnet_base(tensor)
}

# New classification head
flatten_block <- function(tensor) {
tensor |> layer_flatten()
}

output_block_functional <- function(tensor, num_classes) {
tensor |> layer_dense(units = num_classes, activation = "softmax")
}
```

### Create the `kerasnip` Specification

We connect these blocks using `create_keras_functional_spec()`.

```{r create-functional-spec}
create_keras_functional_spec(
model_name = "resnet_transfer",
layer_blocks = list(
input = input_block,
resnet_base = inp_spec(resnet_base_block, "input"),
flatten = inp_spec(flatten_block, "resnet_base"),
output = inp_spec(output_block_functional, "flatten")
),
mode = "classification"
)
```

### Fit and Evaluate the Model

Now we can use our new `resnet_transfer()` specification within a `tidymodels` workflow.

```{r fit-functional-model, cache=TRUE}
spec_functional <- resnet_transfer(
fit_epochs = 5,
fit_validation_split = 0.2
) |>
set_engine("keras")

rec_functional <- recipe(y ~ x, data = train_df_small)

wf_functional <- workflow() |>
add_recipe(rec_functional) |>
add_model(spec_functional)

fit_functional <- fit(wf_functional, data = train_df_small)

# Evaluate on the test set
predictions <- predict(fit_functional, new_data = test_df_small)
bind_cols(predictions, test_df_small) |>
accuracy(truth = y, estimate = .pred_class)
```

Even with a small dataset and few epochs, the pre-trained features from ResNet50 give us a reasonable starting point for accuracy.

## Conclusion

This vignette demonstrated how `kerasnip` bridges the world of pre-trained Keras applications with the structured, reproducible workflows of `tidymodels`.

The **Functional API** is the most direct way to perform transfer learning by attaching a new head to a frozen base model.

This approach allows you to leverage the power of deep learning models that have been trained on massive datasets, significantly boosting performance on smaller, domain-specific tasks.

```{r cleanup, include=FALSE}
remove_keras_spec("resnet_transfer")
```
Loading