A privacy-preserving next word prediction system using federated learning, inspired by Google's Gboard implementation.
This project implements a federated learning system for next word prediction, where multiple clients can train a shared model without sharing their raw data. The system uses LSTM-based neural networks to predict the next word in a sequence while maintaining user privacy.
- Privacy-Preserving Training: Train models without sharing raw text data
- Federated Learning: Collaborative model training across multiple clients
- LSTM Architecture: Advanced neural network for sequence prediction
- Web Interface: Flask-based UI for training and prediction
- Real-time Predictions: Get top 5 most likely next words
- Training Visualization: Track and visualize training progress
- Data Processing: Text tokenization, vocabulary building, and sequence creation
- Model Architecture: LSTM with embedding layer for word representation
- Federated Learning: Weight aggregation and model distribution
- Web Application: User interface for training and prediction
-
Clone the repository:
git clone https://github.com/AdityaC784/PrivacyType.git cd PrivacyType -
Install dependencies:
pip install -r requirements.txt -
Download NLTK data:
import nltk nltk.download('punkt')
-
Prepare Data:
- Place training data in CSV format in the
client_datadirectory - Each CSV should have a 'Text' or 'text' column
- Place training data in CSV format in the
-
Train Model:
- Run the web application:
python app.py - Navigate to the training page
- Configure training parameters (clients, rounds, epochs)
- Start training
- Run the web application:
-
Make Predictions:
- Use the prediction interface to enter text
- Get the top 5 most likely next words
privacytype/
├── app.py # Flask web application
├── data_processor.py # Text processing and sequence creation
├── lstm_model.py # LSTM model implementation
├── federated_learning.py # Federated learning logic
├── main.py # Main training script
├── run_federated.py # Federated learning runner
├── client_data/ # Directory for client datasets
├── models/ # Saved models and metadata
└── visualizations/ # Training visualization outputs
The following files and directories should not be uploaded to version control:
-
Model files and directories:
models/directory (contains trained models and weights)*.h5files (model weight files)*.pklfiles (pickle files with model state)
-
Generated data:
visualizations/directory (contains generated plots)training_logs.json(runtime logs)
-
Environment and system files:
__pycache__/directories*.pycfiles.DS_Store(Mac system files)venv/orenv/(virtual environment directories)
-
Large data files:
client_data/directory (contains your training datasets)- Any large CSV files
Instead, include sample data files or instructions for obtaining them in the repository.
- Vocabulary Size: 11,418 words (configurable)
- Embedding Dimension: 128
- LSTM Units: 256
- Sequence Length: 10 tokens
- Dropout Rate: 0.2
Contributions are welcome! Please feel free to submit a Pull Request.
This project is licensed under the MIT License - see the LICENSE file for details.
- Inspired by Google's Gboard implementation of federated learning
- Uses TensorFlow and Keras for deep learning
- Flask for web interface