Skip to content

nabla-ml/nabla

Repository files navigation

alt text

Nabla is a Python library that provides three key features:

  • Multidimensional Array computation (like NumPy) with strong GPU acceleration
  • Composable Function Transformations: vmap, grad, jit, and other Automatic Differentiation tools
  • Deep integration with MAX and (custom) Mojo kernels

For tutorials and API reference, visit: nablaml.com

Installation

Now available on PyPI!

pip install nabla-ml

Quick Start

import nabla as nb

# Example function using Nabla's array operations
def foo(input):
    return nb.sum(input * input, axes=0)

# Vectorize, differentiate, accelerate
foo_grads = nb.jit(nb.vmap(nb.grad(foo)))
gradients = foo_grads(nb.randn((10, 5)))

Development Setup and Reproducibility

This guide is for contributors or for reproducing the validation and benchmark results.

1. Initial Setup

First, clone the repository and set up a virtual environment with all necessary dependencies.

# Clone the repository
git clone https://github.com/nabla-ml/nabla.git
cd nabla

# Create and activate a virtual environment (recommended)
python3 -m venv venv
source venv/bin/activate

# Install all core and development dependencies
pip install -r requirements-dev.txt
pip install -e ".[dev]"

2. Run the Correctness Validation Suite

This runs the full test suite to verify Nabla's correctness against JAX.

# Navigate to the unit test directory from the project root
cd nabla/tests/unit

# Execute the unified test script
python unified.py all -all-configs

3. Run the Performance Benchmarks

This script reproduces the performance benchmarks for Nabla, JAX, and PyTorch.

# Navigate to the benchmark directory
cd nabla/tests/benchmarks

# Run the benchmark script
python benchmark1.py

Repository Structure

alt text

nabla/
├── nabla/                     # Core Python library
│   ├── core/                  # Array class and MAX compiler integration
│   ├── nn/                    # Neural network modules and models
│   ├── ops/                   # Mathematical operations (binary, unary, linalg, etc.)
│   ├── transforms/            # Function transformations (vmap, grad, jit, etc.)
│   └── utils/                 # Utilities (formatting, types, MAX-interop, etc.)
├── tests/                     # Comprehensive test suite
├── tutorials/                 # Notebooks on Nabla usage for ML tasks
├── examples/                  # Example scripts for common use cases
└── experimental/              # Core (pure) Mojo library (WIP!)

Contributing

Contributions welcome! Discuss significant changes in Issues first. Submit PRs for bugs, docs, and smaller features.

License

Nabla is licensed under the Apache-2.0 license.


Thank you for checking out Nabla!

Development Status PyPI version Python 3.12+ License: Apache 2.0