diff --git a/examples/pytorch/llm/scripts/torchacc/baichuan2_13b_chat/acc_lora_dp_sft.sh b/examples/pytorch/llm/scripts/torchacc/baichuan2_13b_chat/acc_lora_dp_sft.sh index d6c143b270..ffd7676d85 100644 --- a/examples/pytorch/llm/scripts/torchacc/baichuan2_13b_chat/acc_lora_dp_sft.sh +++ b/examples/pytorch/llm/scripts/torchacc/baichuan2_13b_chat/acc_lora_dp_sft.sh @@ -3,11 +3,13 @@ # Note: TorchAcc is currently only available internally. # torchacc dp export USE_TORCHACC=1 -export XLA_FLAGS='--xla_gpu_force_compilation_parallelism=32 --xla_multiheap_size_constraint_per_heap=4831838208 --xla_disable_hlo_passes=all-gather-combiner,all-reduce-combiner,reduce-scatter-combiner,gpu-convert-async-collectives-to-sync,rematerialization' export XLA_IR_SHAPE_CACHE_SIZE=100000000 export XLA_ALLOCATOR_FRACTION=0.95 export XLA_EXPERIMENTAL=nonzero:masked_select +export XLA_PERSISTENT_CACHE_PATH=./output/compiled_cache/Baichuan2-13B-Chat +mkdir -p $XLA_PERSISTENT_CACHE_PATH + NPROC_PER_NODE=2 \ CUDA_VISIBLE_DEVICES=0,1 \ MASTER_PORT=27829 \ diff --git a/examples/pytorch/llm/scripts/torchacc/baichuan2_13b_chat/acc_lora_fsdp_sft.sh b/examples/pytorch/llm/scripts/torchacc/baichuan2_13b_chat/acc_lora_fsdp_sft.sh index 5277721c61..649031c55e 100644 --- a/examples/pytorch/llm/scripts/torchacc/baichuan2_13b_chat/acc_lora_fsdp_sft.sh +++ b/examples/pytorch/llm/scripts/torchacc/baichuan2_13b_chat/acc_lora_fsdp_sft.sh @@ -3,11 +3,13 @@ # Note: TorchAcc is currently only available internally. # torchacc fsdp export USE_TORCHACC=1 -export XLA_FLAGS='--xla_gpu_force_compilation_parallelism=32 --xla_multiheap_size_constraint_per_heap=4831838208 --xla_disable_hlo_passes=all-gather-combiner,all-reduce-combiner,reduce-scatter-combiner,gpu-convert-async-collectives-to-sync,rematerialization' export XLA_IR_SHAPE_CACHE_SIZE=100000000 export XLA_ALLOCATOR_FRACTION=0.95 export XLA_EXPERIMENTAL=nonzero:masked_select +export XLA_PERSISTENT_CACHE_PATH=./output/compiled_cache/Baichuan2-13B-Chat +mkdir -p $XLA_PERSISTENT_CACHE_PATH + NPROC_PER_NODE=2 \ CUDA_VISIBLE_DEVICES=0,1 \ swift sft \ diff --git a/examples/pytorch/llm/scripts/torchacc/chatglm3_6b/acc_lora_dp_sft.sh b/examples/pytorch/llm/scripts/torchacc/chatglm3_6b/acc_lora_dp_sft.sh index d059dc7c11..6d28e1371a 100644 --- a/examples/pytorch/llm/scripts/torchacc/chatglm3_6b/acc_lora_dp_sft.sh +++ b/examples/pytorch/llm/scripts/torchacc/chatglm3_6b/acc_lora_dp_sft.sh @@ -3,11 +3,13 @@ # Note: TorchAcc is currently only available internally. # torchacc dp export USE_TORCHACC=1 -export XLA_FLAGS='--xla_gpu_force_compilation_parallelism=32 --xla_multiheap_size_constraint_per_heap=4831838208 --xla_disable_hlo_passes=all-gather-combiner,all-reduce-combiner,reduce-scatter-combiner,gpu-convert-async-collectives-to-sync,rematerialization' export XLA_IR_SHAPE_CACHE_SIZE=100000000 export XLA_ALLOCATOR_FRACTION=0.95 export XLA_EXPERIMENTAL=nonzero:masked_select +export XLA_PERSISTENT_CACHE_PATH=./output/compiled_cache/chatglm3-6b +mkdir -p $XLA_PERSISTENT_CACHE_PATH + NPROC_PER_NODE=2 \ CUDA_VISIBLE_DEVICES=0,1 \ diff --git a/examples/pytorch/llm/scripts/torchacc/chatglm3_6b/acc_lora_fsdp_sft.sh b/examples/pytorch/llm/scripts/torchacc/chatglm3_6b/acc_lora_fsdp_sft.sh index 5993ab8531..53221f00ca 100644 --- a/examples/pytorch/llm/scripts/torchacc/chatglm3_6b/acc_lora_fsdp_sft.sh +++ b/examples/pytorch/llm/scripts/torchacc/chatglm3_6b/acc_lora_fsdp_sft.sh @@ -3,11 +3,13 @@ # Note: TorchAcc is currently only available internally. # torchacc fsdp export USE_TORCHACC=1 -export XLA_FLAGS='--xla_gpu_force_compilation_parallelism=32 --xla_multiheap_size_constraint_per_heap=4831838208 --xla_disable_hlo_passes=all-gather-combiner,all-reduce-combiner,reduce-scatter-combiner,gpu-convert-async-collectives-to-sync,rematerialization' export XLA_IR_SHAPE_CACHE_SIZE=100000000 export XLA_ALLOCATOR_FRACTION=0.95 export XLA_EXPERIMENTAL=nonzero:masked_select +export XLA_PERSISTENT_CACHE_PATH=./output/compiled_cache/chatglm3-6b +mkdir -p $XLA_PERSISTENT_CACHE_PATH + NPROC_PER_NODE=2 \ CUDA_VISIBLE_DEVICES=0,1 \ diff --git a/examples/pytorch/llm/scripts/torchacc/llama2_13b_chat/acc_lora_dp_sft.sh b/examples/pytorch/llm/scripts/torchacc/llama2_13b_chat/acc_lora_dp_sft.sh index c5df3e81fb..e84d806540 100644 --- a/examples/pytorch/llm/scripts/torchacc/llama2_13b_chat/acc_lora_dp_sft.sh +++ b/examples/pytorch/llm/scripts/torchacc/llama2_13b_chat/acc_lora_dp_sft.sh @@ -4,11 +4,13 @@ export USE_TORCHACC=1 export TORCHACC_TRIM_GRAPH=1 -export XLA_FLAGS='--xla_gpu_force_compilation_parallelism=32 --xla_multiheap_size_constraint_per_heap=4831838208 --xla_disable_hlo_passes=all-gather-combiner,all-reduce-combiner,reduce-scatter-combiner,gpu-convert-async-collectives-to-sync,rematerialization' export XLA_IR_SHAPE_CACHE_SIZE=100000000 export XLA_ALLOCATOR_FRACTION=0.95 export XLA_EXPERIMENTAL=nonzero:masked_select +export XLA_PERSISTENT_CACHE_PATH=./output/compiled_cache/Llama-2-13b-chat-ms +mkdir -p $XLA_PERSISTENT_CACHE_PATH + NPROC_PER_NODE=2 \ CUDA_VISIBLE_DEVICES=0,1 \ swift sft \ @@ -20,7 +22,7 @@ swift sft \ --output_dir output \ --num_train_epochs 1 \ --max_length 2048 \ - --batch_size 16 \ + --batch_size 14 \ --use_flash_attn true \ --gradient_accumulation_steps 1 \ --gradient_checkpointing no \ diff --git a/examples/pytorch/llm/scripts/torchacc/llama2_13b_chat/acc_lora_fsdp_sft.sh b/examples/pytorch/llm/scripts/torchacc/llama2_13b_chat/acc_lora_fsdp_sft.sh index 5d84ea90f6..b00a710732 100644 --- a/examples/pytorch/llm/scripts/torchacc/llama2_13b_chat/acc_lora_fsdp_sft.sh +++ b/examples/pytorch/llm/scripts/torchacc/llama2_13b_chat/acc_lora_fsdp_sft.sh @@ -3,11 +3,13 @@ # Note: TorchAcc is currently only available internally. export USE_TORCHACC=1 export TORCHACC_TRIM_GRAPH=1 -export XLA_FLAGS='--xla_gpu_force_compilation_parallelism=32 --xla_multiheap_size_constraint_per_heap=4831838208 --xla_disable_hlo_passes=all-gather-combiner,all-reduce-combiner,reduce-scatter-combiner,gpu-convert-async-collectives-to-sync,rematerialization' export XLA_IR_SHAPE_CACHE_SIZE=100000000 export XLA_ALLOCATOR_FRACTION=0.95 export XLA_EXPERIMENTAL=nonzero:masked_select +export XLA_PERSISTENT_CACHE_PATH=./output/compiled_cache/Llama-2-13b-chat-ms +mkdir -p $XLA_PERSISTENT_CACHE_PATH + NPROC_PER_NODE=2 \ CUDA_VISIBLE_DEVICES=0,1 \ MASTER_PORT=27829 \ @@ -20,7 +22,7 @@ swift sft \ --output_dir output \ --num_train_epochs 1 \ --max_length 2048 \ - --batch_size 24 \ + --batch_size 20 \ --use_flash_attn true \ --gradient_accumulation_steps 1 \ --gradient_checkpointing no \ diff --git a/examples/pytorch/llm/scripts/torchacc/llama3_8b_instruct/acc_lora_dp_sft.sh b/examples/pytorch/llm/scripts/torchacc/llama3_8b_instruct/acc_lora_dp_sft.sh index f86b55436b..f9fd263f23 100644 --- a/examples/pytorch/llm/scripts/torchacc/llama3_8b_instruct/acc_lora_dp_sft.sh +++ b/examples/pytorch/llm/scripts/torchacc/llama3_8b_instruct/acc_lora_dp_sft.sh @@ -4,11 +4,12 @@ export USE_TORCHACC=1 export TORCHACC_TRIM_GRAPH=1 -export XLA_FLAGS='--xla_gpu_force_compilation_parallelism=32 --xla_multiheap_size_constraint_per_heap=4831838208 --xla_disable_hlo_passes=all-gather-combiner,all-reduce-combiner,reduce-scatter-combiner,gpu-convert-async-collectives-to-sync,rematerialization' export XLA_IR_SHAPE_CACHE_SIZE=100000000 export XLA_ALLOCATOR_FRACTION=0.95 export XLA_EXPERIMENTAL=nonzero:masked_select -export XLA_COORDINATOR_PORT=12457 + +export XLA_PERSISTENT_CACHE_PATH=./output/compiled_cache/Meta-Llama-3-8B-Instruct +mkdir -p $XLA_PERSISTENT_CACHE_PATH NPROC_PER_NODE=2 \ CUDA_VISIBLE_DEVICES=0,1 \ diff --git a/examples/pytorch/llm/scripts/torchacc/llama3_8b_instruct/acc_lora_fsdp_sft.sh b/examples/pytorch/llm/scripts/torchacc/llama3_8b_instruct/acc_lora_fsdp_sft.sh index 36f69b1792..e161dc873a 100644 --- a/examples/pytorch/llm/scripts/torchacc/llama3_8b_instruct/acc_lora_fsdp_sft.sh +++ b/examples/pytorch/llm/scripts/torchacc/llama3_8b_instruct/acc_lora_fsdp_sft.sh @@ -3,11 +3,12 @@ # Note: TorchAcc is currently only available internally. export USE_TORCHACC=1 export TORCHACC_TRIM_GRAPH=1 -export XLA_FLAGS='--xla_gpu_force_compilation_parallelism=32 --xla_multiheap_size_constraint_per_heap=4831838208 --xla_disable_hlo_passes=all-gather-combiner,all-reduce-combiner,reduce-scatter-combiner,gpu-convert-async-collectives-to-sync,rematerialization' export XLA_IR_SHAPE_CACHE_SIZE=100000000 export XLA_ALLOCATOR_FRACTION=0.95 export XLA_EXPERIMENTAL=nonzero:masked_select -# export XLA_COORDINATOR_PORT=12457 + +export XLA_PERSISTENT_CACHE_PATH=./output/compiled_cache/Meta-Llama-3-8B-Instruct +mkdir -p $XLA_PERSISTENT_CACHE_PATH NPROC_PER_NODE=2 \ CUDA_VISIBLE_DEVICES=0,1 \ diff --git a/examples/pytorch/llm/scripts/torchacc/qwen1half_14b_chat/acc_lora_dp_sft.sh b/examples/pytorch/llm/scripts/torchacc/qwen1half_14b_chat/acc_lora_dp_sft.sh index 4a6f0894a7..f76554e472 100644 --- a/examples/pytorch/llm/scripts/torchacc/qwen1half_14b_chat/acc_lora_dp_sft.sh +++ b/examples/pytorch/llm/scripts/torchacc/qwen1half_14b_chat/acc_lora_dp_sft.sh @@ -2,12 +2,14 @@ # 80GB GPU memory # Note: TorchAcc is currently only available internally. export USE_TORCHACC=1 -# export TORCHACC_TRIM_GRAPH=1 -export XLA_FLAGS='--xla_gpu_force_compilation_parallelism=32 --xla_multiheap_size_constraint_per_heap=4831838208 --xla_disable_hlo_passes=all-gather-combiner,all-reduce-combiner,reduce-scatter-combiner,gpu-convert-async-collectives-to-sync,rematerialization' +export TORCHACC_TRIM_GRAPH=1 export XLA_IR_SHAPE_CACHE_SIZE=1000000000 export XLA_ALLOCATOR_FRACTION=0.95 export XLA_EXPERIMENTAL=nonzero:masked_select +export XLA_PERSISTENT_CACHE_PATH=./output/compiled_cache/qwen1half-14b-chat +mkdir -p $XLA_PERSISTENT_CACHE_PATH + NPROC_PER_NODE=2 \ CUDA_VISIBLE_DEVICES=2,3 \ MASTER_PORT=23797 \ diff --git a/examples/pytorch/llm/scripts/torchacc/qwen1half_14b_chat/acc_lora_fsdp_sft.sh b/examples/pytorch/llm/scripts/torchacc/qwen1half_14b_chat/acc_lora_fsdp_sft.sh index 6e57d9c5e3..2dfad5832d 100644 --- a/examples/pytorch/llm/scripts/torchacc/qwen1half_14b_chat/acc_lora_fsdp_sft.sh +++ b/examples/pytorch/llm/scripts/torchacc/qwen1half_14b_chat/acc_lora_fsdp_sft.sh @@ -5,10 +5,13 @@ DEBUG_PREFIX=qwen15_14b DEBUG_PATH=torchacc_debug/qwen15/ export USE_TORCHACC=1 # export TORCHACC_TRIM_GRAPH=1 -export XLA_FLAGS='--xla_gpu_force_compilation_parallelism=32 --xla_multiheap_size_constraint_per_heap=4831838208 --xla_disable_hlo_passes=all-gather-combiner,all-reduce-combiner,reduce-scatter-combiner,gpu-convert-async-collectives-to-sync,rematerialization' export XLA_IR_SHAPE_CACHE_SIZE=1000000000 export XLA_ALLOCATOR_FRACTION=0.95 export XLA_EXPERIMENTAL=nonzero:masked_select + +export XLA_PERSISTENT_CACHE_PATH=./output/compiled_cache/qwen1half-14b-chat +mkdir -p $XLA_PERSISTENT_CACHE_PATH + MASTER_PORT=23783 \ NPROC_PER_NODE=2 \ CUDA_VISIBLE_DEVICES=0,1 \ diff --git a/examples/pytorch/llm/scripts/torchacc/qwen1half_32b_chat/acc_lora_fsdp_sft.sh b/examples/pytorch/llm/scripts/torchacc/qwen1half_32b_chat/acc_lora_fsdp_sft.sh index 2a5d6644f4..089558262b 100644 --- a/examples/pytorch/llm/scripts/torchacc/qwen1half_32b_chat/acc_lora_fsdp_sft.sh +++ b/examples/pytorch/llm/scripts/torchacc/qwen1half_32b_chat/acc_lora_fsdp_sft.sh @@ -3,11 +3,13 @@ # Note: TorchAcc is currently only available internally. export USE_TORCHACC=1 # export TORCHACC_TRIM_GRAPH=1 -export XLA_FLAGS='--xla_gpu_force_compilation_parallelism=32 --xla_multiheap_size_constraint_per_heap=4831838208 --xla_disable_hlo_passes=all-gather-combiner,all-reduce-combiner,reduce-scatter-combiner,gpu-convert-async-collectives-to-sync,rematerialization' export XLA_IR_SHAPE_CACHE_SIZE=1000000000 export XLA_ALLOCATOR_FRACTION=0.95 export XLA_EXPERIMENTAL=nonzero:masked_select +export XLA_PERSISTENT_CACHE_PATH=./output/compiled_cache/qwen1half-32b-chat +mkdir -p $XLA_PERSISTENT_CACHE_PATH + NPROC_PER_NODE=4 \ CUDA_VISIBLE_DEVICES=0,1,2,3 \ swift sft \ diff --git a/examples/pytorch/llm/scripts/torchacc/qwen_72b_chat/acc_full_fsdp_sft.sh b/examples/pytorch/llm/scripts/torchacc/qwen_72b_chat/acc_full_fsdp_sft.sh index c819d3d446..ed1901dab9 100644 --- a/examples/pytorch/llm/scripts/torchacc/qwen_72b_chat/acc_full_fsdp_sft.sh +++ b/examples/pytorch/llm/scripts/torchacc/qwen_72b_chat/acc_full_fsdp_sft.sh @@ -3,10 +3,11 @@ # Note: TorchAcc is currently only available internally. export USE_TORCHACC=1 -export XLA_FLAGS='--xla_multiheap_size_constraint_per_heap=4831838208 --xla_disable_hlo_passes=all-gather-combiner,all-reduce-combiner,reduce-scatter-combiner' export XLA_IR_SHAPE_CACHE_SIZE=100000000 export XLA_ALLOCATOR_FRACTION=0.97 +export XLA_PERSISTENT_CACHE_PATH=./output/compiled_cache/qwen-72b-chat +mkdir -p $XLA_PERSISTENT_CACHE_PATH # Note: You need to set the correct MASTER_ADDR, MASTER_PORT and NODE_RANK for each node. MASTER_ADDR=127.0.0.1 \ diff --git a/examples/pytorch/llm/scripts/torchacc/qwen_72b_chat/acc_lora_fsdp_sft.sh b/examples/pytorch/llm/scripts/torchacc/qwen_72b_chat/acc_lora_fsdp_sft.sh index df3cdf35f0..2ced208aa5 100644 --- a/examples/pytorch/llm/scripts/torchacc/qwen_72b_chat/acc_lora_fsdp_sft.sh +++ b/examples/pytorch/llm/scripts/torchacc/qwen_72b_chat/acc_lora_fsdp_sft.sh @@ -3,11 +3,13 @@ # Note: TorchAcc is currently only available internally. export USE_TORCHACC=1 -export XLA_FLAGS='--xla_gpu_force_compilation_parallelism=32 --xla_multiheap_size_constraint_per_heap=4831838208 --xla_disable_hlo_passes=all-gather-combiner,all-reduce-combiner,reduce-scatter-combiner,gpu-convert-async-collectives-to-sync,rematerialization' export XLA_IR_SHAPE_CACHE_SIZE=100000000 export XLA_ALLOCATOR_FRACTION=0.95 export XLA_EXPERIMENTAL=nonzero:masked_select +export XLA_PERSISTENT_CACHE_PATH=./output/compiled_cache/qwen-72b-chat +mkdir -p $XLA_PERSISTENT_CACHE_PATH + NPROC_PER_NODE=4 \ CUDA_VISIBLE_DEVICES=0,1,2,3 \ swift sft \ @@ -18,7 +20,7 @@ swift sft \ --output_dir output_qwen_72b \ --num_train_epochs 1 \ --max_length 2048 \ - --batch_size 4 \ + --batch_size 8 \ --use_flash_attn true \ --gradient_accumulation_steps 1 \ --gradient_checkpointing no \ diff --git a/examples/pytorch/llm/scripts/torchacc/yi_34b_chat/acc_lora_fsdp_sft.sh b/examples/pytorch/llm/scripts/torchacc/yi_34b_chat/acc_lora_fsdp_sft.sh index 17bb17dbd1..cd29a3b9e7 100644 --- a/examples/pytorch/llm/scripts/torchacc/yi_34b_chat/acc_lora_fsdp_sft.sh +++ b/examples/pytorch/llm/scripts/torchacc/yi_34b_chat/acc_lora_fsdp_sft.sh @@ -3,11 +3,13 @@ # Note: TorchAcc is currently only available internally. export USE_TORCHACC=1 export TORCHACC_TRIM_GRAPH=1 -export XLA_FLAGS='--xla_gpu_force_compilation_parallelism=32 --xla_multiheap_size_constraint_per_heap=4831838208 --xla_disable_hlo_passes=all-gather-combiner,all-reduce-combiner,reduce-scatter-combiner,gpu-convert-async-collectives-to-sync,rematerialization' export XLA_IR_SHAPE_CACHE_SIZE=1000000000 export XLA_ALLOCATOR_FRACTION=0.95 export XLA_EXPERIMENTAL=nonzero:masked_select +export XLA_PERSISTENT_CACHE_PATH=./output/compiled_cache/yi-34b-chat +mkdir -p $XLA_PERSISTENT_CACHE_PATH + NPROC_PER_NODE=4 \ CUDA_VISIBLE_DEVICES=0,1,2,3 \ swift sft \ @@ -18,7 +20,7 @@ swift sft \ --output_dir output \ --num_train_epochs 1 \ --max_length 2048 \ - --batch_size 12 \ + --batch_size 10 \ --use_flash_attn true \ --gradient_accumulation_steps 1 \ --gradient_checkpointing no \ diff --git a/swift/trainers/trainers.py b/swift/trainers/trainers.py index 1db2527d8a..66dc22dd42 100644 --- a/swift/trainers/trainers.py +++ b/swift/trainers/trainers.py @@ -201,20 +201,19 @@ def compute_loss(self, model, inputs, return_outputs=None): loss = self.label_smoother(outputs, labels) else: loss = outputs['loss'] if isinstance(outputs, dict) else outputs[0] - if use_torchacc(): - ta_trim_graph() - if labels is None: - labels = inputs['labels'] if self.sequence_parallel_size > 1: from swift.trainers.xtuner import reduce_xtuner_sequence_parallel_loss loss = reduce_xtuner_sequence_parallel_loss(loss, labels) + if labels is None: + labels = inputs['labels'] preds = outputs.logits.argmax(dim=2)[..., :-1] labels = labels[..., 1:] masks = labels != -100 acc_strategy = getattr(self.args, 'acc_strategy', 'token') acc: Optional[Tensor] = None + if preds.shape != labels.shape: pass elif acc_strategy == 'sentence': @@ -223,6 +222,11 @@ def compute_loss(self, model, inputs, return_outputs=None): acc_list.append(torch.all(preds[i, m] == labels[i, m]).to(torch.int64).item()) acc = torch.tensor(acc_list, device=preds.device).float().mean() else: + if use_torchacc(): + ta_trim_graph() + preds = preds.to('cpu') + masks = masks.to('cpu') + labels = labels.to('cpu') acc = (torch.masked_select(preds, masks) == torch.masked_select(labels, masks)).float().mean() if model.training and acc is not None: if 'acc' not in self._custom_metrics: