Skip to content

Commit ab1d992

Browse files
authored
[TorchAcc] cache the compiled results and remove some xla flags (#1160)
1 parent 2956815 commit ab1d992

File tree

15 files changed

+55
-25
lines changed

15 files changed

+55
-25
lines changed

examples/pytorch/llm/scripts/torchacc/baichuan2_13b_chat/acc_lora_dp_sft.sh

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@
33
# Note: TorchAcc is currently only available internally.
44
# torchacc dp
55
export USE_TORCHACC=1
6-
export XLA_FLAGS='--xla_gpu_force_compilation_parallelism=32 --xla_multiheap_size_constraint_per_heap=4831838208 --xla_disable_hlo_passes=all-gather-combiner,all-reduce-combiner,reduce-scatter-combiner,gpu-convert-async-collectives-to-sync,rematerialization'
76
export XLA_IR_SHAPE_CACHE_SIZE=100000000
87
export XLA_ALLOCATOR_FRACTION=0.95
98
export XLA_EXPERIMENTAL=nonzero:masked_select
109

10+
export XLA_PERSISTENT_CACHE_PATH=./output/compiled_cache/Baichuan2-13B-Chat
11+
mkdir -p $XLA_PERSISTENT_CACHE_PATH
12+
1113
NPROC_PER_NODE=2 \
1214
CUDA_VISIBLE_DEVICES=0,1 \
1315
MASTER_PORT=27829 \

examples/pytorch/llm/scripts/torchacc/baichuan2_13b_chat/acc_lora_fsdp_sft.sh

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@
33
# Note: TorchAcc is currently only available internally.
44
# torchacc fsdp
55
export USE_TORCHACC=1
6-
export XLA_FLAGS='--xla_gpu_force_compilation_parallelism=32 --xla_multiheap_size_constraint_per_heap=4831838208 --xla_disable_hlo_passes=all-gather-combiner,all-reduce-combiner,reduce-scatter-combiner,gpu-convert-async-collectives-to-sync,rematerialization'
76
export XLA_IR_SHAPE_CACHE_SIZE=100000000
87
export XLA_ALLOCATOR_FRACTION=0.95
98
export XLA_EXPERIMENTAL=nonzero:masked_select
109

10+
export XLA_PERSISTENT_CACHE_PATH=./output/compiled_cache/Baichuan2-13B-Chat
11+
mkdir -p $XLA_PERSISTENT_CACHE_PATH
12+
1113
NPROC_PER_NODE=2 \
1214
CUDA_VISIBLE_DEVICES=0,1 \
1315
swift sft \

examples/pytorch/llm/scripts/torchacc/chatglm3_6b/acc_lora_dp_sft.sh

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@
33
# Note: TorchAcc is currently only available internally.
44
# torchacc dp
55
export USE_TORCHACC=1
6-
export XLA_FLAGS='--xla_gpu_force_compilation_parallelism=32 --xla_multiheap_size_constraint_per_heap=4831838208 --xla_disable_hlo_passes=all-gather-combiner,all-reduce-combiner,reduce-scatter-combiner,gpu-convert-async-collectives-to-sync,rematerialization'
76
export XLA_IR_SHAPE_CACHE_SIZE=100000000
87
export XLA_ALLOCATOR_FRACTION=0.95
98
export XLA_EXPERIMENTAL=nonzero:masked_select
109

10+
export XLA_PERSISTENT_CACHE_PATH=./output/compiled_cache/chatglm3-6b
11+
mkdir -p $XLA_PERSISTENT_CACHE_PATH
12+
1113

1214
NPROC_PER_NODE=2 \
1315
CUDA_VISIBLE_DEVICES=0,1 \

examples/pytorch/llm/scripts/torchacc/chatglm3_6b/acc_lora_fsdp_sft.sh

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@
33
# Note: TorchAcc is currently only available internally.
44
# torchacc fsdp
55
export USE_TORCHACC=1
6-
export XLA_FLAGS='--xla_gpu_force_compilation_parallelism=32 --xla_multiheap_size_constraint_per_heap=4831838208 --xla_disable_hlo_passes=all-gather-combiner,all-reduce-combiner,reduce-scatter-combiner,gpu-convert-async-collectives-to-sync,rematerialization'
76
export XLA_IR_SHAPE_CACHE_SIZE=100000000
87
export XLA_ALLOCATOR_FRACTION=0.95
98
export XLA_EXPERIMENTAL=nonzero:masked_select
109

10+
export XLA_PERSISTENT_CACHE_PATH=./output/compiled_cache/chatglm3-6b
11+
mkdir -p $XLA_PERSISTENT_CACHE_PATH
12+
1113

1214
NPROC_PER_NODE=2 \
1315
CUDA_VISIBLE_DEVICES=0,1 \

examples/pytorch/llm/scripts/torchacc/llama2_13b_chat/acc_lora_dp_sft.sh

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,13 @@
44

55
export USE_TORCHACC=1
66
export TORCHACC_TRIM_GRAPH=1
7-
export XLA_FLAGS='--xla_gpu_force_compilation_parallelism=32 --xla_multiheap_size_constraint_per_heap=4831838208 --xla_disable_hlo_passes=all-gather-combiner,all-reduce-combiner,reduce-scatter-combiner,gpu-convert-async-collectives-to-sync,rematerialization'
87
export XLA_IR_SHAPE_CACHE_SIZE=100000000
98
export XLA_ALLOCATOR_FRACTION=0.95
109
export XLA_EXPERIMENTAL=nonzero:masked_select
1110

11+
export XLA_PERSISTENT_CACHE_PATH=./output/compiled_cache/Llama-2-13b-chat-ms
12+
mkdir -p $XLA_PERSISTENT_CACHE_PATH
13+
1214
NPROC_PER_NODE=2 \
1315
CUDA_VISIBLE_DEVICES=0,1 \
1416
swift sft \
@@ -20,7 +22,7 @@ swift sft \
2022
--output_dir output \
2123
--num_train_epochs 1 \
2224
--max_length 2048 \
23-
--batch_size 16 \
25+
--batch_size 14 \
2426
--use_flash_attn true \
2527
--gradient_accumulation_steps 1 \
2628
--gradient_checkpointing no \

examples/pytorch/llm/scripts/torchacc/llama2_13b_chat/acc_lora_fsdp_sft.sh

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@
33
# Note: TorchAcc is currently only available internally.
44
export USE_TORCHACC=1
55
export TORCHACC_TRIM_GRAPH=1
6-
export XLA_FLAGS='--xla_gpu_force_compilation_parallelism=32 --xla_multiheap_size_constraint_per_heap=4831838208 --xla_disable_hlo_passes=all-gather-combiner,all-reduce-combiner,reduce-scatter-combiner,gpu-convert-async-collectives-to-sync,rematerialization'
76
export XLA_IR_SHAPE_CACHE_SIZE=100000000
87
export XLA_ALLOCATOR_FRACTION=0.95
98
export XLA_EXPERIMENTAL=nonzero:masked_select
109

10+
export XLA_PERSISTENT_CACHE_PATH=./output/compiled_cache/Llama-2-13b-chat-ms
11+
mkdir -p $XLA_PERSISTENT_CACHE_PATH
12+
1113
NPROC_PER_NODE=2 \
1214
CUDA_VISIBLE_DEVICES=0,1 \
1315
MASTER_PORT=27829 \
@@ -20,7 +22,7 @@ swift sft \
2022
--output_dir output \
2123
--num_train_epochs 1 \
2224
--max_length 2048 \
23-
--batch_size 24 \
25+
--batch_size 20 \
2426
--use_flash_attn true \
2527
--gradient_accumulation_steps 1 \
2628
--gradient_checkpointing no \

examples/pytorch/llm/scripts/torchacc/llama3_8b_instruct/acc_lora_dp_sft.sh

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,12 @@
44

55
export USE_TORCHACC=1
66
export TORCHACC_TRIM_GRAPH=1
7-
export XLA_FLAGS='--xla_gpu_force_compilation_parallelism=32 --xla_multiheap_size_constraint_per_heap=4831838208 --xla_disable_hlo_passes=all-gather-combiner,all-reduce-combiner,reduce-scatter-combiner,gpu-convert-async-collectives-to-sync,rematerialization'
87
export XLA_IR_SHAPE_CACHE_SIZE=100000000
98
export XLA_ALLOCATOR_FRACTION=0.95
109
export XLA_EXPERIMENTAL=nonzero:masked_select
11-
export XLA_COORDINATOR_PORT=12457
10+
11+
export XLA_PERSISTENT_CACHE_PATH=./output/compiled_cache/Meta-Llama-3-8B-Instruct
12+
mkdir -p $XLA_PERSISTENT_CACHE_PATH
1213

1314
NPROC_PER_NODE=2 \
1415
CUDA_VISIBLE_DEVICES=0,1 \

examples/pytorch/llm/scripts/torchacc/llama3_8b_instruct/acc_lora_fsdp_sft.sh

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@
33
# Note: TorchAcc is currently only available internally.
44
export USE_TORCHACC=1
55
export TORCHACC_TRIM_GRAPH=1
6-
export XLA_FLAGS='--xla_gpu_force_compilation_parallelism=32 --xla_multiheap_size_constraint_per_heap=4831838208 --xla_disable_hlo_passes=all-gather-combiner,all-reduce-combiner,reduce-scatter-combiner,gpu-convert-async-collectives-to-sync,rematerialization'
76
export XLA_IR_SHAPE_CACHE_SIZE=100000000
87
export XLA_ALLOCATOR_FRACTION=0.95
98
export XLA_EXPERIMENTAL=nonzero:masked_select
10-
# export XLA_COORDINATOR_PORT=12457
9+
10+
export XLA_PERSISTENT_CACHE_PATH=./output/compiled_cache/Meta-Llama-3-8B-Instruct
11+
mkdir -p $XLA_PERSISTENT_CACHE_PATH
1112

1213
NPROC_PER_NODE=2 \
1314
CUDA_VISIBLE_DEVICES=0,1 \

examples/pytorch/llm/scripts/torchacc/qwen1half_14b_chat/acc_lora_dp_sft.sh

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,14 @@
22
# 80GB GPU memory
33
# Note: TorchAcc is currently only available internally.
44
export USE_TORCHACC=1
5-
# export TORCHACC_TRIM_GRAPH=1
6-
export XLA_FLAGS='--xla_gpu_force_compilation_parallelism=32 --xla_multiheap_size_constraint_per_heap=4831838208 --xla_disable_hlo_passes=all-gather-combiner,all-reduce-combiner,reduce-scatter-combiner,gpu-convert-async-collectives-to-sync,rematerialization'
5+
export TORCHACC_TRIM_GRAPH=1
76
export XLA_IR_SHAPE_CACHE_SIZE=1000000000
87
export XLA_ALLOCATOR_FRACTION=0.95
98
export XLA_EXPERIMENTAL=nonzero:masked_select
109

10+
export XLA_PERSISTENT_CACHE_PATH=./output/compiled_cache/qwen1half-14b-chat
11+
mkdir -p $XLA_PERSISTENT_CACHE_PATH
12+
1113
NPROC_PER_NODE=2 \
1214
CUDA_VISIBLE_DEVICES=2,3 \
1315
MASTER_PORT=23797 \

examples/pytorch/llm/scripts/torchacc/qwen1half_14b_chat/acc_lora_fsdp_sft.sh

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,13 @@ DEBUG_PREFIX=qwen15_14b
55
DEBUG_PATH=torchacc_debug/qwen15/
66
export USE_TORCHACC=1
77
# export TORCHACC_TRIM_GRAPH=1
8-
export XLA_FLAGS='--xla_gpu_force_compilation_parallelism=32 --xla_multiheap_size_constraint_per_heap=4831838208 --xla_disable_hlo_passes=all-gather-combiner,all-reduce-combiner,reduce-scatter-combiner,gpu-convert-async-collectives-to-sync,rematerialization'
98
export XLA_IR_SHAPE_CACHE_SIZE=1000000000
109
export XLA_ALLOCATOR_FRACTION=0.95
1110
export XLA_EXPERIMENTAL=nonzero:masked_select
11+
12+
export XLA_PERSISTENT_CACHE_PATH=./output/compiled_cache/qwen1half-14b-chat
13+
mkdir -p $XLA_PERSISTENT_CACHE_PATH
14+
1215
MASTER_PORT=23783 \
1316
NPROC_PER_NODE=2 \
1417
CUDA_VISIBLE_DEVICES=0,1 \

examples/pytorch/llm/scripts/torchacc/qwen1half_32b_chat/acc_lora_fsdp_sft.sh

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@
33
# Note: TorchAcc is currently only available internally.
44
export USE_TORCHACC=1
55
# export TORCHACC_TRIM_GRAPH=1
6-
export XLA_FLAGS='--xla_gpu_force_compilation_parallelism=32 --xla_multiheap_size_constraint_per_heap=4831838208 --xla_disable_hlo_passes=all-gather-combiner,all-reduce-combiner,reduce-scatter-combiner,gpu-convert-async-collectives-to-sync,rematerialization'
76
export XLA_IR_SHAPE_CACHE_SIZE=1000000000
87
export XLA_ALLOCATOR_FRACTION=0.95
98
export XLA_EXPERIMENTAL=nonzero:masked_select
109

10+
export XLA_PERSISTENT_CACHE_PATH=./output/compiled_cache/qwen1half-32b-chat
11+
mkdir -p $XLA_PERSISTENT_CACHE_PATH
12+
1113
NPROC_PER_NODE=4 \
1214
CUDA_VISIBLE_DEVICES=0,1,2,3 \
1315
swift sft \

examples/pytorch/llm/scripts/torchacc/qwen_72b_chat/acc_full_fsdp_sft.sh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@
33
# Note: TorchAcc is currently only available internally.
44

55
export USE_TORCHACC=1
6-
export XLA_FLAGS='--xla_multiheap_size_constraint_per_heap=4831838208 --xla_disable_hlo_passes=all-gather-combiner,all-reduce-combiner,reduce-scatter-combiner'
76
export XLA_IR_SHAPE_CACHE_SIZE=100000000
87
export XLA_ALLOCATOR_FRACTION=0.97
98

9+
export XLA_PERSISTENT_CACHE_PATH=./output/compiled_cache/qwen-72b-chat
10+
mkdir -p $XLA_PERSISTENT_CACHE_PATH
1011
# Note: You need to set the correct MASTER_ADDR, MASTER_PORT and NODE_RANK for each node.
1112

1213
MASTER_ADDR=127.0.0.1 \

examples/pytorch/llm/scripts/torchacc/qwen_72b_chat/acc_lora_fsdp_sft.sh

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@
33
# Note: TorchAcc is currently only available internally.
44

55
export USE_TORCHACC=1
6-
export XLA_FLAGS='--xla_gpu_force_compilation_parallelism=32 --xla_multiheap_size_constraint_per_heap=4831838208 --xla_disable_hlo_passes=all-gather-combiner,all-reduce-combiner,reduce-scatter-combiner,gpu-convert-async-collectives-to-sync,rematerialization'
76
export XLA_IR_SHAPE_CACHE_SIZE=100000000
87
export XLA_ALLOCATOR_FRACTION=0.95
98
export XLA_EXPERIMENTAL=nonzero:masked_select
109

10+
export XLA_PERSISTENT_CACHE_PATH=./output/compiled_cache/qwen-72b-chat
11+
mkdir -p $XLA_PERSISTENT_CACHE_PATH
12+
1113
NPROC_PER_NODE=4 \
1214
CUDA_VISIBLE_DEVICES=0,1,2,3 \
1315
swift sft \
@@ -18,7 +20,7 @@ swift sft \
1820
--output_dir output_qwen_72b \
1921
--num_train_epochs 1 \
2022
--max_length 2048 \
21-
--batch_size 4 \
23+
--batch_size 8 \
2224
--use_flash_attn true \
2325
--gradient_accumulation_steps 1 \
2426
--gradient_checkpointing no \

examples/pytorch/llm/scripts/torchacc/yi_34b_chat/acc_lora_fsdp_sft.sh

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@
33
# Note: TorchAcc is currently only available internally.
44
export USE_TORCHACC=1
55
export TORCHACC_TRIM_GRAPH=1
6-
export XLA_FLAGS='--xla_gpu_force_compilation_parallelism=32 --xla_multiheap_size_constraint_per_heap=4831838208 --xla_disable_hlo_passes=all-gather-combiner,all-reduce-combiner,reduce-scatter-combiner,gpu-convert-async-collectives-to-sync,rematerialization'
76
export XLA_IR_SHAPE_CACHE_SIZE=1000000000
87
export XLA_ALLOCATOR_FRACTION=0.95
98
export XLA_EXPERIMENTAL=nonzero:masked_select
109

10+
export XLA_PERSISTENT_CACHE_PATH=./output/compiled_cache/yi-34b-chat
11+
mkdir -p $XLA_PERSISTENT_CACHE_PATH
12+
1113
NPROC_PER_NODE=4 \
1214
CUDA_VISIBLE_DEVICES=0,1,2,3 \
1315
swift sft \
@@ -18,7 +20,7 @@ swift sft \
1820
--output_dir output \
1921
--num_train_epochs 1 \
2022
--max_length 2048 \
21-
--batch_size 12 \
23+
--batch_size 10 \
2224
--use_flash_attn true \
2325
--gradient_accumulation_steps 1 \
2426
--gradient_checkpointing no \

swift/trainers/trainers.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -201,20 +201,19 @@ def compute_loss(self, model, inputs, return_outputs=None):
201201
loss = self.label_smoother(outputs, labels)
202202
else:
203203
loss = outputs['loss'] if isinstance(outputs, dict) else outputs[0]
204-
if use_torchacc():
205-
ta_trim_graph()
206-
if labels is None:
207-
labels = inputs['labels']
208204

209205
if self.sequence_parallel_size > 1:
210206
from swift.trainers.xtuner import reduce_xtuner_sequence_parallel_loss
211207
loss = reduce_xtuner_sequence_parallel_loss(loss, labels)
212208

209+
if labels is None:
210+
labels = inputs['labels']
213211
preds = outputs.logits.argmax(dim=2)[..., :-1]
214212
labels = labels[..., 1:]
215213
masks = labels != -100
216214
acc_strategy = getattr(self.args, 'acc_strategy', 'token')
217215
acc: Optional[Tensor] = None
216+
218217
if preds.shape != labels.shape:
219218
pass
220219
elif acc_strategy == 'sentence':
@@ -223,6 +222,11 @@ def compute_loss(self, model, inputs, return_outputs=None):
223222
acc_list.append(torch.all(preds[i, m] == labels[i, m]).to(torch.int64).item())
224223
acc = torch.tensor(acc_list, device=preds.device).float().mean()
225224
else:
225+
if use_torchacc():
226+
ta_trim_graph()
227+
preds = preds.to('cpu')
228+
masks = masks.to('cpu')
229+
labels = labels.to('cpu')
226230
acc = (torch.masked_select(preds, masks) == torch.masked_select(labels, masks)).float().mean()
227231
if model.training and acc is not None:
228232
if 'acc' not in self._custom_metrics:

0 commit comments

Comments
 (0)