A Unified Framework for Semantic Segmentation of Earth Observation Imagery
SemanticSeg4EO is a comprehensive framework for semantic segmentation of satellite imagery, supporting both binary and multi-class segmentation through a unified codebase. The system integrates advanced deep learning architectures specifically adapted for remote sensing applications, with emphasis on methodological transparency, reproducibility, and experimental flexibility.
- Overview
- What's New in V2
- Key Features
- Installation
- Quick Start
- Dataset Preparation
- Patch Extraction
- Training System
- Advanced Training Features
- Inference on Large Images
- Architecture Support
- Output Format
- Examples
- Best Practices
- Troubleshooting
- License
- Contact
SemanticSeg4EO provides a unified pipeline for Earth Observation (EO) data segmentation, from data preparation to large-scale inference. The framework combines robust preprocessing, advanced training techniques, and seamless patch-based prediction, making it suitable for both research and production applications in land-cover mapping, environmental monitoring, and change detection.
Version 2 introduces significant enhancements for improved training performance and flexibility:
- Focal Loss: Better handling of class imbalance with configurable alpha/gamma
- Tversky Loss: Control false positive/negative tradeoff with alpha/beta parameters
- Combo Loss: Combined CE + Dice + Focal for maximum flexibility
- Focal-Dice: Recommended for severely imbalanced datasets
- Encoder Freezing: Freeze pretrained encoder for initial epochs to preserve learned features
- Gradual Unfreezing: Automatic unfreezing after specified epochs
- Warmup: Gradual learning rate increase at training start
- Multiple Schedulers: ReduceLROnPlateau, Cosine Annealing, One-Cycle
- Mixed Precision Training (AMP): Faster training with reduced memory usage
- Per-class IoU Logging: Detailed metrics for each class during training
- CSV Export: All training metrics saved to CSV for analysis
- Per-class Visualization: Training plots show IoU evolution per class
- Enhanced Checkpoints: Complete configuration saved with models
- Multi-image Extraction: Process multiple image-label pairs automatically
- Pattern Matching: Automatic file pairing with regex patterns
- Single codebase for both binary and multi-class segmentation
- Automatic mode detection based on configuration
- Consistent interface across all workflows
- K-Fold Cross-Validation with comprehensive statistics
- Multi-channel data augmentation tailored for satellite imagery
- Class weighting for imbalanced datasets
- Early stopping and model checkpointing
- Percentile-based normalization (99th percentile robust normalization)
- Custom U-Net variants with dropout regularization
- Segmentation Models PyTorch (SMP) integration: UNet, UNet++, DeepLabV3, DeepLabV3+, FPN, PSPNet, MANet, PAN, LinkNet
- TorchVision models support
- Configurable encoders (ResNet, EfficientNet, etc.)
- Patch-based prediction with seamless reconstruction
- Weighted blending to reduce border artifacts
- Geospatial metadata preservation
- Confidence map generation
- Automatic encoder detection from checkpoints
- Automatic patch extraction using shapefile grids
- Train/validation/test splitting with reproducibility
- Multi-channel support (including Sentinel-2 with 10+ bands)
- Batch mode for processing multiple images
- Python ≥ 3.8
- PyTorch ≥ 1.10 (with CUDA for GPU acceleration)
- GPU recommended for training and large-scale inference
# Clone repository
git clone https://github.com/aleguillou1/SemanticSeg4EO.git
cd SemanticSeg4EO
# Install dependencies
pip install -r requirements.txt# Core Deep Learning
torch>=1.10.0
torchvision>=0.11.0
segmentation-models-pytorch>=0.3.0
# Geospatial
rasterio>=1.3.0
geopandas>=0.12.0
# Image Processing
tifffile>=2022.5.4
opencv-python>=4.5.0
# Scientific Computing
numpy>=1.21.0
scipy>=1.7.0
scikit-learn>=1.0.0
# Visualization
matplotlib>=3.5.0
# Utilities
tqdm>=4.64.0
dataset_root/
├── Patch/
│ ├── train/
│ │ ├── images/
│ │ │ ├── patch_001.tif
│ │ │ └── ...
│ │ └── labels/
│ │ ├── patch_001.tif
│ │ └── ...
│ ├── validation/
│ │ ├── images/
│ │ └── labels/
│ └── test/
│ ├── images/
│ └── labels/
# Multi-class segmentation with Focal Loss (recommended for imbalanced data)
python main.py --mode multiclass --classes 5 --dataset_root /path/to/data \
--model unet++ --loss_type focal --use_class_weights
# With encoder freezing and warmup
python main.py --mode multiclass --classes 5 --dataset_root /path/to/data \
--model unet++ --freeze_encoder --freeze_epochs 5 --warmup_epochs 2
# Binary segmentation
python main.py --mode binary --dataset_root /path/to/data --model unetpython Predict_large_image.py --model trained_models/model_final.pth \
--input large_image.tif \
--output prediction.tif- Images: Multi-band GeoTIFF files (e.g., Sentinel-2 with 10+ bands)
- Labels: Single-band GeoTIFF masks
- Binary: 0 (background) and 1 (foreground)
- Multi-class: Integers from 0 to N-1 (where N = number of classes)
- Spatial alignment: Images and masks must have identical georeferencing
The system expects the following directory structure:
dataset_root/
└── Patch/
├── train/
│ ├── images/ # Training images
│ └── labels/ # Training masks
├── validation/
│ ├── images/ # Validation images
│ └── labels/ # Validation masks
└── test/
├── images/ # Test images
└── labels/ # Test masks
For large satellite scenes, use the patch extraction module to create training-ready datasets.
python Patch_extraction.py single \
--image /path/to/satellite_image.tif \
--label /path/to/ground_truth.tif \
--grid /path/to/grid_shapefile.shp \
--output /path/to/output_dataset \
--patch_size 224 \
--image_channels 10 \
--train_ratio 0.75 \
--val_ratio 0.15 \
--test_ratio 0.10# Automatically find and process Image_1.tif/Label_1.tif, Image_2.tif/Label_2.tif, etc.
python Patch_extraction.py batch \
--data_dir /path/to/images_folder \
--grid /path/to/grid.shp \
--output /path/to/output \
--patch_size 224 \
--image_channels 10 \
--recursive- Images:
Image_1.tif,Image_2.tif, ... ORimage_1.tif,image_2.tif, ... - Labels:
Label_1.tif,Label_2.tif, ... ORlabel_1.tif,label_2.tif, ... - Grids (optional per-image):
Grid_1.shp,Grid_2.shp, ...
python Patch_extraction.py info --output /path/to/datasetpython Patch_extraction.py visualize \
--output /path/to/output_dataset \
--split train \
--sample_index 0The system provides a single entry point (main.py) for both segmentation modes with all features:
python main.py --mode [binary|multiclass] [OPTIONS]# Multi-class with Focal Loss and encoder freezing
python main.py --mode multiclass \
--classes 5 \
--dataset_root /path/to/data \
--model unet++ \
--loss_type focal_dice \
--freeze_encoder --freeze_epochs 5 \
--warmup_epochs 2 \
--use_amp \
--log_per_class \
--class_names background water vegetation buildings roads# 5-fold cross-validation with per-class metrics
python main.py --mode multiclass \
--classes 5 \
--dataset_root /path/to/data \
--model unet++ \
--val_strategy kfold \
--n_splits 5 \
--loss_type focal \
--log_per_class| Loss Type | Description | Best For |
|---|---|---|
ce |
Cross Entropy only | Balanced datasets |
dice |
Dice Loss only | General segmentation |
dice_ce |
Dice + Cross Entropy (default) | Balanced approach |
focal |
Focal Loss | Class imbalance |
focal_dice |
Focal + Dice | Severe imbalance |
tversky |
Tversky Loss | Control FP/FN tradeoff |
combo |
CE + Dice + Focal | Maximum flexibility |
# Using Focal Loss with custom parameters
python main.py --loss_type focal --focal_gamma 2.0 --focal_alpha 0.25
# Using Tversky Loss (weight false negatives more)
python main.py --loss_type tversky --tversky_alpha 0.3 --tversky_beta 0.7Freeze the pretrained encoder to preserve learned features during initial training:
python main.py --freeze_encoder --freeze_epochs 5This is particularly useful when fine-tuning on small datasets or when the target domain is similar to ImageNet.
Gradually increase learning rate from a small value to the target:
python main.py --warmup_epochs 3 --warmup_lr 1e-6 --learning_rate 5e-4Enable automatic mixed precision for faster training and reduced memory:
python main.py --use_ampEnable detailed per-class IoU logging and visualization:
python main.py --log_per_class --class_names background water forest urban# ReduceLROnPlateau (default)
python main.py --scheduler_type reduce_plateau
# Cosine Annealing
python main.py --scheduler_type cosine
# One-Cycle Policy
python main.py --scheduler_type one_cycle| Parameter | Description | Default |
|---|---|---|
--mode |
Segmentation mode: binary or multiclass |
multiclass |
--dataset_root |
Path to dataset root directory | Required |
--model |
Model architecture name | Required |
--classes |
Number of classes (for multiclass) | 5 |
--val_strategy |
Validation strategy: split or kfold |
split |
--loss_type |
Loss function type | dice_ce |
--epochs |
Number of training epochs | 100 |
--batch_size |
Batch size | 8 |
--learning_rate |
Learning rate | 5e-4 |
--encoder_name |
Encoder backbone name | resnet34 |
--pretrained |
Use pretrained encoder weights | True |
--freeze_encoder |
Freeze encoder for initial epochs | False |
--freeze_epochs |
Number of epochs to keep encoder frozen | 5 |
--warmup_epochs |
Number of warmup epochs | 0 |
--use_amp |
Enable mixed precision training | False |
--log_per_class |
Log per-class IoU metrics | True |
--class_names |
Names for each class | Auto-generated |
--use_class_weights |
Apply class weights for imbalance | True |
--n_splits |
Number of folds for cross-validation | 5 |
During training, the system generates:
trained_models/
├── model_best_loss.pth # Best validation loss checkpoint
├── model_best_iou.pth # Best validation IoU checkpoint
├── model_final.pth # Final model with complete config
├── model_training_plot.png # Training visualization
├── model_training_log.csv # Complete metrics history
├── model_per_class_iou.png # Per-class IoU evolution
└── model_metrics.json # Complete metrics in JSON format
The Predict_large_image.py script handles prediction on arbitrarily large satellite scenes with automatic encoder detection:
python Predict_large_image.py --model /path/to/model.pth \
--input /path/to/large_image.tif \
--output /path/to/prediction.tif# Multi-class with custom parameters
python Predict_large_image.py \
--model /path/to/model.pth \
--input large_image.tif \
--output prediction.tif \
--num_classes 6 \
--patch_size 512 \
--overlap 128 \
--save_confidence \
--device cuda| Parameter | Description | Default |
|---|---|---|
--model |
Path to trained model (.pth file) | Required |
--input |
Input satellite image | Required |
--output |
Output segmentation map | Required |
--encoder_name |
Encoder backbone (auto-detected) | Auto |
--patch_size |
Size of prediction patches | 512 |
--overlap |
Overlap between patches | 128 |
--num_classes |
Number of output classes | Auto-detected |
--threshold |
Confidence threshold (binary) | 0.5 |
--save_confidence |
Save confidence map | False |
--device |
Computation device | cuda |
The predictor uses weighted blending to eliminate border artifacts, automatic patch tiling with configurable overlap, geospatial metadata preservation, and nodata value handling from source images.
Run the following to see all available models:
python -c "from model_training import get_available_models; print(get_available_models())"unet-dropout: Custom U-Net with dropout regularization
- Encoder-Decoder: UNet, UNet++, MANet, Linknet, PAN
- Pyramid Networks: FPN, PSPNet
- DeepLab Family: DeepLabV3, DeepLabV3+
- DeepLabV3: With ResNet50 backbone
- ResNet: 18, 34, 50, 101, 152
- EfficientNet: b0-b7
- MobileNet: v2, v3
- DenseNet: 121, 169, 201, 264
- VGG: 11, 13, 16, 19
- And more via SMP
Trained models are saved with comprehensive metadata including all configuration:
{
'model_state_dict': model_weights,
'config': {
'model_name': 'unet++',
'mode': 'multiclass',
'in_channels': 10,
'num_classes': 6,
'encoder_name': 'resnet34',
'loss_type': 'focal_dice',
'freeze_encoder': True,
'freeze_epochs': 5,
'warmup_epochs': 2,
'use_amp': True,
# ... all other config parameters
},
'performance_metrics': {
'best_val_loss': 0.1234,
'best_val_iou': 0.7890,
'per_class_iou': {...}
}
}- Segmentation map: GeoTIFF with class labels
- Confidence map (optional): GeoTIFF with prediction confidence
- Statistics report: Console output with class distribution and confidence metrics
# Extract patches from large scenes (batch mode)
python Patch_extraction.py batch \
--data_dir ./raw_data \
--grid grid_polygons.shp \
--output ./landcover_dataset \
--patch_size 256 \
--image_channels 10
# Train with advanced features
python main.py --mode multiclass \
--classes 6 \
--dataset_root ./landcover_dataset \
--model deeplabv3+ \
--encoder_name efficientnet-b4 \
--pretrained \
--loss_type focal_dice \
--freeze_encoder --freeze_epochs 5 \
--warmup_epochs 3 \
--use_amp \
--log_per_class \
--class_names background water forest urban agriculture bare \
--epochs 200
# Predict on new scene
python Predict_large_image.py \
--model ./trained_models/model_final.pth \
--input new_sentinel2_scene.tif \
--output landcover_prediction.tif \
--save_confidence# Train binary segmentation with Dice loss
python main.py --mode binary \
--dataset_root ./water_dataset \
--model unet++ \
--encoder_name resnet34 \
--pretrained \
--loss_type dice \
--in_channels 10 \
--epochs 150 \
--learning_rate 0.0005
# Predict with custom threshold
python Predict_large_image.py \
--model ./water_models/model_final.pth \
--input sentinel2_water_scene.tif \
--output water_mask.tif \
--threshold 0.3# Use Focal-Dice loss with high gamma
python main.py --mode multiclass \
--classes 5 \
--dataset_root ./imbalanced_data \
--model unet++ \
--loss_type focal_dice \
--focal_gamma 3.0 \
--use_class_weights \
--freeze_encoder --freeze_epochs 10- Normalize images using the 99th percentile method (already implemented)
- Ensure class balance or use
--use_class_weightsfor imbalanced datasets - Use data augmentation (
--data_augmentation) for small datasets - Validate spatial alignment between images and masks
- Start with pretrained encoders and use
--freeze_encoderfor better transfer learning - Use Focal Loss (
--loss_type focalor--loss_type focal_dice) for imbalanced datasets - Enable warmup (
--warmup_epochs 2-5) for more stable training - Use mixed precision (
--use_amp) for faster training on modern GPUs - Use cross-validation (
--val_strategy kfold) for reliable performance estimation - Monitor per-class metrics (
--log_per_class) to identify underperforming classes
- Set appropriate overlap (25-50% of patch size) to avoid border artifacts
- Generate confidence maps (
--save_confidence) for uncertainty analysis - Process large images in chunks if memory is limited
- Verify geospatial alignment of output predictions
- Use GPU acceleration for both training and inference
- Adjust patch size based on GPU memory (256-512px recommended)
- Enable mixed precision training (
--use_amp) for 40-60% speedup - Use data loaders with pinned memory for faster data transfer
- Cause: Incorrect dataset structure or file extensions
- Solution: Verify directory structure and ensure files have
.tifor.tiffextensions
- Cause: Batch size or patch size too large
- Solution: Reduce
--batch_sizeor--patch_size, enable--use_amp
- Cause: Insufficient overlap between patches
- Solution: Increase
--overlapparameter (recommended: 25-50% of patch size)
- Cause: Mismatch in model parameters or architecture
- Solution: Ensure
--in_channels,--num_classes, and--encoder_namematch training configuration (auto-detected from checkpoint)
- Cause: Large patch size or CPU inference
- Solution: Reduce patch size, use GPU (
--device cuda), or enable tiling
- Cause: Dominant background class
- Solution: Use
--loss_type focal_dice, increase--focal_gamma, enable--use_class_weights
For detailed debugging, add error tracebacks:
import traceback
try:
# Your code here
except Exception as e:
print(f"Error: {e}")
traceback.print_exc()This project is licensed under the MIT License. See the LICENSE file for details.
For questions, collaborations, or technical support:
Adrien Leguillou
Research Engineer at LETG
Email: adrien.leguillou@univ-brest.fr
This framework builds upon several open-source projects:
Special thanks to the remote sensing community for datasets and methodologies that inspired this work.
Special thanks to the remote sensing community for datasets and methodologies that inspired this work.