Skip to content

Commit 1719847

Browse files
committed
Update linting configurations and improve code formatting in deepseek_v32 example scripts
- Added per-file ignores for the inference directory in `pyproject.toml`. - Refactored code in `topk_selector.py`, `convert.py`, `generate.py`, `kernel.py`, and `model.py` to enhance readability by adjusting spacing and line breaks. - Ensured consistent formatting across function definitions and assertions for better clarity.
1 parent 453d442 commit 1719847

File tree

6 files changed

+193
-122
lines changed

6 files changed

+193
-122
lines changed

examples/deepseek_v32/inference/convert.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import torch
88
from safetensors.torch import safe_open, save_file
99

10-
1110
mapping = {
1211
"embed_tokens": ("embed", 0),
1312
"input_layernorm": ("attn_norm", None),
@@ -74,7 +73,8 @@ def main(hf_ckpt_path, save_path, n_experts, mp):
7473
if idx < i * n_local_experts or idx >= (i + 1) * n_local_experts:
7574
continue
7675
elif dim is not None:
77-
assert param.size(dim) % mp == 0, f"Dimension {dim} must be divisible by {mp}"
76+
assert param.size(
77+
dim) % mp == 0, f"Dimension {dim} must be divisible by {mp}"
7878
shard_size = param.size(dim) // mp
7979
new_param = param.narrow(dim, i * shard_size, shard_size).contiguous()
8080
state_dicts[i][name] = new_param

examples/deepseek_v32/inference/generate.py

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,11 @@ def sample(logits, temperature: float = 1.0):
2828

2929

3030
@torch.inference_mode()
31-
def generate(
32-
model: Transformer,
33-
prompt_tokens: List[List[int]],
34-
max_new_tokens: int,
35-
eos_id: int,
36-
temperature: float = 1.0
37-
) -> List[List[int]]:
31+
def generate(model: Transformer,
32+
prompt_tokens: List[List[int]],
33+
max_new_tokens: int,
34+
eos_id: int,
35+
temperature: float = 1.0) -> List[List[int]]:
3836
"""
3937
Generates new tokens based on the given prompt tokens using the specified model.
4038
@@ -49,7 +47,9 @@ def generate(
4947
List[List[int]]: A list of lists containing the generated tokens for each sequence.
5048
"""
5149
prompt_lens = [len(t) for t in prompt_tokens]
52-
assert max(prompt_lens) <= model.max_seq_len, f"Prompt length exceeds model maximum sequence length (max_seq_len={model.max_seq_len})"
50+
assert max(
51+
prompt_lens
52+
) <= model.max_seq_len, f"Prompt length exceeds model maximum sequence length (max_seq_len={model.max_seq_len})"
5353
total_len = min(model.max_seq_len, max_new_tokens + max(prompt_lens))
5454
tokens = torch.full((len(prompt_tokens), total_len), -1, dtype=torch.long, device="cuda")
5555
for i, t in enumerate(prompt_tokens):
@@ -71,7 +71,7 @@ def generate(
7171
break
7272
completion_tokens = []
7373
for i, toks in enumerate(tokens.tolist()):
74-
toks = toks[prompt_lens[i]:prompt_lens[i]+max_new_tokens]
74+
toks = toks[prompt_lens[i]:prompt_lens[i] + max_new_tokens]
7575
if eos_id in toks:
7676
toks = toks[:toks.index(eos_id)]
7777
completion_tokens.append(toks)
@@ -139,16 +139,26 @@ def main(
139139
continue
140140
messages.append({"role": "user", "content": prompt})
141141
prompt_tokens = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
142-
completion_tokens = generate(model, [prompt_tokens], max_new_tokens, tokenizer.eos_token_id, temperature)
142+
completion_tokens = generate(model, [prompt_tokens], max_new_tokens,
143+
tokenizer.eos_token_id, temperature)
143144
completion = tokenizer.decode(completion_tokens[0], skip_special_tokens=True)
144145
print(completion)
145146
messages.append({"role": "assistant", "content": completion})
146147
else:
147148
with open(input_file) as f:
148149
prompts = f.read().split("\n\n")
149-
assert len(prompts) <= args.max_batch_size, f"Number of prompts exceeds maximum batch size ({args.max_batch_size})"
150-
prompt_tokens = [tokenizer.apply_chat_template([{"role": "user", "content": prompt}], add_generation_prompt=True) for prompt in prompts]
151-
completion_tokens = generate(model, prompt_tokens, max_new_tokens, tokenizer.eos_token_id, temperature)
150+
assert len(
151+
prompts
152+
) <= args.max_batch_size, f"Number of prompts exceeds maximum batch size ({args.max_batch_size})"
153+
prompt_tokens = [
154+
tokenizer.apply_chat_template([{
155+
"role": "user",
156+
"content": prompt
157+
}],
158+
add_generation_prompt=True) for prompt in prompts
159+
]
160+
completion_tokens = generate(model, prompt_tokens, max_new_tokens, tokenizer.eos_token_id,
161+
temperature)
152162
completions = tokenizer.batch_decode(completion_tokens, skip_special_tokens=True)
153163
for prompt, completion in zip(prompts, completions):
154164
print("Prompt:", prompt)
@@ -183,4 +193,5 @@ def main(
183193
parser.add_argument("--temperature", type=float, default=0.6)
184194
args = parser.parse_args()
185195
assert args.input_file or args.interactive, "Either input-file or interactive mode must be specified"
186-
main(args.ckpt_path, args.config, args.input_file, args.interactive, args.max_new_tokens, args.temperature)
196+
main(args.ckpt_path, args.config, args.input_file, args.interactive, args.max_new_tokens,
197+
args.temperature)

examples/deepseek_v32/inference/kernel.py

Lines changed: 19 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import tilelang.language as T
44
from typing import Tuple, Optional
55

6-
76
tilelang.set_log_level("WARNING")
87

98
pass_configs = {
@@ -34,9 +33,7 @@ def fast_round_scale(amax, fp8_max_inv):
3433

3534

3635
@tilelang.jit(pass_configs=pass_configs)
37-
def act_quant_kernel(
38-
N, in_dtype=BF16, out_dtype=FP8, scale_dtype=FP32, round_scale=False
39-
):
36+
def act_quant_kernel(N, in_dtype=BF16, out_dtype=FP8, scale_dtype=FP32, round_scale=False):
4037
M = T.symbolic("M")
4138
fp8_min = -448.0
4239
fp8_max = 448.0
@@ -51,10 +48,11 @@ def act_quant_kernel_(
5148
Y: T.Tensor[(M, N), out_dtype],
5249
S: T.Tensor[(M, T.ceildiv(N, group_size)), scale_dtype],
5350
):
54-
with T.Kernel(T.ceildiv(M, blk_m), T.ceildiv(N, group_size), threads=128) as (
55-
pid_m,
56-
pid_n,
57-
):
51+
with T.Kernel(
52+
T.ceildiv(M, blk_m), T.ceildiv(N, group_size), threads=128) as (
53+
pid_m,
54+
pid_n,
55+
):
5856
x_shared = T.alloc_shared((blk_m, group_size), in_dtype)
5957
x_local = T.alloc_fragment((blk_m, group_size), in_dtype)
6058
amax_local = T.alloc_fragment((blk_m,), scale_dtype)
@@ -73,9 +71,7 @@ def act_quant_kernel_(
7371
else:
7472
s_local[i] = amax_local[i] * fp8_max_inv
7573
for i, j in T.Parallel(blk_m, group_size):
76-
y_local[i, j] = T.clamp(
77-
x_local[i, j] / s_local[i], fp8_min, fp8_max
78-
)
74+
y_local[i, j] = T.clamp(x_local[i, j] / s_local[i], fp8_min, fp8_max)
7975
for i in T.Parallel(blk_m):
8076
S[pid_m * blk_m + i, pid_n] = s_local[i]
8177
T.copy(y_local, y_shared)
@@ -84,9 +80,9 @@ def act_quant_kernel_(
8480
return act_quant_kernel_
8581

8682

87-
def act_quant(
88-
x: torch.Tensor, block_size: int = 128, scale_fmt: Optional[str] = None
89-
) -> Tuple[torch.Tensor, torch.Tensor]:
83+
def act_quant(x: torch.Tensor,
84+
block_size: int = 128,
85+
scale_fmt: Optional[str] = None) -> Tuple[torch.Tensor, torch.Tensor]:
9086
"""
9187
Quantizes the input tensor `x` using block-wise quantization.
9288
@@ -101,8 +97,7 @@ def act_quant(
10197
"""
10298
assert x.is_contiguous(), "Input tensor must be contiguous"
10399
assert x.size(-1) % block_size == 0, (
104-
f"Last dimension size must be divisible by block_size (block_size={block_size})"
105-
)
100+
f"Last dimension size must be divisible by block_size (block_size={block_size})")
106101
N = x.size(-1)
107102
y = torch.empty_like(x, dtype=torch.float8_e4m3fn)
108103
s = x.new_empty(*x.size()[:-1], N // block_size, dtype=torch.float32)
@@ -129,10 +124,11 @@ def fp8_gemm_kernel_(
129124
scales_a: T.Tensor[(M, T.ceildiv(K, group_size)), FP32],
130125
scales_b: T.Tensor[(T.ceildiv(N, group_size), T.ceildiv(K, group_size)), FP32],
131126
):
132-
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (
133-
bx,
134-
by,
135-
):
127+
with T.Kernel(
128+
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (
129+
bx,
130+
by,
131+
):
136132
A_shared = T.alloc_shared((block_M, block_K), FP8)
137133
B_shared = T.alloc_shared((block_N, block_K), FP8)
138134
C_shared = T.alloc_shared((block_M, block_N), out_dtype)
@@ -168,9 +164,8 @@ def fp8_gemm_kernel_(
168164
return fp8_gemm_kernel_
169165

170166

171-
def fp8_gemm(
172-
a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Tensor
173-
) -> torch.Tensor:
167+
def fp8_gemm(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor,
168+
b_s: torch.Tensor) -> torch.Tensor:
174169
"""
175170
Perform a matrix multiplication using FP8 precision.
176171
@@ -185,8 +180,7 @@ def fp8_gemm(
185180
"""
186181
assert a.is_contiguous() and b.is_contiguous(), "Input tensors must be contiguous"
187182
assert a_s.is_contiguous() and b_s.is_contiguous(), (
188-
"Scaling factor tensors must be contiguous"
189-
)
183+
"Scaling factor tensors must be contiguous")
190184
K = a.size(-1)
191185
M = a.numel() // K
192186
N = b.size(0)

0 commit comments

Comments
 (0)