This repository contains the implementation of a color prediction system for birds in the VOCIM dataset. The system uses a combination of TinyViT for initial color predictions and a Graph Neural Network (GNN) for refining these predictions while ensuring bijectivity within each frame. Based on work by Xiaoran Chen (SDSC).
The system consists of two main components:
-
TinyViT Backbone:
- Processes individual bird crops
- Outputs probability distribution over colors
- Can use heatmap mask for better feature extraction
-
ColorGNN:
- Takes TinyViT's predictions for birds in the same frame
- Creates a bipartite graph between birds and colors
- Refines predictions while ensuring bijectivity
- Uses Hungarian algorithm for optimal assignments
-
Data Processing:
- Images are cropped to individual birds
- Frame IDs are extracted from filenames
- Birds are grouped by frame for GNN processing
-
Model Processing:
- TinyViT processes each crop individually
- For each frame:
- Get top-K colors from TinyViT for each bird
- Create bipartite graph using only top-K colors
- GNN processes the graph to learn relationships
- Combine GNN scores with TinyViT probabilities
- Apply Hungarian algorithm for final assignments
-
Score Combination:
- GNN outputs a matrix of shape (num_birds, num_colors)
- Scores are weighted by TinyViT probabilities
- Higher TinyViT confidence → stronger GNN influence
- Lower TinyViT confidence → weaker GNN influence
-
Bijectivity Constraint:
- Hungarian algorithm ensures one-to-one assignments
- Each bird gets a unique color
- Each color is used at most once per frame
- Frame-Based Processing: Birds from the same frame are processed together
- Top-K Selection: Only considers TinyViT's top-K color predictions
- Bipartite Graph: Represents relationships between birds and colors
- Score Weighting: GNN scores are weighted by TinyViT confidence
- Bijective Assignments: Ensures unique color assignments per frame
- Training:
python train.py --config config.py- Evaluation:
python eval.py --model_path path/to/model --data_path path/to/data- PyTorch
- torch-geometric
- timm
- numpy
- PIL
- yaml
Key parameters in config.py:
use_heatmap_mask: Whether to use heatmap mask for TinyViTsigma_val: Sigma value for heatmap mask- Model architecture parameters
- Training parameters
The dataset should be organized as follows:
- Images are cropped to individual birds
- Filenames contain frame IDs (e.g., 'img00332_bird_1.png')
- YAML files map bird identities to colors
- JSON annotations contain bounding boxes and metadata
- The heatmap mask only affects TinyViT's feature extraction
- GNN processes TinyViT's outputs, not the original images
- Bijectivity is enforced at the frame level
- The system can handle varying numbers of birds per frame
The system is designed to:
- Process bird images with backpack annotations
- Apply heatmap masks to focus on relevant regions
- Train a deep learning model to predict color categories
- Evaluate model performance on test datasets
- Enforce unique color assignments per frame using linear assignment
model.py: Contains the main model architecture and training logicdataset.py: Handles data loading and preprocessingdataloader.py: Manages data batching and loadingtrain.py: Main training scripteval.py: Evaluation scriptlinear_assignment_eval.py: Evaluation with linear assignment for unique color constraintsutils.py: Utility functionsconfig.py: Configuration file containing global parameters and settings for the project, including:- Heatmap mask parameters (sigma values)
- Model configuration settings
- Data processing parameters
split.py,split_by_color_video.py,split_by_video.py: Data splitting utilitiessampler.py: Custom data sampling implementationutils_and_analysis/: Additional utility scripts and analysis tools
The project expects data in the following format:
- JSON annotations file containing image paths and labels
- YAML files for bird identity mapping (
newdata_bird_identity.yaml) - YAML file for color mapping (
newdata_colormap.yaml)
Use the provided splitting scripts to prepare your dataset:
split.py: General data splittingsplit_by_color_video.py: Split data by color and videosplit_by_video.py: Split data by videosplit_by_videosplit_file.py: Split data using a predefined split file
To train the model:
- Prepare your dataset using the splitting scripts
- Run the training script:
./train.shThe training script supports:
- Early stopping
- Model checkpointing
- Learning rate scheduling
- Logging of training metrics
To evaluate a trained model:
./eval.shThe evaluation script will:
- Load the trained model
- Run inference on the test set
- Calculate accuracy and other metrics
- Save predictions to a JSON file
For evaluation with linear assignment:
python linear_assignment_eval.pyThe linear assignment evaluation:
- Groups predictions by frame
- Creates cost matrices using model probabilities
- Applies the Hungarian algorithm to enforce unique color assignments
- Reports both direct and linear assignment accuracies
The model uses a deep learning architecture that:
- Takes RGB images as input
- Optionally uses heatmap masks to focus on relevant regions
- Processes images through a neural network
- Outputs color category predictions
The model includes a Graph Neural Network (GNN) component that:
- Enhances the base TinyViT architecture
- Uses separate dropout for the GNN component
- Creates a bipartite graph between birds and colors
- Processes relationships between birds and colors in the same frame
- Currently achieves 94.24% accuracy on the ambiguous subset
The linear assignment evaluation enforces the constraint that each color can only be used once per frame:
- Groups all birds from the same frame together
- Creates a cost matrix for each frame where:
- Rows represent bird predictions
- Columns represent available colors
- Cell values are 1 - model's softmax probabilities
- Uses the Hungarian algorithm to find optimal unique color assignments
- Computes accuracy across all frames
Our results demonstrate that the mask-guided TinyViT outperforms a ResNet-50 baseline, achieving 97.54% accuracy on unambiguous scenes and 92.81% accuracy on crowded (ambiguous) scenes. With the addition of the GNN and the linear assignment step, overall accuracy improves to 98%.