This is a project in the Deep Learning course DD2424 at the Royal Institute of Technology.
Deep networks are growing deeper and require more energy resources. Previous research has shown that the deep learning field will stop progressing if contributors do not produce lighter and efficient network architectures. This study utilizes MobileNet, the state-of-the-art network for mobile devices, to classify images of fruits and vegetables, and compares its performances to networks with more extensive architectures. Networks were trained from randomly initialized weights and Transfer Learning using ImageNet weights. MobileNet trained with Transfer Learning produced a top-1 accuracy of 96.8% and performed like the more extensive network architectures. This study uses the top-1 MobileNet predictor to conclude a real-time image classification app.
Read paper here.
(Figure 1. Graphical User Interface. The first window shows captured photos live from the userβs webcam. The second window includes prediction bars. The demonstration shows 99% Green Apple. At the bottom, there is a purple button to close the app.)
-
Navigate to the repository
-
Setup a virtual environment
python3 -m venv fruits
source fruits/bin/activate
- Install Required Utility Packages
pip3 install -r requirements.txt
- Run,
python3 main.py --realtime
- Run,
python3 main.py --predict --img_path 'MyPath/ToThe/Image'
NOTE: Replace MyPath/ToThe/Image with the path to your image to classify. Default: data/test_data/apple.jpg
This part requires an NVIDIA GPU.
Send a request to Majdj@kth.se to obtain the dataset folder, because the dataset is larger than 100MB.
-
Replace data/FRUITS folder with the one obtained from majdj@kth.se.
-
Run,
python3 main.py --generate_data
- Run,
python3 main.py --train --model "mobilenet"
- Run,
python3 main.py --train --model "mobilenet" --transfer_learning
NOTE: Training supplements: scheduler, pass --scheduler; fine tune a pre-trained model, pass --fine_tune
- Run,
python3 main.py --evaluate --model "mobilenet" --path 'MyPath/To/Weights'
NOTE: If the model was trained with Transfer Learning, pass the --transfer_learning tag.