This project contains a machine learning model and a Gradio web application to classify images as either AI-generated or real. It utilizes Python, the Hugging Face Transformers library, PyTorch for model training and inference, and Gradio for the interactive user interface. The system is designed to help distinguish between photographic images and images created by artificial intelligence.
- Python
- PyTorch
- Hugging Face Transformers
- Gradio
- Scikit-learn
- Pillow (PIL)
- Matplotlib & Seaborn (for metrics visualization)
Here's a brief overview of the key files and directories:
main.py
: The main script for training the image classification model. It handles data loading, preprocessing, model training using PyTorch/Transformers, evaluation, and saving the trained model.app.py
: A Gradio web application that loads the trained model and provides an interactive interface for users to upload images and get predictions (AI-generated or real).preprocess_images.py
: A utility script to convert images in the dataset (especially HEIC/HEIF formats) to JPEG format. This is often a necessary step before training.inference.py
: A simple Python script for performing command-line inference with the trained model on a single image.ai_vs_real_classifier/
: This directory is created during training (bymain.py
) and stores the trained model files.app.py
loads the model from here.人工智慧第四組資料夾/
: This is the expected root directory for the training dataset. It should contain two subdirectories:ai/
: For AI-generated images.real/
: For real photographic images.
confusion_matrix.png
: An image file generated after training bymain.py
, showing the confusion matrix of the model's performance on the test set.pyproject.toml
: Defines project metadata and dependencies..python-version
: Specifies the Python version for the project.README.md
: This file, providing information about the project.
-
Clone the repository:
git clone <repository-url> # Replace <repository-url> with the actual URL cd <repository-directory>
-
Python Version: This project uses Python 3.12. It's recommended to use a virtual environment (e.g., venv, conda).
python -m venv venv source venv/bin/activate # On Windows: venv\Scripts\activate
-
Install Dependencies: The project dependencies are listed in
pyproject.toml
. You can install them usingpip
:pip install transformers torch torchvision torchaudio gradio Pillow scikit-learn matplotlib seaborn
For GPU support with PyTorch (recommended for training and faster inference), follow the instructions on the official PyTorch website to install the version compatible with your CUDA setup.
-
Optional: HEIC/HEIF Support: The
preprocess_images.py
script can convert HEIC/HEIF images. To enable this, you need to installpillow-heif
:pip install pillow-heif
You might also need to install system libraries for
libheif
:- On Debian/Ubuntu:
sudo apt-get install libheif1 libde265-0
- On macOS (using Homebrew):
brew install libheif
- On Debian/Ubuntu:
The model is trained to classify images as either AI-generated or real. You'll need to prepare your dataset accordingly:
-
Dataset Directory Structure: Create a root folder for your dataset. By default, the training script (
main.py
) expects this folder to be named人工智慧第四組資料夾
(as seen in the script) and located in the project's root directory. Inside this folder, create two subdirectories:ai/
: Place all your AI-generated images in this folder.real/
: Place all your real photographic images in this folder.
The structure should look like this:
<project-root>/ ├── 人工智慧第四組資料夾/ │ ├── ai/ │ │ ├── image1.jpg │ │ ├── image2.png │ │ └── ... │ └── real/ │ ├── image3.jpg │ ├── image4.jpeg │ └── ... ├── main.py ├── app.py └── ...
-
Image Formats & Preprocessing: The training script processes images using the Pillow library. While it can handle various formats, it's recommended to have images in common web formats like JPEG or PNG. If you have images in HEIC/HEIF format, or other formats that need conversion, you can use the
preprocess_images.py
script.- Ensure you have installed
pillow-heif
as described in the "Setup and Installation" section if you need to process HEIC/HEIF files. - Modify the
DATA_DIRECTORY
variable at the bottom ofpreprocess_images.py
if your dataset folder is named differently or located elsewhere. - Run the script from the project root:
python preprocess_images.py
This will convert images (e.g., HEIC to JPEG) and place them in the same directory structure, potentially overwriting originals if configured to do so (the script defaults to deleting originals after successful conversion if the new file has a different name or was an HEIF). Review the script's behavior if you have concerns about original files.
- Ensure you have installed
Once your dataset is prepared and dependencies are installed, you can train the image classification model:
-
Run the Training Script: Execute the
main.py
script from the project's root directory:python main.py
-
Training Process:
- The script will load images from the specified data directory (
./人工智慧第四組資料夾/
by default). - It uses a Vision Transformer model (e.g.,
google/vit-large-patch16-224-in21k
) pre-trained on ImageNet, and fine-tunes it on your dataset. - A Focal Loss function is implemented to help address potential class imbalance between 'AI' and 'real' images.
- The script will output training progress, including loss and evaluation metrics (accuracy, F1-score, precision, recall) for each epoch if a validation set is created.
- The training arguments (like number of epochs, batch size, etc.) are defined within
main.py
and can be adjusted if needed.
- The script will load images from the specified data directory (
-
Output:
- Trained Model: The best performing model (based on F1-score on the validation set) will be saved to the
./ai_vs_real_classifier/
directory. This directory will contain files likepytorch_model.bin
,config.json
, etc. - Confusion Matrix: A visual representation of the model's performance on the test set,
confusion_matrix.png
, will be saved in the project's root directory.
- Trained Model: The best performing model (based on F1-score on the validation set) will be saved to the
Note on Resources: Training deep learning models, especially large ones like Vision Transformers, can be computationally intensive and may require a GPU for reasonable training times. The script is configured to use mixed-precision training (bf16=True
) and PyTorch 2.0 compilation (torch_compile=True
) if available, which can speed up training and reduce memory usage on compatible hardware.
Once the model is trained and saved, you can use it to classify new images.
The primary way to interact with the classifier is through the Gradio web application:
- Run the App:
Execute the
app.py
script from the project's root directory:python app.py
- Using the Interface:
- The script will typically print a local URL to your console (e.g.,
Running on local URL: http://127.0.0.1:7860
). Open this URL in your web browser. - You'll see an interface titled "Custom Image Classifier".
- Upload an image using the provided upload box.
- The model will process the image, and the predicted label ("AI" or "Real") along with confidence scores for the top classes will be displayed.
- The app includes some example image paths (
test_image_0.jpg
,test_image_1.png
, etc.). Ensure these example images exist if you want to use the example buttons in the Gradio interface, or update the paths inapp.py
.
- The script will typically print a local URL to your console (e.g.,
For a simpler, non-interactive way to get a prediction for a single image, you can use inference.py
:
- Prepare an image: Place an image you want to test (e.g.,
my_test_image.jpg
) in the project's root directory, or provide the correct path to it. - Modify
inference.py
(if needed):- The script is currently set to load the model from
./outputs
. Important: The training scriptmain.py
saves the model to./ai_vs_real_classifier
. You will likely need to change themodel
path ininference.py
from"./outputs"
to"./ai_vs_real_classifier"
to use your trained model. - Update the image path
Image.open("test_image.jpg")
toImage.open("your_image_name.jpg")
.
- The script is currently set to load the model from
- Run the script:
python inference.py
- The script will print the raw prediction output (a list of dictionaries with labels and scores) to the console.
Note: The Gradio application (app.py
) is the recommended way for most users to perform inference due to its user-friendly interface and correct model path configuration by default (it loads from ./ai_vs_real_classifier
).
- Base Model: The classification model is based on the Vision Transformer (ViT) architecture. Specifically, the
main.py
script is configured to use"google/vit-large-patch16-224-in21k"
as the starting point. This model is pre-trained on the ImageNet-21k dataset. - Fine-Tuning: The pre-trained ViT model is then fine-tuned on the custom "AI vs. Real" image dataset provided by the user.
- Output Layer: The number of labels for the classification head is set to 2 (e.g., "AI" and "Real"). The
id2label
mapping is{0: "AI", 1: "Real"}
andlabel2id
is{"AI": 0, "Real": 1}
. - Loss Function: To handle potential class imbalance in the dataset (where one class might have significantly more images than the other), a Focal Loss function is implemented in
main.py
and used during training. This helps the model pay more attention to hard-to-classify examples and down-weights the loss assigned to well-classified examples. Thealpha
parameter of the Focal Loss is dynamically calculated based on the inverse class frequencies in the training data. - Performance: The
main.py
script computes and logs various metrics during validation and testing, including accuracy, F1-score, precision, and recall. A confusion matrix is also generated and saved asconfusion_matrix.png
.
After a successful training run using python main.py
, a confusion matrix image named confusion_matrix.png
is saved to the project's root directory.
This image visually represents the model's performance by showing the counts of true positive, true negative, false positive, and false negative predictions on the test set.
Viewing the Confusion Matrix:
- You can open the
confusion_matrix.png
file directly from your file explorer. - If this README is being viewed on a platform that renders Markdown and can access local images (e.g., when the repository is hosted on GitHub), the image can be embedded directly.
(Note: The image will only be visible below if confusion_matrix.png
exists in the root of the repository and the Markdown viewer supports relative image paths.)
Here are some potential areas for future improvement and development:
requirements.txt
: Create arequirements.txt
file for easier pip-based dependency management, in addition topyproject.toml
.- Configuration File: Move hardcoded paths (like dataset directory in
main.py
andpreprocess_images.py
) and training parameters (epochs, batch size inmain.py
) into a configuration file (e.g., YAML or JSON) for easier modification without altering the code. - Expanded Model Evaluation: Implement more detailed evaluation, such as per-class metrics if more categories are added, or ROC curves and AUC scores.
- Model Checkpointing Options: Provide more flexible model checkpointing strategies during training (e.g., save best N models, save every K epochs).
- Data Augmentation: Explore and implement more sophisticated data augmentation techniques during training to improve model generalization.
- Error Handling: Enhance error handling and logging in the scripts.
- Testing: Add unit tests for preprocessing, training, and inference components.
- Packaging: Package the application for easier distribution (e.g., using Docker).
- UI Enhancements: Add more features to the Gradio UI, such as displaying example predictions or allowing adjustment of inference parameters.