Skip to content

Commit 4e59ef2

Browse files
committed
stuff
1 parent d591be1 commit 4e59ef2

File tree

4 files changed

+105
-10
lines changed

4 files changed

+105
-10
lines changed

scripts/run_benchmark_breakdown.sh

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,9 @@ run () { # run(step, runtime, attn)
5656
then
5757
echo "Skipping existing $FILE_NAME"
5858
else
59-
$RUN $COMMON ${RUNTIME[$2]} ${ATTN[$3]} ${STEP[$1]} --save="$FILE_NAME"
59+
CMD="$RUN $COMMON ${RUNTIME[$2]} ${ATTN[$3]} ${STEP[$1]} --save=$FILE_NAME"
60+
echo "$CMD"
61+
$CMD
6062
fi
6163
}
6264

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
2+
# Santacoder prefill.
3+
# ./scripts/run_benchmark_breakdown.sh santacoder bigcode/gpt_bigcode-santacoder 32 2040 5 0
4+
# Santacoder decode (fewer data points because slower)
5+
# ./scripts/run_benchmark_breakdown.sh santacoder bigcode/gpt_bigcode-santacoder 32 2040 11 1
6+
MODEL_NAME=${1:-"santacoder"}
7+
MODEL_PATH=${2:-"bigcode/gpt_bigcode-santacoder"}
8+
BATCH_SIZE=${3:-32}
9+
MAX_NEW_TOKENS=${4:-2040}
10+
# Prime number to see key length padding effect.
11+
TOKEN_STEP=${5:-5}
12+
STEP_ID=${6:-""}
13+
FILE_PREFIX=${7:-""}
14+
CYCLES=${8:-10}
15+
16+
SAVE_DIR=data/benchmarks/v2
17+
#BATCH_SIZES="1 2 4 8 16 24 32 48 64 96 128 160 224 256"
18+
RUN="python3 src/main.py --max_log_outputs=0 --dtype=float16 --device=cuda --custom_generate --breakdown_latency --ignore_oom"
19+
20+
21+
RUNTIME=("")
22+
RUNTIME_NAMES=("base")
23+
24+
ATTN=( \
25+
"--pipeline_class=TG_Pipeline" \
26+
)
27+
ATTN_NAME=( \
28+
"textgen" \
29+
)
30+
31+
32+
STEP=("--no_prefill" "--no_cache")
33+
STEP_NAME=("decode" "prefill")
34+
35+
COMMON="--pretrained_model=$MODEL_PATH --tokenizer=$MODEL_PATH --cycles=$CYCLES --max_input_length=1 --max_new_tokens=$MAX_NEW_TOKENS --key_length_step=$TOKEN_STEP --batch_size=$BATCH_SIZE predict_last_token=True"
36+
37+
run () { # run(step, runtime, attn)
38+
FILE_NAME="$SAVE_DIR"/"$MODEL_NAME"_bs_"$BATCH_SIZE"_tok_"$MAX_NEW_TOKENS"_step_"$TOKEN_STEP"_"${STEP_NAME[$1]}"/"$FILE_PREFIX""${RUNTIME_NAMES[$2]}"_"${ATTN_NAME[$3]}".json
39+
if [ -f "$FILE_NAME" ];
40+
then
41+
echo "Skipping existing $FILE_NAME"
42+
else
43+
CMD="$RUN $COMMON ${RUNTIME[$2]} ${ATTN[$3]} ${STEP[$1]} --save=$FILE_NAME"
44+
echo "$CMD"
45+
$CMD
46+
fi
47+
}
48+
49+
if [ "${STEP_ID}" -eq "0" ]
50+
then
51+
# Decode (default attn only)
52+
run 0 0 0
53+
else
54+
# Prefill
55+
run 1 0 0
56+
fi

src/pipeline.py

Lines changed: 41 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@
2121
GPTBigCodeConfig,
2222
)
2323

24+
from transformers.modeling_outputs import (
25+
CausalLMOutputWithCrossAttentions,
26+
)
27+
2428

2529
logger = logging.getLogger(__name__)
2630

@@ -413,7 +417,7 @@ def __init__(self, **kwargs):
413417

414418
super().__init__(**kwargs)
415419

416-
if self.device != torch.device("cuda"):
420+
if self.device != torch.device("cuda:0"):
417421
raise ValueError(f"Deepspeed does not support device {self.device}")
418422

419423
if self.dtype not in (torch.float32, torch.float16, torch.bfloat16):
@@ -433,10 +437,21 @@ def __init__(self, **kwargs):
433437

434438
class TextGenModelWrapper:
435439
def __init__(self, model):
440+
from text_generation_server.models import CausalLM, FlashCausalLM
441+
436442
self.model = model
443+
if isinstance(self.model, FlashCausalLM):
444+
self._is_flash = True
445+
elif isinstance(self.model, CausalLM):
446+
self._is_flash = False
447+
else:
448+
raise NotImplementedError()
437449

438450
def parameters(self):
439-
return self.model.parameters()
451+
return []
452+
453+
def eval(self):
454+
pass
440455

441456
def __call__(
442457
self,
@@ -447,16 +462,35 @@ def __call__(
447462
return_dict,
448463
use_cache,
449464
):
450-
return self.model(input_ids, attention_mask, position_ids, past_key_values)
465+
if self._is_flash:
466+
raise NotImplementedError()
467+
logits, past_key_values = self.model.forward(
468+
input_ids,
469+
position_ids,
470+
cu_seqlens,
471+
max_s,
472+
past_key_values,
473+
pre_allocate_past_size,
474+
)
475+
else:
476+
logits, past_key_values = self.model.forward(input_ids, attention_mask, position_ids, past_key_values)
477+
return CausalLMOutputWithCrossAttentions(
478+
loss=None,
479+
logits=logits,
480+
past_key_values=past_key_values,
481+
hidden_states=None,
482+
attentions=None,
483+
cross_attentions=None,
484+
)
451485

452486

453487
class TG_Pipeline(Pipeline):
454488
def __init__(self, **kwargs):
455-
if self.device != torch.device("cuda"):
456-
raise ValueError(f"Textgen does not support device {self.device}")
457-
458489
super().__init__(**kwargs)
459490

491+
if self.device != torch.device("cuda:0"):
492+
raise ValueError(f"Textgen does not support device {self.device}")
493+
460494
def _get_config(
461495
self,
462496
model_type: Optional[str],
@@ -475,7 +509,7 @@ def _save_pretrained(self, pretrained_model: str):
475509
raise NotImplementedError()
476510

477511
def _load_pretrained(self, pretrained_model: str):
478-
from text_generation_server import get_model
512+
from text_generation_server.models import get_model
479513

480514
pretrained_model, revision = parse_revision(pretrained_model)
481515
return TextGenModelWrapper(get_model(pretrained_model, revision, False, False))

src/utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,8 +149,11 @@ def get_inputs_from_tokens(tokens, length, tokenizer):
149149
raise RuntimeError("Failed to generate stable input sequences")
150150

151151

152-
def get_random_inputs(length, tokenizer, random_state):
153-
return get_inputs_from_tokens(random_state.randint(0, tokenizer.vocab_size, length).tolist(), length, tokenizer)
152+
def get_random_inputs(lengths, tokenizer, random_state):
153+
return [
154+
get_inputs_from_tokens(random_state.randint(0, tokenizer.vocab_size, length).tolist(), length, tokenizer)
155+
for length in lengths
156+
]
154157

155158

156159
def get_inputs_from_files(files: List[Path], lengths, tokenizer, random_state):

0 commit comments

Comments
 (0)