This project demonstrates a complete machine learning pipeline for predicting the likelihood of stroke in patients using clinical and demographic data. It showcases key data science skills including data exploration, preprocessing, handling imbalanced data, model training, evaluation, threshold tuning, and model serialisation — all implemented in Python with best practices.
The primary goal is to develop and evaluate models that can predict stroke cases despite the challenges of a highly imbalanced dataset.
The dataset used in this project is sourced from Kaggle:
Stroke Prediction Dataset
It contains demographic, lifestyle, and medical information along with a binary label indicating stroke occurrence.
stroke-prediction-ml/
├── data/
│ └── stroke_data.csv # Raw dataset file
├── models/
│ └── final_model_histgbc.joblib # Serialised trained model and threshold
├── notebooks/
│ └── stroke_prediction_pipeline.ipynb # Jupyter notebook with full analysis and modelling
├── requirements.txt # Project dependencies
├── README.md # This file
└── .gitignore # Specifies files/folders to ignore in Git
- Loaded and inspected dataset for missing values and data types.
- Visualised distributions and relationships of features.
- Identified class imbalance in the target variable (
stroke).
- Handled missing data and categorical variables.
- Created meaningful feature transformations informed by domain knowledge.
- Investigated imbalanced class distribution (stroke cases are the minority).
- Applied random oversampling to augment minority class for training data.
- Discussed alternative sampling techniques (e.g., SMOTE), but faced compatibility constraints.
-
Trained multiple models on both original imbalanced and oversampled datasets:
- Logistic Regression
- Random Forest
- HistGradientBoostingClassifier (HGB)
- LightGBM
-
Evaluated models using:
- Precision, Recall, F1-score per class
- Confusion matrices
- ROC AUC score
- Precision-Recall curves
-
Found HistGradientBoostingClassifier trained on the oversampled data to achieve the best balance between precision and recall, especially for the minority class.
- Performed threshold tuning on predicted probabilities to optimise F1-score.
- Selected an optimal threshold of 0.22 instead of the default 0.5 to improve detection of stroke cases.
- Demonstrated the trade-off between false positives and false negatives using classification reports and confusion matrices.
- Retrained the best model on the full oversampled training dataset.
- Saved the model along with the optimised threshold using
joblibfor easy loading and inference.
- Demonstrated loading the saved model and threshold.
- Used the model to make predictions on the test set with the custom threshold.
- Recomputed evaluation metrics and visualisations to verify consistent performance.
- Python 3.10+
- Jupyter Notebook
- NumPy
- pandas
- matplotlib
- seaborn
- scikit-learn
- joblib
- LightGBM
- VSCode for development
-
Clone the repository:
git clone <your-repo-url> cd stroke-prediction-ml
-
Create and activate a virtual environment (recommended):
python3 -m venv venv source venv/bin/activate # Mac/Linux venv\Scripts\activate # Windows
-
Install dependencies:
pip install -r requirements.txt
-
Download the dataset from Kaggle and place the CSV file inside the data/ folder.
-
Launch Jupyter Notebook and run the pipeline:
jupyter notebook notebooks/stroke_prediction_pipeline.ipynb
-
Follow the notebook to explore data, train models, tune thresholds, and save/load models.
- The dataset is highly imbalanced with relatively few stroke cases, posing a significant challenge for predictive modelling.
- Despite efforts including oversampling and model tuning, recall and precision for stroke cases remain modest.
- Threshold tuning improved minority class detection but introduced more false positives.
- Future work could explore advanced sampling techniques such as SMOTE or ensemble methods to better balance the classes.
- Additional features or clinical data may be required for substantial improvements in stroke prediction.
This project is licensed under the MIT License.