Candle is a minimalist ML framework for Rust with a focus on performance (including GPU support) and ease of use. Try our online demos: whisper, LLaMA2, T5, yolo, Segment Anything.
Make sure that you have candle-core
correctly installed as described in Installation.
Let's see how to run a simple matrix multiplication.
Write the following to your myapp/src/main.rs
file:
use candle_core::{Device, Tensor};
fn main() -> Result<(), Box<dyn std::error::Error>> {
let device = Device::Cpu;
let a = Tensor::randn(0f32, 1., (2, 3), &device)?;
let b = Tensor::randn(0f32, 1., (3, 4), &device)?;
let c = a.matmul(&b)?;
println!("{c}");
Ok(())
}
cargo run
should display a tensor of shape Tensor[[2, 4], f32]
.
Having installed candle
with Cuda support, simply define the device
to be on GPU:
- let device = Device::Cpu;
+ let device = Device::new_cuda(0)?;
For more advanced examples, please have a look at the following section.
These online demos run entirely in your browser:
- yolo: pose estimation and object recognition.
- whisper: speech recognition.
- LLaMA2: text generation.
- T5: text generation.
- Phi-1.5, and Phi-2: text generation.
- Segment Anything Model: Image segmentation.
- BLIP: image captioning.
We also provide a some command line based examples using state of the art models:
- LLaMA and LLaMA-v2: general LLM, includes the SOLAR-10.7B variant.
- Falcon: general LLM.
- Gemma: 2b and 7b general LLMs from Google Deepmind.
- Phi-1, Phi-1.5, and Phi-2: 1.3b and 2.7b general LLMs with performance on par with LLaMA-v2 7b.
- StableLM-3B-4E1T: a 3b general LLM pre-trained on 1T tokens of English and code datasets. Also supports StableLM-2, a 1.6b LLM trained on 2T tokens, as well as the code variants.
- Mamba: an inference only implementation of the Mamba state space model.
- Mistral7b-v0.1: a 7b general LLM with better performance than all publicly available 13b models as of 2023-09-28.
- Mixtral8x7b-v0.1: a sparse mixture of experts 8x7b general LLM with better performance than a Llama 2 70B model with much faster inference.
- StarCoder and StarCoder2: LLM specialized to code generation.
- Qwen1.5: Bilingual (English/Chinese) LLMs.
- RWKV v5 and v6: An RNN with transformer level LLM performance.
- Replit-code-v1.5: a 3.3b LLM specialized for code completion.
- Yi-6B / Yi-34B: two bilingual (English/Chinese) general LLMs with 6b and 34b parameters.
- Quantized LLaMA: quantized version of the LLaMA model using the same quantization techniques as llama.cpp.
- Stable Diffusion: text to image generative model, support for the 1.5, 2.1, SDXL 1.0 and Turbo versions.
- Wuerstchen: another text to image generative model.
- segment-anything: image segmentation model with prompt.
- SegFormer: transformer based semantic segmantation model.
- Whisper: speech recognition model.
- EnCodec: high-quality audio compression model using residual vector quantization.
- MetaVoice: foundational model for text-to-speech.
- T5, Bert, JinaBert : useful for sentence embeddings.
- DINOv2: computer vision model trained using self-supervision (can be used for imagenet classification, depth evaluation, segmentation).
- VGG, RepVGG: computer vision models.
- BLIP: image to text model, can be used to generate captions for an image.
- TrOCR: a transformer OCR model, with dedicated submodels for hand-writing and printed recognition.
- Marian-MT: neural machine translation model, generates the translated text from the input text.
Run them using commands like:
cargo run --example quantized --release
In order to use CUDA add --features cuda
to the example command line. If
you have cuDNN installed, use --features cudnn
for even more speedups.
There are also some wasm examples for whisper and
llama2.c. You can either build them with
trunk
or try them online:
whisper,
llama2,
T5,
Phi-1.5, and Phi-2,
Segment Anything Model.
For LLaMA2, run the following command to retrieve the weight files and start a test server:
cd candle-wasm-examples/llama2-c
wget https://huggingface.co/spaces/lmz/candle-llama2/resolve/main/model.bin
wget https://huggingface.co/spaces/lmz/candle-llama2/resolve/main/tokenizer.json
trunk serve --release --port 8081
And then head over to http://localhost:8081/.
candle-tutorial
: A very detailed tutorial showing how to convert a PyTorch model to Candle.candle-lora
: Efficient and ergonomic LoRA implementation for Candle.candle-lora
has
out-of-the-box LoRA support for many models from Candle, which can be found here.optimisers
: A collection of optimisers including SGD with momentum, AdaGrad, AdaDelta, AdaMax, NAdam, RAdam, and RMSprop.candle-vllm
: Efficient platform for inference and serving local LLMs including an OpenAI compatible API server.candle-ext
: An extension library to Candle that provides PyTorch functions not currently available in Candle.kalosm
: A multi-modal meta-framework in Rust for interfacing with local pre-trained models with support for controlled generation, custom samplers, in-memory vector databases, audio transcription, and more.candle-sampling
: Sampling techniques for Candle.gpt-from-scratch-rs
: A port of Andrej Karpathy's Let's build GPT tutorial on YouTube showcasing the Candle API on a toy problem.candle-einops
: A pure rust implementation of the python einops library.
If you have an addition to this list, please submit a pull request.
- Simple syntax, looks and feels like PyTorch.
- Model training.
- Embed user-defined ops/kernels, such as flash-attention v2.
- Backends.
- Optimized CPU backend with optional MKL support for x86 and Accelerate for macs.
- CUDA backend for efficiently running on GPUs, multiple GPU distribution via NCCL.
- WASM support, run your models in a browser.
- Included models.
- Language Models.
- LLaMA v1 and v2 with variants such as SOLAR-10.7B.
- Falcon.
- StarCoder, StarCoder2.
- Phi 1, 1.5, and 2.
- Mamba, Minimal Mamba
- Gemma 2b and 7b.
- Mistral 7b v0.1.
- Mixtral 8x7b v0.1.
- StableLM-3B-4E1T, StableLM-2-1.6B, Stable-Code-3B.
- Replit-code-v1.5-3B.
- Bert.
- Yi-6B and Yi-34B.
- Qwen1.5.
- RWKV v5 and v6.
- Quantized LLMs.
- Llama 7b, 13b, 70b, as well as the chat and code variants.
- Mistral 7b, and 7b instruct.
- Mixtral 8x7b.
- Zephyr 7b a and b (Mistral-7b based).
- OpenChat 3.5 (Mistral-7b based).
- Text to text.
- T5 and its variants: FlanT5, UL2, MADLAD400 (translation), CoEdit (Grammar correction).
- Marian MT (Machine Translation).
- Text to image.
- Stable Diffusion v1.5, v2.1, XL v1.0.
- Wurstchen v2.
- Image to text.
- BLIP.
- TrOCR.
- Audio.
- Whisper, multi-lingual speech-to-text.
- EnCodec, audio compression model.
- MetaVoice-1B, text-to-speech model.
- Computer Vision Models.
- DINOv2, ConvMixer, EfficientNet, ResNet, ViT, VGG, RepVGG, ConvNeXT, ConvNeXTv2, MobileOne, EfficientVit (MSRA).
- yolo-v3, yolo-v8.
- Segment-Anything Model (SAM).
- SegFormer.
- Language Models.
- File formats: load models from safetensors, npz, ggml, or PyTorch files.
- Serverless (on CPU), small and fast deployments.
- Quantization support using the llama.cpp quantized types.
Cheatsheet:
Using PyTorch | Using Candle | |
---|---|---|
Creation | torch.Tensor([[1, 2], [3, 4]]) |
Tensor::new(&[[1f32, 2.], [3., 4.]], &Device::Cpu)? |
Creation | torch.zeros((2, 2)) |
Tensor::zeros((2, 2), DType::F32, &Device::Cpu)? |
Indexing | tensor[:, :4] |
tensor.i((.., ..4))? |
Operations | tensor.view((2, 2)) |
tensor.reshape((2, 2))? |
Operations | a.matmul(b) |
a.matmul(&b)? |
Arithmetic | a + b |
&a + &b |
Device | tensor.to(device="cuda") |
tensor.to_device(&Device::new_cuda(0)?)? |
Dtype | tensor.to(dtype=torch.float16) |
tensor.to_dtype(&DType::F16)? |
Saving | torch.save({"A": A}, "model.bin") |
candle::safetensors::save(&HashMap::from([("A", A)]), "model.safetensors")? |
Loading | weights = torch.load("model.bin") |
candle::safetensors::load("model.safetensors", &device) |
- candle-core: Core ops, devices, and
Tensor
struct definition - candle-nn: Tools to build real models
- candle-examples: Examples of using the library in realistic settings
- candle-kernels: CUDA custom kernels
- candle-datasets: Datasets and data loaders.
- candle-transformers: transformers-related utilities.
- candle-flash-attn: Flash attention v2 layer.
- candle-onnx: ONNX model evaluation.
Candle's core goal is to make serverless inference possible. Full machine learning frameworks like PyTorch are very large, which makes creating instances on a cluster slow. Candle allows deployment of lightweight binaries.
Secondly, Candle lets you remove Python from production workloads. Python overhead can seriously hurt performance, and the GIL is a notorious source of headaches.
Finally, Rust is cool! A lot of the HF ecosystem already has Rust crates, like safetensors and tokenizers.
-
dfdx is a formidable crate, with shapes being included in types. This prevents a lot of headaches by getting the compiler to complain about shape mismatches right off the bat. However, we found that some features still require nightly, and writing code can be a bit daunting for non rust experts.
We're leveraging and contributing to other core crates for the runtime so hopefully both crates can benefit from each other.
-
burn is a general crate that can leverage multiple backends so you can choose the best engine for your workload.
-
tch-rs Bindings to the torch library in Rust. Extremely versatile, but they bring in the entire torch library into the runtime. The main contributor of
tch-rs
is also involved in the development ofcandle
.
If you get some missing symbols when compiling binaries/tests using the mkl or accelerate features, e.g. for mkl you get:
= note: /usr/bin/ld: (....o): in function `blas::sgemm':
.../blas-0.22.0/src/lib.rs:1944: undefined reference to `sgemm_' collect2: error: ld returned 1 exit status
= note: some `extern` functions couldn't be found; some native libraries may need to be installed or have their path specified
= note: use the `-l` flag to specify native libraries to link
= note: use the `cargo:rustc-link-lib` directive to specify the native libraries to link with Cargo
or for accelerate:
Undefined symbols for architecture arm64:
"_dgemm_", referenced from:
candle_core::accelerate::dgemm::h1b71a038552bcabe in libcandle_core...
"_sgemm_", referenced from:
candle_core::accelerate::sgemm::h2cf21c592cba3c47 in libcandle_core...
ld: symbol(s) not found for architecture arm64
This is likely due to a missing linker flag that was needed to enable the mkl library. You can try adding the following for mkl at the top of your binary:
extern crate intel_mkl_src;
or for accelerate:
extern crate accelerate_src;
Error: request error: https://huggingface.co/meta-llama/Llama-2-7b-hf/resolve/main/tokenizer.json: status code 401
This is likely because you're not permissioned for the LLaMA-v2 model. To fix this, you have to register on the huggingface-hub, accept the LLaMA-v2 model conditions, and set up your authentication token. See issue #350 for more details.
In file included from kernels/flash_fwd_launch_template.h:11:0,
from kernels/flash_fwd_hdim224_fp16_sm80.cu:5:
kernels/flash_fwd_kernel.h:8:10: fatal error: cute/algorithm/copy.hpp: No such file or directory
#include <cute/algorithm/copy.hpp>
^~~~~~~~~~~~~~~~~~~~~~~~~
compilation terminated.
Error: nvcc error while compiling:
cutlass is provided as a git submodule so you may want to run the following command to check it in properly.
git submodule update --init
/usr/include/c++/11/bits/std_function.h:530:146: error: parameter packs not expanded with ‘...’:
This is a bug in gcc-11 triggered by the Cuda compiler. To fix this, install a different, supported gcc version - for example gcc-10, and specify the path to the compiler in the CANDLE_NVCC_CCBIN environment variable.
env CANDLE_NVCC_CCBIN=/usr/lib/gcc/x86_64-linux-gnu/10 cargo ...
Couldn't compile the test.
---- .\candle-book\src\inference\hub.md - Using_the_hub::Using_in_a_real_model_ (line 50) stdout ----
error: linking with `link.exe` failed: exit code: 1181
//very long chain of linking
= note: LINK : fatal error LNK1181: cannot open input file 'windows.0.48.5.lib'
Make sure you link all native libraries that might be located outside a project target, e.g., to run mdbook tests, you should run:
mdbook test candle-book -L .\target\debug\deps\ `
-L native=$env:USERPROFILE\.cargo\registry\src\index.crates.io-6f17d22bba15001f\windows_x86_64_msvc-0.42.2\lib `
-L native=$env:USERPROFILE\.cargo\registry\src\index.crates.io-6f17d22bba15001f\windows_x86_64_msvc-0.48.5\lib
This may be caused by the models being loaded from /mnt/c
, more details on
stackoverflow.
You can set RUST_BACKTRACE=1
to be provided with backtraces when a candle
error is generated.