Context-Aware Personalized Federated Learning framework for LLMs via Variational Bayesian LoRA, RFF-based personalization, and Sparse Optimization.
This repository contains the official PyTorch implementation of CA-PFL. CA-PFL is a novel framework designed to address the challenges of deploying Large Language Models (LLMs) in heterogeneous federated environments. It dynamically generates personalization strategies and optimizes communication efficiency through variational sparse control and random Fourier features.
- Client-Adaptive Sparsity: Dynamically assigns LoRA rank based on local data feature distributions to establish sparsity.
- Variational Bayesian LoRA: Utilizes variational Bayesian priors to estimate optimal rank configurations compatible with federated aggregation.
- RFF-based Personalization: Reduces communication cost and enables efficient personalization through Random Fourier Feature (RFF) projection.
-
Dynamic Sparsification & Error Compensation: Implements dynamic pruning based on learned
$\kappa$ values and maintains an error buffer to ensure convergence. - Communication Efficiency: Reduces communication overhead by approximately 78% compared to baselines while maintaining high accuracy.
The codebase is organized as follows:
.
├── main.py # Entry point: Argument parsing, data partitioning, and federated training loop
├── capfl_client_training.py # Client-side logic: Local training, loss calculation, and sparse updates
├── capfl_server_aggregation.py # Server-side logic: Aggregating sparse model updates and variational priors
├── capfl_integrated_model.py # Model definition: Wraps Llama with RFF and Variational Heads
├── capfl_variational.py # Variational Inference: KL Divergence calculation and Kappa sampling
├── capfl_dynamic_sparsification.py # Sparsification logic: Dynamic pruning and error compensation buffer
├── capfl_rff_personalization.py # Personalization: Random Fourier Features (RFF) projection module
└── capfl_final_test.py # Evaluation script for Accuracy (Logits Method)
- Clone the repository
git clone https://github.com/YanDang/CA-PFL.git
cd CA-PFL
- Install dependencies It is recommended to use a virtual environment.
pip install -r requirements.txt
To run the federated learning simulation with default settings (Llama-3.2-3B on OpenBookQA):
python main.py \
--model_path "meta-llama/Llama-3.2-3B" \
--num_clients 10 \
--server_epochs 10 \
--lora_rank 8 \
--alpha_s 0.01 \
--alpha_p 0.01 \
--rff_output_dim 256
| Argument | Default | Description |
|---|---|---|
| --model_path | meta-llama/Llama-3.2-3B | Path to the base LLM |
| --alpha_s | 0.01 | Weight for Smooth-L1 regularization (controls sparsity) |
| --alpha_p | 0.01 | Weight for KL divergence loss (variational constraint) |
| --alpha_f | 0.1 | Weight for FedProx constraint term |
| --rff_output_dim | 256 | Dimension for Random Fourier Features projection |
| --base_sparsity_ratio | 0.15 | Base retention ratio for dynamic sparsification |
| --num_clients | 10 | Total number of clients in the federated system |
CA-PFL demonstrates superior performance on heterogeneous non-IID data setups compared to state-of-the-art baselines like FedEx-LoRA and LoRA-FAIR.
| Model | Method | Accuracy (Avg) | Comm. Cost (MB/round) |
|---|---|---|---|
| Llama-3.2 3B | FedEx-LoRA | 76.59% | 41.29 |
| CA-PFL (Ours) | 77.77% | 7.19 | |
| Mistral-7B | FedEx-LoRA | 78.66% | 95.6 |
| CA-PFL (Ours) | 79.89% | 21.4 |
We model the LoRA rank selection as a variational inference problem. The model learns a distribution over the sparsity parameter
To handle the "cold start" and personalization on edge devices, we use Random Fourier Features to project hidden states into a low-dimensional manifold:
This acts as a lightweight "Personalized Modulator" that generates the specific
Parameters are updated based on the learned importance score
If you find this code useful for your research, please cite our paper