Skip to content

llmsresearch/coupledadam

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

1 Commit
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

CoupledAdam: Better Embeddings with Coupled Adam

This repository implements the Coupled Adam optimizer as described in the research paper "Better Embeddings with Coupled Adam". The implementation provides a PyTorch-compatible optimizer that can be used as a drop-in replacement for standard Adam optimizer.

Overview

Coupled Adam is a novel optimization algorithm that improves upon the traditional Adam optimizer by introducing coupling between the first and second moment estimates. This coupling helps in achieving better embeddings, particularly in scenarios involving large language models and embedding learning tasks.

Key Features

  • Improved convergence compared to standard Adam
  • Better handling of sparse gradients
  • Enhanced stability in training deep neural networks
  • Compatible with PyTorch's optimizer interface

Detailed Methodology

The Coupled Adam optimizer introduces a novel modification to the traditional Adam optimizer by coupling the first and second moment estimates. Here's the detailed methodology:

1. Traditional Adam Review

Standard Adam maintains two moments for each parameter θ:

  • First moment (m): Exponential moving average of gradients
  • Second moment (v): Exponential moving average of squared gradients

2. Coupled Adam Modifications

The key innovation in Coupled Adam is the introduction of coupling between these moments:

  1. First Moment Update (m_t):

    m_t = β₁m_{t-1} + (1-β₁)g_t
    

    Where:

    • β₁ is the first moment decay rate
    • g_t is the current gradient
  2. Second Moment Update with Coupling (v_t):

    v_t = β₂v_{t-1} + (1-β₂)(g_t² + λ|m_t|²)
    

    Where:

    • β₂ is the second moment decay rate
    • λ is the coupling factor
    • |m_t|² represents the squared magnitude of the first moment
  3. Bias Correction:

    m̂_t = m_t / (1-β₁ᵗ)
    v̂_t = v_t / (1-β₂ᵗ)
    
  4. Parameter Update:

    θ_t = θ_{t-1} - α·m̂_t / (√v̂_t + ε)
    

    Where:

    • α is the learning rate
    • ε is a small constant for numerical stability

3. Key Components

  1. Coupling Factor (λ):

    • Controls the strength of coupling between moments
    • Default value: 0.1
    • Range: [0, 1]
    • Higher values increase the influence of first moment on variance
  2. Moment Decay Rates:

    • β₁ = 0.9 (first moment)
    • β₂ = 0.999 (second moment)
    • These values are empirically chosen for good performance
  3. Adaptive Learning Rate:

    • The effective step size is automatically adjusted based on the coupled moments
    • Helps in better handling of varying gradient magnitudes

4. Advantages

  1. Improved Stability:

    • The coupling between moments provides more stable updates
    • Helps prevent oscillations in optimization
  2. Better Embeddings:

    • Particularly effective for embedding learning tasks
    • Improved convergence in high-dimensional spaces
  3. Adaptive Momentum:

    • The coupling mechanism allows for better adaptation to local geometry
    • More effective handling of varying curvature

Hardware Acceleration

GPU Support

The implementation automatically detects and utilizes available NVIDIA GPUs. If multiple GPUs are available, you can:

  • Use all GPUs with DataParallel (automatic)
  • Specify a particular GPU using the --gpu-id argument
  • Fall back to CPU if no GPU is available

CPU Optimization

When no GPU is available, the implementation automatically:

  • Uses all available CPU cores for data loading
  • Optimizes memory layout for CPU computation
  • Configures PyTorch for optimal multi-threading

Installation

pip install -r requirements.txt

Usage

from coupled_adam import CoupledAdam

# Initialize your model
model = YourModel()

# Create the optimizer
optimizer = CoupledAdam(
    model.parameters(),
    lr=1e-3,
    betas=(0.9, 0.999),
    eps=1e-8,
    weight_decay=0,
    coupling_factor=0.1  # The coupling factor between first and second moments
)

# Training loop
for epoch in range(num_epochs):
    for batch in dataloader:
        optimizer.zero_grad()
        loss = compute_loss(batch)
        loss.backward()
        optimizer.step()

Configuration

The optimizer can be configured using the following parameters:

  • lr: Learning rate (default: 1e-3)
  • betas: Coefficients for computing running averages of gradient and its square (default: (0.9, 0.999))
  • eps: Term added to the denominator to improve numerical stability (default: 1e-8)
  • weight_decay: Weight decay coefficient (default: 0)
  • coupling_factor: Factor controlling the coupling between first and second moments (default: 0.1)

Testing

To run the test suite:

python -m pytest tests/

Citation

If you use this implementation in your research, please cite:

@article{stollenwerk2025better,
  title={Better Embeddings with Coupled Adam},
  author={Stollenwerk, Felix and Stollenwerk, Tobias},
  journal={arXiv preprint arXiv:2502.08441},
  year={2025}
}

License

MIT License

Contributing

Contributions are welcome! Please feel free to submit a Pull Request.

Running the Example

The MNIST example script supports both GPU and CPU execution with various configuration options:

python examples/train_mnist.py [OPTIONS]

Options:
  --batch-size INT        Batch size for training (default: 64)
  --test-batch-size INT   Batch size for testing (default: 1000)
  --epochs INT            Number of epochs to train (default: 10)
  --lr FLOAT             Learning rate (default: 0.001)
  --gpu-id INT           Specific GPU to use (default: None, uses first available)
  --seed INT             Random seed (default: 42)

Example Usage:

  1. Use all available GPUs:
python examples/train_mnist.py
  1. Use specific GPU:
python examples/train_mnist.py --gpu-id 1
  1. Force CPU usage with optimized settings:
python examples/train_mnist.py --gpu-id -1

About

Implementation of Better Embeddings with Coupled Adam research paper

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages