This project demonstrates a deep, practical understanding of the Keras framework by building a custom 2D Convolutional (Conv2D) layer from scratch. The custom layer is then successfully integrated into a CNN and trained on a filtered subset of the CIFAR-10 dataset, achieving a final test accuracy of 83.13%.
While standard deep learning libraries offer a comprehensive suite of layers, the ability to build custom components is a crucial skill for implementing novel research papers and developing unique model architectures. This project's goal is to showcase that capability by:
- Implementing a Custom
MyConv2DLayer: Building a functional equivalent of the standardtf.keras.layers.Conv2Dlayer by inheriting from the basetf.keras.layers.Layerclass. - Demonstrating Practical Application: Using the custom layer within a larger Convolutional Neural Network (CNN) to solve a real-world image classification task.
- Validating Performance: Training the model using best practices and evaluating its effectiveness through accuracy metrics, training history visualization, and a confusion matrix.
The solution is contained within a single Jupyter Notebook that methodically builds the custom layer and then uses it in a complete training and evaluation pipeline.
-
Custom
MyConv2DLayer: The core of this project is the custom-built convolutional layer. Its implementation correctly follows Keras API design patterns:__init__(): Initializes layer-specific hyperparameters like the number of filters, kernel size, and activation function.build(): Creates the layer's trainable weights (kernelwand biasb) usingself.add_weight(). This method intelligently infers the input channel depth from the input shape, making the layer flexible and reusable.call(): Defines the layer's forward pass logic, which consists of the coretf.nn.conv2doperation, the addition of the bias term, and the application of the specified activation function.
-
Data Filtering and Preprocessing: The CIFAR-10 dataset was loaded and then filtered to create a simpler, three-class classification problem using only images of planes, cars, and birds. This focused approach allowed for faster iteration and a clear validation of the custom layer's functionality. The labels were remapped to
0, 1, 2, and pixel values were normalized. -
Model Architecture: A Sequential CNN model was constructed to test the custom layer in a realistic setting:
- A standard
Conv2DandMaxPooling2Dlayer to extract initial features. - The custom
MyConv2Dlayer to perform a subsequent convolution. FlattenandDenselayers for final classification.
- A standard
-
Training and Evaluation: The model was compiled with the
Adamoptimizer andSparseCategoricalCrossentropyloss. Training was managed effectively using Keras callbacks, includingEarlyStoppingto prevent overfitting andReduceLROnPlateauto fine-tune the learning rate. The model's performance was then visualized with accuracy/loss curves and a detailed confusion matrix.
- Primary Framework: TensorFlow 2.10
- Core Libraries: Keras, NumPy
- Data Visualization: Matplotlib
- Metrics & Analysis: scikit-learn
The project uses a custom subset of the well-known CIFAR-10 dataset. From the original 10 classes of 32x32 color images, only three were selected: plane, car, and bird. This resulted in a balanced dataset of 15,000 training images and 3,000 test images, ideal for focused model validation.
To run this project locally, please follow these steps:
-
Clone the Repository:
git clone https://github.com/imehranasgari/your-repo-name.git cd your-repo-name -
Install Dependencies: It is recommended to use a virtual environment.
pip install -r requirements.txt
(Note: A
requirements.txtfile should be created containingtensorflow,numpy,matplotlib, andscikit-learn.) -
Run the Notebook: Launch Jupyter and open the
cumtom_conv_with_limit_class.ipynbnotebook to execute the cells sequentially.jupyter notebook
This project successfully demonstrates the ability to create and validate a complex custom layer within the Keras ecosystem, achieving strong performance on the defined task.
- Final Test Accuracy: 83.13%
- Model Convergence: The training and validation curves show smooth convergence with the
ReduceLROnPlateaucallback effectively adjusting the learning rate. - Functionality Proof: The high accuracy confirms that the custom
MyConv2Dlayer correctly implements the convolution operation and integrates seamlessly into a standard Keras model.
This file was intentionally created to demonstrate skills in implementing and explaining machine learning models, rather than solely focusing on achieving the highest evaluation metrics. The simple approach is for learning, benchmarking, and illustrating fundamental concepts.
Model Architecture with Custom Layer
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv2d (Conv2D) (None, 30, 30, 32) 896
max_pooling2d (MaxPooling2D (None, 15, 15, 32) 0
)
my_conv2d (MyConv2D) (None, 13, 13, 64) 18496
flatten (Flatten) (None, 10816) 0
dense (Dense) (None, 3) 32451
=================================================================
Total params: 51,843
Trainable params: 51,843
Non-trainable params: 0
_________________________________________________________________
Training & Validation Performance Curves
Test Set Confusion Matrix
Building a custom Conv2D layer from scratch was a significant step beyond simply using a framework's built-in tools. It required a solid understanding of the underlying mechanics of convolution, including kernel and bias management, and the use of low-level TensorFlow operations like tf.nn.conv2d.
This project proves my ability to not only use but also extend the Keras API, a critical skill for implementing cutting-edge models or creating specialized layers for unique problem domains. It showcases a deeper level of expertise that is essential for advanced AI development and research.
Mehran Asgari
- Email: imehranasgari@gmail.com
- GitHub: https://github.com/imehranasgari
This project is licensed under the Apache 2.0 License – see the LICENSE file for details.
💡 Some interactive outputs (e.g., plots, widgets) may not display correctly on GitHub. If so, please view this notebook via nbviewer.org for full rendering.