This project is focused on building a deep learning model to classify fruits based on images using Convolutional Neural Networks (CNNs).
The dataset used in this project is the Fruits-360 dataset. It contains images of various fruits that are grouped into different categories. The dataset is divided into three folders:
- Training data: Contains images for training the model.
- Validation data: Used to validate the model during training.
- Test data: Used to evaluate the model after training.
You can download the full dataset from Fruits-360 Dataset on Kaggle.
fruit-classification/
β
βββ data/
β βββ train/ # Training images
β βββ val/ # Validation images
β βββ test/ # Test images
β
βββ notebooks/
β βββ fruit.ipynb
β βββ pineaipple-train.ipynb
β
βββ src/
β βββ data_preprocessing.py # Data augmentation and preprocessing
β βββ model.py # Define the CNN model
β βββ train.py # Train the model
β βββ evaluate.py # Evaluate the model
β
βββ models/
β βββ fruit_classifier_final.keras
β
βββ app.py
βββ requirements.txt
βββ README.md
In this project, we:
- Preprocess and augment the images using TensorFlowβs
ImageDataGenerator
. - Define a Convolutional Neural Network (CNN) model for image classification.
- Train the model using the training and validation sets.
- Evaluate the model using the test set to measure its accuracy.
- Visualize results including prediction examples.
-
Data Preprocessing:
- Augment the images (rotate, shift, zoom, etc.) to increase dataset variety.
- Split the data into training, validation, and test sets.
-
Model Definition:
- A CNN model is built using TensorFlow/Keras layers to recognize fruit images.
-
Training:
- The model is trained using the augmented data and validation sets.
- Early stopping and model checkpointing are used to avoid overfitting.
-
Evaluation:
- The model is evaluated using the test dataset, and performance metrics are displayed.
-
Prediction Visualization:
- Predictions are made on test images, and the results are visualized.
- Python: Programming language used for the project.
- TensorFlow/Keras: For building and training the CNN model.
- Matplotlib/Seaborn: For data visualization (plots and graphs).
- Scikit-learn: For additional evaluation metrics like classification report.
- Python 3.x
- TensorFlow
- Keras
- Matplotlib
- Seaborn
- Scikit-learn
To install the required dependencies, run:
pip install -r requirements.txt