A high-performance neural network library implemented in Rust with automatic differentiation capabilities. This project provides a flexible and efficient framework for building and training neural networks, featuring a custom automatic differentiation system and support for MNIST digit classification.
- Automatic Differentiation: Custom implementation of reverse-mode automatic differentiation using Wengert lists (computational graph)
- Flexible Network Architecture: Build neural networks with configurable layers, activation functions, and weight initialization strategies
- MNIST Support: Built-in support for loading and training on the MNIST handwritten digit dataset
- High Performance: Leverages optimized BLAS libraries (Intel MKL) for efficient matrix operations
- Serialization: Save and load trained network parameters using efficient binary serialization
- Interactive Training: Command-line interface with interactive prompts for testing and evaluation
- Progress Tracking: Real-time progress bars and logging during training
- Python Integration: Uses PyO3 for seamless integration with Python datasets library
- Rust: Install Rust using rustup
- Python: Python 3 with the
datasetslibrary installed
- Clone the repository:
git clone <repository-url>
cd neural- Install Python dependencies:
pip install datasets numpy pillow- For Nix users, you can use the provided
shell.nix:
nix-shell- Build the project:
cargo build --releaseTrain a neural network on the MNIST dataset:
cargo run --release -- \
--dataset-path /path/to/mnist/dataset \
--epoches 20 \
--batch-size 256 \
--learning-rate 0.1--dataset-path <PATH>: Path to the MNIST dataset (required)--cache-path <PATH>: Path for caching datasets (default:.cache)--parameters-path <PATH>: Path to save/load network parameters (default:params.dat)-p, --load-parameters-from-cache: Load pre-trained parameters instead of training-e, --epoches <NUM>: Number of training epochs (default: 20)-b, --batch-size <SIZE>: Batch size for training (default: 256)-l, --learning-rate <RATE>: Learning rate for gradient descent (default: 0.1)
After training (or when loading parameters), the program enters an interactive mode where you can:
test <index>: Test the network on a specific test sampletest all: Evaluate the network on all test samples and show statisticssave: Save the current network parameters to diskquit: Exit the program
$ cargo run --release -- --dataset-path ./mnist --epoches 10
[Training progress bars...]
> test 42
Expected target probabilities: [0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 1.00, 0.00, 0.00]
Predicted target probabilities: [0.01, 0.02, 0.01, 0.01, 0.01, 0.01, 0.01, 0.89, 0.02, 0.01]
Loss: 0.123
> save
> quitThe project includes a custom automatic differentiation (AD) system in the auto-differentiation crate:
- Reverse Mode AD: Implemented using Wengert lists (computational graphs) for efficient backpropagation
- Forward Mode AD: Implemented using trace-based differentiation
The Network type uses a builder pattern for constructing networks:
let network = Network::new(28 * 28) // Input size (784 for MNIST)
.push_hidden_layer(32, sigmoid_fn()) // Hidden layer with 32 neurons
.push_output_layer(10, linear_fn()) // Output layer with 10 classes
.map_output(Softmax); // Apply softmax to outputAvailable activation functions:
sigmoid_fn(): Sigmoid activationrelu_fn(): Rectified Linear Unitleaky_relu(): Leaky ReLUlinear_fn(): Linear (identity) activationsoftplus(): Softplus activationelu(alpha): Exponential Linear Unitgaussian_fn(): Gaussian activationsilu(): Sigmoid Linear Unit (SiLU/Swish)
Supported initialization strategies:
He: He initializationXavier: Xavier/Glorot initializationStandard: Standard normal distribution
Softmax: Softmax normalization for multi-class classificationLinear: No transformation (for regression)
- Forward Pass: Compute predictions for a batch of inputs
- Loss Calculation: Compute cross-entropy loss between predictions and targets
- Backward Pass: Use automatic differentiation to compute gradients
- Parameter Update: Apply gradients using gradient descent with the specified learning rate
- Datasets are loaded via Python's
datasetslibrary - First load attempts to use cached serialized data
- If cache is missing, loads from Python and caches for future use
- Supports efficient zero-copy operations when possible
Network parameters are serialized using postcard, a compact binary format:
- Efficient storage of weights and biases
- Fast loading and saving
- Version-independent format
cargo testThe project includes a heavy profile for maximum optimization:
cargo build --profile heavyThis profile enables:
- Link-time optimization (LTO)
- Maximum optimization level
- Stripped symbols
- Abort on panic
The project uses tracing for structured logging:
- Progress bars via
indicatif - CSV output files for loss and predictions
- Configurable log levels via
RUST_LOGenvironment variable
ndarray: N-dimensional arrays and linear algebrandarray-linalg: Linear algebra operations (BLAS/LAPACK)ndarray-rand: Random array generationnum-traits: Numeric traitssmallvec: Stack-allocated small vectors
object-pool: Object pooling for efficient memory management
pyo3: Rust-Python bindingsnumpy: NumPy array support
clap: Command-line argument parsingserde: Serialization frameworkpostcard: Compact binary serializationindicatif: Progress barstracing: Structured logging
- Uses Intel MKL for optimized BLAS operations
- Efficient memory layout with
ndarray - Batch processing for better cache locality
- Object pooling in the AD system to reduce allocations
- SmallVec for stack-allocated small collections
- Currently focused on fully-connected (dense) layers
- Single-threaded training (no parallel batch processing)
- Fixed batch size during training
- Limited to MNIST dataset structure (though extensible)
Potential enhancements:
- Convolutional layers
- Recurrent layers (RNN, LSTM, GRU)
- Multi-threaded training
- GPU acceleration
- More optimizers (Adam, RMSprop, etc.)
- Additional loss functions
The automatic differentiation implementation is based on concepts from the easy-ml library and implements reverse-mode automatic differentiation using Wengert lists (computational graphs).