Skip to content

This project implements experiment 2 of the NN-16-1 pooling method described in the paper "Multi Layer Neural Networks as Replacement for Pooling Operations". It applies this novel pooling technique to the CIFAR100 dataset.

Notifications You must be signed in to change notification settings

gideon-ogunbanjo/Implementing-Multi-Layer-Neural-Networks

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

15 Commits
 
 
 
 
 
 
 
 

Repository files navigation

Multi-Layer Neural Networks as Replacement for Pooling Operations

Implementation of the paper "Multi Layer Neural Networks as Replacement for Pooling Operations" by Wolfgang Fuhl and Enkelejda Kasneci (AAAI 2021).

Overview

This project implements the paper's novel approach of replacing traditional pooling operations (max/average pooling) with learnable perceptron-based pooling layers in Convolutional Neural Networks. The method demonstrates that single perceptrons or small neural networks can effectively serve as pooling operators while maintaining computational efficiency.

Paper's Key Contribution

Instead of using fixed pooling operations like max or average pooling, the paper proposes:

  • Single Perceptron Pooling: One perceptron per pooling operation (only 10 additional parameters)
  • NN-4-1 Pooling: 4 neurons → 1 neuron (50 additional parameters)
  • NN-16-1 Pooling: 16 neurons → 1 neuron (194 additional parameters)

Critical Insight: The paper found that removing activation functions (like ReLU) from the pooling perceptrons significantly improves performance.

Implementation Details

Experiment 2: CIFAR-100 Classification

Following Table 3 from the paper (page 6):

Method Accuracy Additional Params
Average Pooling 75.40% 0
Max Pooling 75.36% 0
Perceptron 76.06% 10
NN-4-1 76.21% 50
NN-16-1 77.14% 194
Strided Conv (ReLU) 77.53% 184,608

Architecture

  • Model: 14-layer CNN (based on Kobayashi 2019a)
  • Layers:
    • 3 Convolutional layers (32 → 64 → 128 filters)
    • 3 Batch Normalization layers
    • 3 NN-16-1 Pooling layers (custom implementation)
    • 1 Fully Connected layer (100 classes)

Training Configuration (Per Paper Specifications)

  • Optimizer: SGD with momentum (0.9)
  • Learning Rate:
    • Initial: 0.1
    • Reduced by 10× at epochs 80 and 120
    • Pooling layers use 0.01 (10× lower than base)
  • Weight Decay: 5×10⁻⁴ (disabled for pooling layers)
  • Batch Size: 100
  • Epochs: 160
  • Data Augmentation:
    • Random crop (32×32 from 40×40 with 4-pixel zero padding)
    • Normalization to zero mean and unit variance

Mistakes in Original Implementation

Critical Errors ❌

  1. Not Implementing the Paper's Method

    • Used standard nn.MaxPool2d instead of perceptron-based pooling
    • This was the paper's entire contribution!
  2. Wrong Optimizer

    • Used: Adam with lr=0.001
    • Should be: SGD with momentum=0.9, lr=0.1
  3. Incorrect Normalization

    • Used: Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    • Should be: CIFAR-100 statistics mean=[0.5071, 0.4867, 0.4408], std=[0.2675, 0.2565, 0.2761]
  4. Wrong Learning Rate Schedule

    • Used: ReduceLROnPlateau (adaptive)
    • Should be: Step decay at epochs 80 and 120
  5. Missing Data Augmentation

    • Used: Only normalization
    • Should be: Padding + random crop as specified in paper
  6. Wrong Training Duration

    • Used: 100 epochs
    • Should be: 160 epochs

Implementation Bugs 🐛

  1. Incorrect Epoch Loss Calculation
   # WRONG: running_loss was reset every 200 iterations
   epoch_loss = running_loss / len(trainloader)
   
   # CORRECT: Track total loss separately
   total_loss += loss.item()
   epoch_loss = total_loss / len(trainloader)
  1. Missing Bias Initialization
   # WRONG: Only initialized weights
   torch.nn.init.kaiming_normal_(m.weight)
   
   # CORRECT: Initialize bias too
   if m.bias is not None:
       torch.nn.init.constant_(m.bias, 0)
  1. No Validation/Testing

    • Original code only trained, never evaluated accuracy
  2. Fragile Flattening

    # WRONG: Hardcoded dimensions
    x = x.view(-1, 128 * 4 * 4)
    
    # CORRECT: Auto-adjust
    x = torch.flatten(x, 1)

Corrections Made ✅

1. Implemented Custom Pooling Layers

class PerceptronPool2d(nn.Module):
    """Single perceptron pooling (10 params)"""
    # Key: NO activation function!
    
class NeuralNetPool2d(nn.Module):
    """NN-16-1 pooling (194 params)"""
    # 16 neurons → 1 neuron, no ReLU

2. Correct Training Setup

# Paper-specified hyperparameters
optimizer = torch.optim.SGD([
    {'params': other_params},
    {'params': pooling_params, 'lr': 0.01}  # 10× lower
], lr=0.1, momentum=0.9, weight_decay=5e-4)

# LR schedule: reduce at epochs 80, 120
if epoch == 80 or epoch == 120:
    for param_group in optimizer.param_groups:
        param_group['lr'] *= 0.1

3. Proper Data Preprocessing

transform_train = transforms.Compose([
    transforms.Pad(4, fill=0),
    transforms.RandomCrop(32),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.5071, 0.4867, 0.4408],
        std=[0.2675, 0.2565, 0.2761]
    )
])

4. Added Validation

def evaluate(model, dataloader, device):
    """Test accuracy evaluation"""
    # Evaluates every 5 epochs

5. Fixed All Bugs

  • ✅ Proper epoch loss tracking
  • ✅ Bias initialization
  • ✅ Robust tensor flattening
  • ✅ Better timing (time.perf_counter())

Key Insights from the Paper

  1. No Activation Functions: The paper discovered that perceptron pooling works better WITHOUT ReLU/activation functions
  2. Minimal Parameters: Single perceptron (10 params) performs nearly as well as networks with thousands of parameters
  3. Reduced Learning Rate: Pooling layers need 10× lower learning rate for stable training
  4. Initialization Matters: Initialize close to average pooling (0.25 for 2×2) with small random variations

Usage

# Install dependencies
pip install torch torchvision

# Run training
python train.py

# Choose pooling type in main():
net = Net(pool_type='nn-16-1')  # Options: 'perceptron', 'nn-4-1', 'nn-16-1', 'max'

Expected Results

After 160 epochs on CIFAR-100:

  • Perceptron Pooling: ~76.06% accuracy (10 additional params)
  • NN-4-1 Pooling: ~76.21% accuracy (50 additional params)
  • NN-16-1 Pooling: ~77.14% accuracy (194 additional params)

Compare to:

  • Max Pooling baseline: ~75.36%
  • Strided Conv: ~77.53% (but requires 184,608 params!)

File Structure

.
├── train.py              # Main training script with all implementations
├── README.md             # This file
├── data/                 # CIFAR-100 dataset (auto-downloaded)
└── best_model.pth        # Saved best model (optional)

References

@article{fuhl2021multi,
  title={Multi Layer Neural Networks as Replacement for Pooling Operations},
  author={Fuhl, Wolfgang and Kasneci, Enkelejda},
  journal={arXiv preprint arXiv:2006.06969v4},
  year={2021},
  note={AAAI Conference on Artificial Intelligence}
}

Creator

Gideon Ayodeji Ogunbanjo
Founder & CEO, Nuvo AI

License

This implementation is for educational and research purposes. Please cite the original paper if you use this code in your research.


Note: Training for 160 epochs on CIFAR-100 takes approximately 2-4 hours on a modern GPU (depending on hardware). Results may vary slightly due to random initialization and data shuffling.

About

This project implements experiment 2 of the NN-16-1 pooling method described in the paper "Multi Layer Neural Networks as Replacement for Pooling Operations". It applies this novel pooling technique to the CIFAR100 dataset.

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages