Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
/target
/*.txt
13 changes: 7 additions & 6 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 3 additions & 2 deletions llama.cu/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ cuda = { git = "https://github.com/YdrMaster/cuda-driver", rev = "6a97931" }
cublas = { git = "https://github.com/YdrMaster/cuda-driver", rev = "6a97931" }
nccl = { git = "https://github.com/YdrMaster/cuda-driver", rev = "6a97931" }
flash-attn = { git = "https://github.com/YdrMaster/learn-flash-attn", rev = "616bbac" }
nn = { git = "https://github.com/YdrMaster/InfiniNN", rev = "fa8aaf6" }
nn = { git = "https://github.com/CearX/InfiniNN", rev = "6caef2"}
ggus = { git = "https://github.com/InfiniTensor/gguf", rev = "23c362f" }
tokeneer = { git = "https://github.com/InfiniTensor/tokeneer", rev = "c48f39f" }
tokeneer = { git = "https://github.com/CearX/tokeneer.git", rev = "2546d72" }

bytesize = "2.0"
log.workspace = true
Expand All @@ -19,6 +19,7 @@ serde.workspace = true
memmap2 = "0.9"
lru = "0.14"
rand = "0.9"
half = "2.3"
minijinja = { version = "2.11", default-features = false, features = [
"loader",
"builtins",
Expand Down
134 changes: 133 additions & 1 deletion llama.cu/src/exec/group.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
use super::mamba_cache::MambaCache;
use super::{CacheParts, Progress, model::ModelExec, upos};
use crate::{batch::Req, handle::Handle, load::load_weight, memory::MemPages};
use cuda::{DevByte, DevMem, Stream, VirByte};
use nn::{
Distribution, Graph, GraphBuilder, LLaMA, NNGraph, Tensor, TensorMeta, digit_layout::types, op,
Distribution, Graph, GraphBuilder, LLaMA, Mamba, NNGraph, Tensor, TensorMeta,
digit_layout::types, op,
};
use std::{
collections::BTreeMap,
Expand Down Expand Up @@ -240,3 +242,133 @@ fn builder() -> GraphBuilder {
.register_op("all-reduce", op::all_reduce::AllReduce);
ans
}

// Mamba GraphBuilder
fn builder_mamba() -> GraphBuilder {
let mut ans = GraphBuilder::default();
ans.register_op("embedding", op::embedding::Embedding)
.register_op("rms-norm", op::normalization::RmsNorm)
.register_op("linear", op::linear::Linear)
.register_op("silu", op::activation::SiLU)
.register_op("element-mul", op::element_mul::ElementMul)
.register_op("split", op::split::Split)
.register_op("mamba-causal-conv1d", op::mamba::CausalConv1d)
.register_op("mamba-selective-scan", op::mamba::SelectiveScan);
ans
}

pub(crate) struct ModelGroupMamba<'ctx> {
internal: Internal<'ctx>,
pages: MemPages,
_weight: DevMem<'ctx>,
next_pos: u32,
}

impl<'ctx> ModelGroupMamba<'ctx> {
pub fn new<T: IntoIterator<Item = usize>>(
mamba: Mamba<Tensor<&[u8], 2>>,
dist: Distribution,
progress: Option<Arc<Progress>>,
config: ModelGroupConfig<T>,
handle: &mut Handle<'ctx>,
barrier: Option<&Barrier>,
) -> Self {
let ModelGroupConfig {
static_model_keys,
mut dyn_cache_size,
use_cuda_graph,
} = config;

let NNGraph(Graph { topo, nodes, edges }) = builder_mamba()
.build(
mamba.tensor_parallel(dist),
[
TensorMeta::new(types::U32, ["n_tok".into()]),
TensorMeta::new(types::U32, ["n_tok".into()]),
TensorMeta::new(types::U32, ["n_tok".into()]),
],
)
.unwrap();
handle.ctx.stream().synchronize();

let dev = handle.ctx.dev();
let mut pages = MemPages::new(dev);
let (_weight, edges) = load_weight(edges, progress, handle.ctx);
let graph = NNGraph(Graph { topo, nodes, edges });
let static_models = if use_cuda_graph {
static_model_keys
.into_iter()
.map(|n_tok| {
if let Some(b) = barrier {
b.wait();
}
let key = NonZeroUsize::new(n_tok).unwrap();
let exec = ModelExec::new(graph.clone(), n_tok, handle, &mut pages, true);
(key, exec)
})
.collect::<BTreeMap<_, _>>()
} else {
dyn_cache_size += static_model_keys.into_iter().count();
Default::default()
};

let internal = Internal::new(graph, static_models, dyn_cache_size);
Self {
internal,
pages,
_weight,
next_pos: 0,
}
}

pub fn load_inputs_mamba_prefill(
&mut self,
handle: &mut Handle<'ctx>,
len: usize,
tok: &[utok],
stream: &Stream<'ctx>,
) -> (NonZeroUsize, &mut [DevByte]) {
let key = self.internal.get_key(NonZeroUsize::new(len).unwrap());
let model = self.internal.map_exec(key, handle, &mut self.pages, stream);
stream.memcpy_h2d(model.tok_buf(), &tok[..key.get()]);
let pos: Vec<upos> = (0..key.get()).map(|i| i as upos).collect();
stream.memcpy_h2d(model.pos_buf(), &pos);
self.next_pos = key.get() as u32;
let out_idx: Vec<utok> = (0..key.get()).map(|i| i as utok).collect();
let buf = model.input_buf_at(2);
stream.memcpy_h2d(buf, &out_idx);
(key, model.tok_buf())
}

pub fn load_input_mamba_decode(
&mut self,
handle: &mut Handle<'ctx>,
tok: utok,
stream: &Stream<'ctx>,
) -> (NonZeroUsize, &mut [DevByte]) {
let key = self.internal.get_key(NonZeroUsize::new(1).unwrap());
let model = self.internal.map_exec(key, handle, &mut self.pages, stream);
let tok_buf = model.tok_buf();
stream.memcpy_h2d(tok_buf, &[tok]);
let pos_buf = model.pos_buf();
let cur = self.next_pos;
stream.memcpy_h2d(pos_buf, &[cur]);
// 更新 next_pos
self.next_pos = cur.saturating_add(1);
// decode 时 out_idx 固定为 0
let out_idx_buf = model.input_buf_at(2);
stream.memcpy_h2d(out_idx_buf, &[0u32]);
(key, model.tok_buf())
}

pub fn launch_mamba(
&mut self,
key: NonZeroUsize,
cache: &mut MambaCache,
handle: &mut Handle,
stream: &Stream<'ctx>,
) -> Tensor<*const VirByte, 2> {
let model = self.internal.get_mut(&key).unwrap();
model.launch_with_mamba_cache(handle, cache, stream)
}
}
Loading