A sophisticated machine learning system for stock market prediction using Transformer architecture. This project includes data collection, preprocessing, model training, and evaluation components.
This project consists of three main components:
- Data Collection (
collect.py
) - Model Architecture (
model.py
) - Model Testing (
model_test.py
)
- Comprehensive stock data collection including:
- Basic stock information (OHLCV data)
- Financial metrics
- Historical financials (quarterly and annual)
- Options data
- Analyst ratings
- Industry data
- Peer comparison data
- Advanced Transformer-based prediction model
- Cross-validation testing framework
- Performance visualization tools
The StockDataCollector
class handles all data collection operations:
- Basic stock information
- Historical financial data
- Options data
- Analyst ratings
- Industry data
- Peer comparison
Key methods:
collector = StockDataCollector(symbol, mongodb_uri)
collector.collect_all_data() # Collects all available data
Contains the core prediction model and supporting classes:
TransformerPredictor
: Main model architecture using TransformerStockDataset
: Custom dataset class for stock dataFocalLoss
: Custom loss function for imbalanced dataStockPredictor
: High-level class for model training and prediction
Key features:
- Multi-head attention mechanism
- Focal loss for handling class imbalance
- Technical indicators (MA5, MA20, RSI, MACD)
- Sequence-based prediction
Comprehensive testing framework including:
- Time series cross-validation
- Performance metrics calculation
- Confusion matrix visualization
- Metrics plotting across folds
- Set up MongoDB:
# Start MongoDB service
mongod --dbpath <your_db_path>
- Collect data:
from collect import StockDataCollector
collector = StockDataCollector("AAPL", "mongodb://localhost:27017/")
collector.collect_all_data()
- Train the model:
from model import StockPredictor
# Define parameters
model_params = {
'input_dim': 9,
'num_heads': 4,
'num_layers': 2,
'dropout': 0.1
}
training_params = {
'learning_rate': 0.0005,
'batch_size': 64,
'num_epochs': 50
}
predictor = StockPredictor(model_params, training_params)
predictor.train(train_loader, val_loader)
- Test the model:
from model_test import ModelTester
tester = ModelTester("AAPL", n_splits=5)
fold_metrics, avg_metrics = tester.run_cross_validation()
The Transformer-based predictor includes:
- Input embedding layer
- Multi-head attention layers
- Position-wise feed-forward networks
- Dropout for regularization
- Final classification layer
The system processes the following features:
- Price data (Open, High, Low, Close)
- Volume
- Moving averages (5-day and 20-day)
- RSI (Relative Strength Index)
- MACD (Moving Average Convergence Divergence)
The model is evaluated using:
- Accuracy
- Precision
- Recall
- F1 Score
- Confusion matrices
- Cross-validation performance plots
The model was trained and evaluated on stock market data with the following results:
Train Loss: 0.0401
Train Accuracy: 62.99%
Validation Loss: 0.0573
Validation Accuracy: 52.49%
- Total samples in dataset: 1,620
Sample fold (Fold 5) metrics:
Accuracy: 0.6019 (60.19%)
Precision: 0.7137 (71.37%)
Recall: 0.5651 (56.51%)
F1 Score: 0.6308 (63.08%)
Average metrics across all folds:
Accuracy: 0.5784 (57.84%)
Precision: 0.6287 (62.87%)
Recall: 0.6576 (65.76%)
F1 Score: 0.6348 (63.48%)
These results show that:
- The model achieves consistent performance across training and validation sets
- Precision is notably higher than recall, indicating the model is more conservative in its predictions
- The model maintains stable performance across different cross-validation folds
- Overall F1 score of ~63% suggests the model performs significantly better than random chance (50%) for binary prediction
- The model uses time series cross-validation to prevent future data leakage
- Focal Loss is implemented to handle class imbalance
- The system includes extensive error handling and logging
- Technical indicators are calculated automatically during data preparation
- Fork the repository
- Create your feature branch
- Commit your changes
- Push to the branch
- Create a new Pull Request
This project is for educational purposes only. Trading stocks carries significant risks, and past performance does not guarantee future results. Always conduct thorough research and consider consulting with a financial advisor before making investment decisions.