High-performance deep learning models for removing AV1 compression artifacts from images and video frames.
This project provides production-ready models for AV1 artifact removal using two complementary approaches:
- Conditional U-Net (av1_conditional_unet_restorer_v2.py): A single, universal model that handles the entire CRF quality range (23-63) by accepting the CRF value as a condition. This is the recommended high-quality, flexible solution.
- Nano Models (av1_nano_*.py): A suite of ultra-lightweight, specialized models, each trained for a narrow CRF "bucket" (e.g., CRF 34-43). These are optimized for maximum speed and real-time video processing.
Both architectures are built on SOTA principles, including GroupNorm, GELU, efficient blocks, and modern upsampling techniques to ensure high-quality, artifact-free restoration.
Your AV1 dataset should be organized as follows for the training and testing scripts:
av1_data/ # Div2K + Flickr2K master dataset
βββ train/ # 90% of master dataset
β βββ lq/ # AV1-compressed images (.avif)
β β βββ crf_23/preset_4/ # crf_XX/preset_y/image_crfXX_pY.avif
β β βββ crf_24/preset_4/
β β βββ ... (up to crf_63)
β βββ hq/ # High-quality reference images (.png)
β
βββ val/ # 10% of master dataset
β βββ lq/ # AV1-compressed images
β βββ hq/ # High-quality reference images
β
βββ test/ # Separate test set (e.g., Div2K_valid)
βββ lq/ # AV1-compressed images
βββ hq/ # High-quality reference images
Purpose of each split:
- train/: Used for model training.
- val/: Used for validation during training (for early stopping and selecting the best model).
- test/: Used for final inference evaluation with scripts/restore_av1.py.
Purpose: A single model handles the entire CRF range (23-63) with compression-aware adaptive restoration.
- Dual Conditioning Modes:
- CRF-Only (128-dim): Fast, efficient, and the primary mode.
- CRF+Preset (192-dim): Optional mode for maximum quality.
- 5-Level U-Net: Deep multi-scale feature extraction for robust artifact removal.
- FiLM Conditioning: Adaptive feature modulation based on compression parameters.
- Bottleneck Attention: Global context modeling via efficient channel-wise self-attention.
- SOTA Upsampling: Uses Bilinear Upsample + Conv to eliminate checkerboard artifacts.
- Residual Learning: Predicts the artifact correction, preserving original image details.
- Memory-Efficient Tiling: Built-in inference logic handles images of any size.
Input [B,3,H,W] + CRF [B,1] (+ Preset [B,1])
β
βΌ
Conditioning Embedder (128/192-dim)
β
βΌ
ββββββββββββββββββββββββββββββββββββββββββ
β 5-Level U-Net Backbone β
ββββββββββββββββββββββββββββββββββββββββββ€
β Head (ch[0]) β EfficientResBlocks β
β β β
β βΌ Skip 0 β
β Enc1 (ch[1]) β FiLM β Blocks β2Γ β
β β β
β βΌ Skip 1 β
β Enc2 (ch[2]) β FiLM β Blocks β2Γ β
β β β
β βΌ Skip 2 β
β Enc3 (ch[3]) β FiLM β Blocks β2Γ β
β β β
β βΌ Skip 3 β
β Bottleneck (ch[4]) β2Γ β
β β Pre-Attn (Blocks) β
β β SimpleSelfAttention (Channel-wise) β
β β FiLM Conditioning β
β β Post-Attn (Blocks) β
β β β
β βΌ β2Γ (Bilinear + Conv) β
β Dec3 (ch[3]) β Skip 3 β Blocks β
β β β
β βΌ β2Γ (Bilinear + Conv) β
β Dec2 (ch[2]) β Skip 2 β Blocks β
β β β
β βΌ β2Γ (Bilinear + Conv) β
β Dec1 (ch[1]) β Skip 1 β Blocks β
β β β
β βΌ β2Γ (Bilinear + Conv) β
β Tail (ch[0]) β Skip 0 β Residual β
ββββββββββββββββββββββββββββββββββββββββββ
β
βΌ
Output = Input + Residual (clamped)
These configurations are precisely engineered to match target parameter counts.
| Size | Target | Actual Params | Use Case |
|---|---|---|---|
| nano | ~2M | 2.30M | Minimal viable conditional model |
| tiny | ~5M | 4.96M | Lightweight, fast iteration |
| small | ~10M | 9.21M | Standard balanced |
| base | ~12M | 12.79M | RECOMMENDED DEFAULT |
| large | ~20M | 19.69M | Enhanced restoration |
| huge | ~32M | 32.34M | High quality (Slow) |
| pro | ~50M | 50.31M | Maximum quality / Research |
| (Note: CRF+Preset mode adds < 0.3M parameters) |
Purpose: Ultra-lightweight models trained on narrow CRF ranges for real-time video processing.
Train separate specialized models for each compression tier:
| CRF Range | Compression | Strategy |
|---|---|---|
| 23-33 | Light | Texture preservation |
| 34-43 | Moderate | Balanced restoration |
| 44-53 | Heavy | Aggressive correction |
| 54-63 | Extreme | Full reconstruction |
At inference, a router selects the appropriate model based on the input CRF, resulting in faster and more accurate restoration for that specific range.
A. Nano U-Net (av1_nano_unet_restorer.py)
- Best balance of quality and speed
- 3-Level shallow U-Net (vs 5 in full U-Net)
- Depthwise separable convolutions + ECA attention
- Sizes: nano (0.2M), tiny (0.5M), small (1.2M), base (2.5M), large (6.0M), huge (11.2M)
Input Image
β
βΌ
Head + ResBlocks
β
βΌ [skip0]
Encoder-1 (downsample 2Γ)
β
βΌ [skip1]
Encoder-2 (downsample 2Γ)
β
βΌ [skip2]
Encoder-3 (downsample 2Γ)
β
βΌ
Decoder-3 (upsample 2Γ) β [skip2]
β
βΌ
Decoder-2 (upsample 2Γ) β [skip1]
β
βΌ
Decoder-1 (upsample 2Γ) β [skip0]
β
βΌ
Tail β Residual
β
βΌ
Restored = Input + Residual
B. Nano ResNet (av1_nano_resnet_restorer.py)
- Maximum speed (no downsampling)
- Processes at native resolution
- Multi-scale feature extraction head
- Sizes: nano (0.7M), tiny (1.2M), small (2.1M), base (3.3M), huge (6.5M)
Input Image
β
βΌ
Head Conv
β
βΌ
Multi-Scale Feature Extractor
(parallel 3x3, 5x5, 7x7 paths)
β
βΌ
N Γ Residual Blocks
(with long skip connection)
β
βΌ
Tail β Residual
β
βΌ
Restored = Input + Residual
C. Nano FBCNN (av1_nano_fbcnn_restorer.py)
- FBCNN architecture adapted for AV1
- Single-scale processing, quality-focused
- ~1.8M params
D. Nano Mamba (av1_nano_mamba_restorer.py)
- Experimental Hybrid CNN + State Space Model (SSM)
- Global receptive field
- ~2.0M params
git clone https://github.com/sohamazing/av1-restorer.git
cd av1-restorer
The setup script automatically detects and configures for CUDA (NVIDIA), MPS (Apple Silicon), or CPU.
chmod +x setup_env.sh
./setup_env.sh
# Activate the environment
conda activate aura
This is the complete workflow to generate the train, val, and test splits.
# Download DIV2K training data (for train/val)
wget [http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_HR.zip](http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_HR.zip)
unzip DIV2K_train_HR.zip
# Download DIV2K validation data (for test)
wget [http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_valid_HR.zip](http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_valid_HR.zip)
unzip DIV2K_valid_HR.zip
python scripts/degrade_av1.py \
--input_dir ./DIV2K_train_HR \
--output_dir ./av1_data/master/lq \
--crf_range 23 63 \
--preset_range 4 4 \
--num_workers 8
# Create symlink to HQ images
ln -s $(pwd)/DIV2K_train_HR ./av1_data/master/hq
python scripts/degrade_av1.py \
--input_dir ./DIV2K_valid_HR \
--output_dir ./av1_data/test/lq \
--crf_range 23 63 \
--preset_range 4 4 \
--num_workers 8
# Create symlink to HQ images
ln -s $(pwd)/DIV2K_valid_HR ./av1_data/test/hq
python scripts/split_av1_dataset.py \
--input_lq ./av1_data/master/lq \
--input_hq ./av1_data/master/hq \
--output_train_lq ./av1_data/train/lq \
--output_train_hq ./av1_data/train/hq \
--output_val_lq ./av1_data/val/lq \
--output_val_hq ./av1_data/val/hq \
--split_ratio 0.9
This workflow trains a single, high-quality model (e.g., large @ 19.7M params) on the full CRF range.
Create configs/conditional_config_tiny_crf23-63.yaml:
# ============================================================
# Conditional U-Net Configuration (CRF-Only Mode)
# ============================================================
project:
name: "AV1-Restorer"
experiment_name: "conditional_unet_tiny_crf23-63"
log_to_wandb: true
system:
device: "auto" # cuda/mps/cpu
seed: 24
mixed_precision: true # True for CUDA, False for MPS
num_workers: 8
model:
type: "unet"
size: "tiny" # Select: nano, tiny, small, base, large, huge, pro
dataset:
crf_range: [23, 63] # Full CRF spectrum
preset_range: [4, 4] # Single value = CRF-Only mode
norm_range: [-1, 1] # Image normalization
data:
train_lq_root: "./av1_data/train/lq"
train_hq_root: "./av1_data/train/hq"
val_lq_root: "./av1_data/val/lq"
val_hq_root: "./av1_data/val/hq"
lq_ext: ".avif"
hq_ext: ".png"
# ============================================================
# Curriculum Learning (Progressive Training)
# ============================================================
curriculum:
- # Stage 1: Small patches (learn local patterns)
patch_size: 128
batch_size: 64 # Adjust based on VRAM
epochs: 100
- # Stage 2: Medium patches (learn broader context)
patch_size: 256
batch_size: 16
epochs: 50
- # Stage 3: Large patches (refine full context)
patch_size: 512
batch_size: 4
epochs: 10
# ============================================================
# Loss Configuration (SOTA Balanced)
# ============================================================
loss:
charbonnier: {enabled: true, weight: 1.0} # Robust l1 variant
perceptual: {enabled: true, weight: 0.05} # 'vgg' (faster) or 'lpips'
ms_ssim: {enabled: true, weight: 0.15} # pip install pytorch-msssim
frequency: {enabled: true, weight: 0.01} # 1.0 for magnitide, 0.0 for phase
# ============================================================
# Optimizer & Scheduler
# ============================================================
optimizer:
type: "adamw"
lr: 0.0001
use_ema: true
ema_decay: 0.9999
scheduler:
type: "cosine"
warmup_steps: 1000
min_lr: 1.0e-7
# ============================================================
# Training Settings
# ============================================================
training:
grad_clip_norm: 1.0
validate_every_n_epochs: 1
log_every_n_steps: 50
checkpoint:
dir: "./checkpoints/conditional_unet_tiny_crf23-63"
save_every_n_epochs: 1
- Train from scratch
python av1_restorer/train_av1_conditional_restorer.py \
--config configs/conditional_unet/conditional_config_tiny_crf23-63.yaml
- Resume from latest checkpoint
python av1_restorer/train_av1_conditional_restorer.py \
--config configs/conditional_unet/conditional_config_tiny_crf23-63.yaml \
--resume latest
- Resume with W&B tracking
python av1_restorer/train_av1_conditional_restorer.py \
--config configs/conditional_unet/conditional_config_tiny_crf23-63.yaml \
--resume best \
--wandb_id <your-wandb-run-id>
- Console: Real-time loss, LR, and validation metrics.
- Weights & Biases: Comprehensive experiment tracking.
- Checkpoints: Auto-saved to the checkpoint: dir.
Key Metrics to Watch:
- loss_charbonnier: Main reconstruction loss (should decrease).
- metric_ms_ssim: Structural similarity (0-1, higher is better).
- val/improvement_l1: L1 improvement over baseline (should be positive and increasing).
- val/restored_l1 vs. val/baseline_l1: Restored L1 should become lower than baseline.
This workflow trains multiple lightweight models, each specialized for a specific CRF range.
Create a config file for each CRF bucket (e.g., configs/nano_models/nano_unet_small_crf34-43.yaml).
# ... (project, system, data sections as above) ...
model:
type: "nano_unet" # or nano_resnet
size: "small" # nano, tiny, small, etc.
dataset:
crf_range: [34, 43] # <<< NARROW CRF BUCKET
# ... (rest of config) ...
Launch a separate training run for each config file.
# Train Model A (CRF 23-33)
python av1_restorer/train_av1_nano_restorer.py \
--config configs/nano_models/nano_unet_small_crf23-33.yaml
# Train Model B (CRF 34-43)
python av1_restorer/train_av1_nano_restorer.py \
--config configs/nano_models/nano_unet_small_crf34-43.yaml
Use scripts/restore_av1.py to run your trained models.
| Argument | Short | Description |
|---|---|---|
| --checkpoint <path> | -c | Required. Path to the .pth checkpoint file. |
| --input_dir <path> | -d | Required. Path to a directory of LQ images. |
| --output_dir <path> | Required. Path to save restored images. | |
| --auto | Auto-detect CRF/Preset from filenames (e.g., ..._crf30_p4.avif). | |
| --test | Enable test mode: runs metrics against an HQ directory. | |
| --hq_dir <path> | (Test Mode) Path to the corresponding HQ images. | |
| --device <name> | auto, cuda, mps, or cpu. Defaults to auto. | |
| --tile <size> | Tile size for large images (e.g., 512). | |
| --overwrite | Overwrite existing files in the output directory. | |
| --dry_run | Log actions without processing. |
This is the recommended command for evaluating a model's performance.
python scripts/restore_av1.py \
--checkpoint checkpoints/conditional_unet_large/best.pth \
--input_dir ./av1_data/test/lq \
--output_dir ./results/large_model_test_metrics \
--hq_dir ./av1_data/test/hq \
--test \
--auto
This is the standard use case for batch-processing a folder of images.
python scripts/restore_av1.py \
--checkpoint checkpoints/conditional_unet_large/best.pth \
--input_dir /path/to/my_compressed_images \
--output_dir /path/to/my_restored_images \
--auto
python scripts/restore_av1.py \
--checkpoint checkpoints/conditional_unet_large/best.pth \
--input /path/to/my_image_crf45_p4.avif \
--output /path/to/restored_image.png \
--crf 45 \
--preset 4
This project will compare training and inference performance across two distinct hardware setups:
- Nvidia L4 VM (Cloud Engine): Representing a typical cloud-based GPU environment.
- M4 Max Macbook Pro (Local Machine): Representing high-end local ARM-based hardware (MPS).
Metrics will be gathered during the training and testing phases to evaluate real-world speed and efficiency.
- CUDA out of memory / OOM
- 1. Set Env Var: export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
- 2. Reduce batch_size in your YAML config.
- 3. Enable Gradient Checkpointing: In av1_conditional_unet_restorer_v2.py, import from torch.utils.checkpoint import checkpoint and wrap body calls in your forward pass (e.g., e1 = checkpoint(self.encoder1['body'], e1, use_reentrant=False)).
- NaN/Inf Loss
- Disable AMP on MPS: If on Apple Silicon, set mixed_precision: false in your YAML.
- Check Loss Weights: A perceptual weight > 0.2 can cause instability. Start low (0.05).
- Lower Learning Rate: Change lr: 1.0e-4 to 5.0e-5.
- Slow/Negative L1 Improvement
- Balance Losses: Your perceptual loss weight is likely too high. Reduce perceptual.weight (e.g., to 0.05).
- Add Structural Loss: Install and enable ms_ssim (pip install pytorch-msssim). This is critical for structural integrity.
- Checkerboard Artifacts in Output
- Ensure you are using av1_conditional_unet_restorer_v2.py, which uses Bilinear Upsample + Conv to fix this.
- Test Dataset Loading:
python -c "
from utils.av1_dataset import AV1Dataset
dataset = AV1Dataset(
lq_root_dir='av1_data/train/lq',
hq_root_dir='av1_data/train/hq',
hq_ext='.png',
patch_size=128,
crf_range=(23, 63),
preset_range=(4, 4),
norm_range=(-1, 1)
)
print(f'Dataset size: {len(ds)}')
dataset.print_statistics()
sample = dataset[0]
print(f"Sample keys: {sample.keys()}")
print(f"LQ shape: {sample['lq'].shape}, range: [{sample['lq'].min():.3f}, {sample['lq'].max():.3f}]")
print(f"HQ shape: {sample['hq'].shape}, range: [{sample['hq'].min():.3f}, {sample['hq'].max():.3f}]")
print(f"CRF: {sample['crf'].item()}, Preset: {sample['preset'].item()}")
"
- Model Loading:
# Create conditional restorer for all size configs
python av1_restorer/models/av1_conditional_unet_restorer_v2.py
- Dry Run Training (2 epochs):
# Create a dry_run.yaml config with epochs: 2
python av1_restorer/train_av1_conditional_restorer.py \
--config configs/dry_run.yaml
aura/
βββ av1_restorer/
β βββ models/
β β βββ av1_conditional_unet_restorer_v2.py # SOTA Conditional U-Net
β β βββ av1_nano_unet_restorer.py # Nano U-Net
β β βββ av1_nano_resnet_restorer.py # Nano ResNet (Fastest)
β β βββ av1_nano_fbcnn_restorer.py
β β βββ av1_nano_mamba_restorer.py
β β βββ blocks.py # Shared building blocks
β β
β βββ train_av1_conditional_restorer.py # Trainer for Conditional U-Net
β βββ train_av1_nano_restorer.py # (Hypothetical) Trainer for Nanos
β
βββ utils/
β βββ av1_dataset.py # Dataloader
β βββ loss.py # CombinedLoss function
β
βββ scripts/
β βββ degrade_av1.py # Creates LQ dataset
β βββ split_av1_dataset.py # Splits train/val
β βββ restore_av1.py # Inference script
β
βββ configs/
β βββ conditional_unet/ # Configs for Conditional U-Net
β βββ nano_models/ # Configs for Nano models
β
βββ av1_data/ # master dataset (Div2K + Flickr2K)
β βββ train/ # 90% of master dataset
β β βββ lq/ # AV1-compressed images (.avif)
β β β βββ crf_23/preset_4/
β β β βββ crf_24/preset_4/
β β β βββ ... (up to crf_63)
β β βββ hq/ # High-quality reference images (.png)
β β
β βββ val/ # 10% of master dataset
β β βββ lq/ # AV1-compressed images
β β βββ hq/ # High-quality reference images
β β
β βββ test/ # Separate test set (e.g., DIV2K_valid)
β βββ lq/ # AV1-compressed images
β βββ hq/ # High-quality reference images
β
βββ checkpoints/ # Saved models